11
11
lambada_src_uri = 'http://eaidata.bmk.sh/data/lambada_test.jsonl'
12
12
normalization = 'NFKC'
13
13
14
+
14
15
# Note: this task is called "lambada" but it really refers to OpenAI's version
15
16
# of the task, which actually differs in some ways from the task described in
16
17
# the original paper. So, strictly speaking, accuracy values from this task
@@ -29,31 +30,34 @@ def lambada_create_tokens_data(params, path):
29
30
json .dump (arrays , f )
30
31
return arrays
31
32
33
+
32
34
def lambada_read_or_create_tokens_data (params , path ):
33
35
# if you tell me where the file should go, i will helpfully create it for you
34
36
if not os .path .exists (path ):
35
37
return lambada_create_tokens_data (params , path )
36
38
with open (path ) as f :
37
39
return json .load (f )
38
40
41
+
39
42
def bin_pack (params , tokens_data ):
40
43
eos_token = params ['eos_id' ]
41
44
n_ctx = params ['n_ctx' ]
42
45
dummy_token = 1
43
46
pad_batch_size = params ['eval_batch_size' ]
44
47
bins = []
45
48
for a in tokens_data :
46
- if len (bins ) == 0 or len (bins [- 1 ])+ len (a )+ 1 > n_ctx :
49
+ if len (bins ) == 0 or len (bins [- 1 ]) + len (a ) + 1 > n_ctx :
47
50
bins .append ([])
48
51
bins [- 1 ] += a
49
52
bins [- 1 ].append (eos_token )
50
- while len (bins )% pad_batch_size != 0 :
53
+ while len (bins ) % pad_batch_size != 0 :
51
54
bins .append ([])
52
55
bins_array = np .full ((len (bins ), n_ctx ), dummy_token , dtype = np .uint16 )
53
56
for i , b in enumerate (bins ):
54
57
bins_array [i , 0 :len (b )] = b
55
58
return bins_array
56
59
60
+
57
61
def lambada_init (params ):
58
62
ds_configs = params ['dataset_configs' ]
59
63
l = [
@@ -67,45 +71,14 @@ def lambada_init(params):
67
71
tokens_data = lambada_read_or_create_tokens_data (params , lt_path )
68
72
bins_array = bin_pack (params , tokens_data )
69
73
params ['lambada_tokens_path' ] = lt_path
70
- params ['lambada_n_steps' ] = len (bins_array )// params ['eval_batch_size' ]
74
+ params ['lambada_n_steps' ] = len (bins_array ) // params ['eval_batch_size' ]
75
+
71
76
72
77
def lambada_get_task_info (params ):
73
78
return {
74
79
'n_steps' : params ['lambada_n_steps' ],
75
80
}
76
81
77
- def wikitext_detokenizer (string ):
78
- # contractions
79
- string = string .replace ("s '" , "s'" )
80
- string = re .sub (r"/' [0-9]/" , r"/'[0-9]/" , string )
81
- # number separators
82
- string = string .replace (" @-@ " , "-" )
83
- string = string .replace (" @,@ " , "," )
84
- string = string .replace (" @.@ " , "." )
85
- # punctuation
86
- string = string .replace (" : " , ": " )
87
- string = string .replace (" ; " , "; " )
88
- string = string .replace (" . " , ". " )
89
- string = string .replace (" ! " , "! " )
90
- string = string .replace (" ? " , "? " )
91
- string = string .replace (" , " , ", " )
92
- # double brackets
93
- string = re .sub (r"\(\s*([^\)]*?)\s*\)" , r"(\1)" , string )
94
- string = re .sub (r"\[\s*([^\]]*?)\s*\]" , r"[\1]" , string )
95
- string = re .sub (r"{\s*([^}]*?)\s*}" , r"{\1}" , string )
96
- string = re .sub (r"\"\s*([^\"]*?)\s*\"" , r'"\1"' , string )
97
- string = re .sub (r"'\s*([^']*?)\s*'" , r"'\1'" , string )
98
- # miscellaneous
99
- string = string .replace ("= = = =" , "====" )
100
- string = string .replace ("= = =" , "===" )
101
- string = string .replace ("= =" , "==" )
102
- string = string .replace (" " + chr (176 ) + " " , chr (176 ))
103
- string = string .replace (" \n " , "\n " )
104
- string = string .replace ("\n " , "\n " )
105
- string = string .replace (" N " , " 1 " )
106
- string = string .replace (" 's" , "'s" )
107
-
108
- return string
109
82
110
83
# The LAMBADA evaluation code looks at the logits of each position just before an eos_token
111
84
def lambada_input (params ):
@@ -115,80 +88,19 @@ def lambada_input(params):
115
88
tokens_data = lambada_read_or_create_tokens_data (params , lt_path )
116
89
bins_array = bin_pack (params , tokens_data )
117
90
dataset = tf .data .Dataset .from_tensor_slices (bins_array )
118
- def _get_output (bin ):
119
- bin = tf .cast (bin , dtype = tf .int32 )
120
- indexes = tf .range (n_ctx )
121
- results = tf .gather (bin , (indexes + 1 )% n_ctx )
122
- eos_next_positions = tf .math .equal (tf .gather (bin , (indexes + 2 )% n_ctx ), eos_token )
123
- output = tf .where (eos_next_positions , results , tf .constant (eos_token , shape = [n_ctx ]))
124
- bin = tf .reshape (bin , [n_ctx ])
125
- bin = tf .cast (bin , dtype = tf .int32 )
126
- output = tf .reshape (output , [n_ctx ])
127
- output = tf .cast (output , dtype = tf .int32 )
128
- return bin , output
129
- dataset = dataset .map (_get_output )
130
- dataset = dataset .batch (params ['eval_batch_size' ], drop_remainder = True )
131
- dataset = dataset .repeat ()
132
- return dataset
133
91
134
- def wikitext_create_tokens_data (params , path , version ):
135
- assert version .lower () in ["wikitext2" , "wikitext103" ]
136
- wikitext2_src = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip"
137
- wikitext103_src = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip"
138
- version_src = wikitext103_src if version .lower () == "wikitext103" else wikitext2_src
139
- with open (path , 'w' ) as f :
140
- wikitext_path = f"./{ version } -raw-v1.zip"
141
- os .system (f"wget { version_src } -O { wikitext_path } " )
142
- os .makedirs (f"{ version } " , exist_ok = True )
143
- os .system (f"unzip { wikitext_path } -d { version } " )
144
- n = 103 if version .lower () == "wikitext103" else 2
145
- with open (f"./{ version } /wikitext-{ n } -raw/wiki.test.raw" , 'r' ) as wt :
146
- text = ftfy .fix_text (wikitext_detokenizer (wt .read ()))
147
- enc = fetch_encoder (params )
148
- encoded_text = encode (enc , text )
149
- arrays = []
150
- for i in range (0 , len (encoded_text ), params ["n_ctx" ] - 1 ):
151
- arrays .append (encoded_text [i :i + params ["n_ctx" ] - 1 ])
152
- json .dump (arrays , f )
153
- return arrays
154
-
155
- def wikitext_read_or_create_tokens_data (params , path , version ):
156
- # if you tell me where the file should go, i will helpfully create it for you
157
- if not os .path .exists (path ):
158
- return wikitext_create_tokens_data (params , path , version )
159
- with open (path ) as f :
160
- return json .load (f )
161
-
162
- def wikitext_init (params , version ):
163
- wikitext_path = version + ".json"
164
- tokens_data = wikitext_read_or_create_tokens_data (params , wikitext_path , version )
165
- bins_array = bin_pack (params , tokens_data )
166
- params ['wikitext_path' ] = wikitext_path
167
- params ['wikitext_n_steps' ] = len (bins_array )// params ['eval_batch_size' ]
168
-
169
- def wikitext_get_task_info (params , version ):
170
- return {
171
- 'n_steps' : params ['wikitext_n_steps' ],
172
- }
173
-
174
- def wikitext_input (params , version ):
175
- eos_token = 50256 if params ['n_vocab' ] >= 50257 else 0
176
- n_ctx = params ['n_ctx' ]
177
- wt_path = params ['wikitext_path' ]
178
- tokens_data = wikitext_read_or_create_tokens_data (params , wt_path , version )
179
- bins_array = bin_pack (params , tokens_data )
180
- dataset = tf .data .Dataset .from_tensor_slices (bins_array )
181
92
def _get_output (bin ):
182
93
bin = tf .cast (bin , dtype = tf .int32 )
183
94
indexes = tf .range (n_ctx )
184
- results = tf .gather (bin , (indexes + 1 ) % n_ctx )
185
- eos_next_positions = tf .math .equal (tf .gather (bin , (indexes + 2 ) % n_ctx ), eos_token )
95
+ results = tf .gather (bin , (indexes + 1 ) % n_ctx )
96
+ eos_next_positions = tf .math .equal (tf .gather (bin , (indexes + 2 ) % n_ctx ), eos_token )
186
97
output = tf .where (eos_next_positions , results , tf .constant (eos_token , shape = [n_ctx ]))
187
98
bin = tf .reshape (bin , [n_ctx ])
188
99
bin = tf .cast (bin , dtype = tf .int32 )
189
100
output = tf .reshape (output , [n_ctx ])
190
101
output = tf .cast (output , dtype = tf .int32 )
191
102
return bin , output
103
+
192
104
dataset = dataset .map (_get_output )
193
105
dataset = dataset .batch (params ['eval_batch_size' ], drop_remainder = True )
194
106
dataset = dataset .repeat ()
@@ -200,15 +112,5 @@ def _get_output(bin):
200
112
'init_fn' : lambada_init ,
201
113
'get_task_info_fn' : lambada_get_task_info ,
202
114
'input_fn' : lambada_input ,
203
- },
204
- 'wikitext2' : {
205
- 'init_fn' : partial (wikitext_init , version = 'wikitext2' ),
206
- 'get_task_info_fn' : partial (wikitext_get_task_info , version = 'wikitext2' ),
207
- 'input_fn' : partial (wikitext_input , version = 'wikitext2' ),
208
- },
209
- 'wikitext103' : {
210
- 'init_fn' : partial (wikitext_init , version = 'wikitext103' ),
211
- 'get_task_info_fn' : partial (wikitext_get_task_info , version = 'wikitext103' ),
212
- 'input_fn' : partial (wikitext_input , version = 'wikitext103' ),
213
- },
115
+ }
214
116
}
0 commit comments