1
+ """
2
+ Byte pair encoding utilities
3
+
4
+ Code from https://github.com/openai/gpt-2/blob/master/src/encoder.py
5
+
6
+ Modified MIT License
7
+
8
+ Software Copyright (c) 2019 OpenAI
9
+
10
+ We don’t claim ownership of the content you create with GPT-2, so it is yours to do with as you please.
11
+ We only ask that you use GPT-2 responsibly and clearly indicate your content was created using GPT-2.
12
+
13
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
14
+ associated documentation files (the "Software"), to deal in the Software without restriction,
15
+ including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
16
+ and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
17
+ subject to the following conditions:
18
+
19
+ The above copyright notice and this permission notice shall be included
20
+ in all copies or substantial portions of the Software.
21
+ The above copyright notice and this permission notice need not be included
22
+ with content created by the Software.
23
+
24
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
25
+ INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
27
+ BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
28
+ TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
29
+ OR OTHER DEALINGS IN THE SOFTWARE.
30
+
31
+
32
+ """
33
+
34
+ import os
35
+ import json
36
+ import regex as re
37
+ from functools import lru_cache
38
+
39
+ @lru_cache ()
40
+ def bytes_to_unicode ():
41
+ """
42
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
43
+ The reversible bpe codes work on unicode strings.
44
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
45
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
46
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
47
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
48
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
49
+ """
50
+ bs = list (range (ord ("!" ), ord ("~" )+ 1 ))+ list (range (ord ("¡" ), ord ("¬" )+ 1 ))+ list (range (ord ("®" ), ord ("ÿ" )+ 1 ))
51
+ cs = bs [:]
52
+ n = 0
53
+ for b in range (2 ** 8 ):
54
+ if b not in bs :
55
+ bs .append (b )
56
+ cs .append (2 ** 8 + n )
57
+ n += 1
58
+ cs = [chr (n ) for n in cs ]
59
+ return dict (zip (bs , cs ))
60
+
61
+ def get_pairs (word ):
62
+ """Return set of symbol pairs in a word.
63
+
64
+ Word is represented as tuple of symbols (symbols being variable-length strings).
65
+ """
66
+ pairs = set ()
67
+ prev_char = word [0 ]
68
+ for char in word [1 :]:
69
+ pairs .add ((prev_char , char ))
70
+ prev_char = char
71
+ return pairs
72
+
73
+ class Encoder :
74
+ def __init__ (self , encoder , bpe_merges , errors = 'replace' ):
75
+ self .encoder = encoder
76
+ self .decoder = {v :k for k ,v in self .encoder .items ()}
77
+ self .errors = errors # how to handle errors in decoding
78
+ self .byte_encoder = bytes_to_unicode ()
79
+ self .byte_decoder = {v :k for k , v in self .byte_encoder .items ()}
80
+ self .bpe_ranks = dict (zip (bpe_merges , range (len (bpe_merges ))))
81
+ self .cache = {}
82
+
83
+ # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
84
+ self .pat = re .compile (r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" )
85
+
86
+ def bpe (self , token ):
87
+ if token in self .cache :
88
+ return self .cache [token ]
89
+ word = tuple (token )
90
+ pairs = get_pairs (word )
91
+
92
+ if not pairs :
93
+ return token
94
+
95
+ while True :
96
+ bigram = min (pairs , key = lambda pair : self .bpe_ranks .get (pair , float ('inf' )))
97
+ if bigram not in self .bpe_ranks :
98
+ break
99
+ first , second = bigram
100
+ new_word = []
101
+ i = 0
102
+ while i < len (word ):
103
+ try :
104
+ j = word .index (first , i )
105
+ new_word .extend (word [i :j ])
106
+ i = j
107
+ except :
108
+ new_word .extend (word [i :])
109
+ break
110
+
111
+ if word [i ] == first and i < len (word )- 1 and word [i + 1 ] == second :
112
+ new_word .append (first + second )
113
+ i += 2
114
+ else :
115
+ new_word .append (word [i ])
116
+ i += 1
117
+ new_word = tuple (new_word )
118
+ word = new_word
119
+ if len (word ) == 1 :
120
+ break
121
+ else :
122
+ pairs = get_pairs (word )
123
+ word = ' ' .join (word )
124
+ self .cache [token ] = word
125
+ return word
126
+
127
+ def encode (self , text ):
128
+ bpe_tokens = []
129
+ for token in re .findall (self .pat , text ):
130
+ token = '' .join (self .byte_encoder [b ] for b in token .encode ('utf-8' ))
131
+ bpe_tokens .extend (self .encoder [bpe_token ] for bpe_token in self .bpe (token ).split (' ' ))
132
+ return bpe_tokens
133
+
134
+ def decode (self , tokens ):
135
+ text = '' .join ([self .decoder [token ] for token in tokens ])
136
+ text = bytearray ([self .byte_decoder [c ] for c in text ]).decode ('utf-8' , errors = self .errors )
137
+ return text
138
+
139
+ def get_encoder (model_name , models_dir ):
140
+ with open (os .path .join (models_dir , model_name , 'encoder.json' ), 'r' ) as f :
141
+ encoder = json .load (f )
142
+ with open (os .path .join (models_dir , model_name , 'vocab.bpe' ), 'r' , encoding = "utf-8" ) as f :
143
+ bpe_data = f .read ()
144
+ bpe_merges = [tuple (merge_str .split ()) for merge_str in bpe_data .split ('\n ' )[1 :- 1 ]]
145
+ return Encoder (
146
+ encoder = encoder ,
147
+ bpe_merges = bpe_merges ,
148
+ )
0 commit comments