Skip to content
This repository was archived by the owner on Feb 25, 2022. It is now read-only.

Commit 6b37cf3

Browse files
author
sid
committed
remove wikitext tasks (borked)
1 parent 5ce3203 commit 6b37cf3

File tree

2 files changed

+14
-111
lines changed

2 files changed

+14
-111
lines changed

main.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,8 @@ def run_eval_tasks():
217217

218218
if args.eval:
219219
run_eval_tasks()
220-
run_eval()
220+
if params["eval_steps"] > 0:
221+
run_eval()
221222
return
222223

223224

tasks.py

+12-110
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
lambada_src_uri = 'http://eaidata.bmk.sh/data/lambada_test.jsonl'
1212
normalization = 'NFKC'
1313

14+
1415
# Note: this task is called "lambada" but it really refers to OpenAI's version
1516
# of the task, which actually differs in some ways from the task described in
1617
# the original paper. So, strictly speaking, accuracy values from this task
@@ -29,31 +30,34 @@ def lambada_create_tokens_data(params, path):
2930
json.dump(arrays, f)
3031
return arrays
3132

33+
3234
def lambada_read_or_create_tokens_data(params, path):
3335
# if you tell me where the file should go, i will helpfully create it for you
3436
if not os.path.exists(path):
3537
return lambada_create_tokens_data(params, path)
3638
with open(path) as f:
3739
return json.load(f)
3840

41+
3942
def bin_pack(params, tokens_data):
4043
eos_token = params['eos_id']
4144
n_ctx = params['n_ctx']
4245
dummy_token = 1
4346
pad_batch_size = params['eval_batch_size']
4447
bins = []
4548
for a in tokens_data:
46-
if len(bins) == 0 or len(bins[-1])+len(a)+1 > n_ctx:
49+
if len(bins) == 0 or len(bins[-1]) + len(a) + 1 > n_ctx:
4750
bins.append([])
4851
bins[-1] += a
4952
bins[-1].append(eos_token)
50-
while len(bins)%pad_batch_size != 0:
53+
while len(bins) % pad_batch_size != 0:
5154
bins.append([])
5255
bins_array = np.full((len(bins), n_ctx), dummy_token, dtype=np.uint16)
5356
for i, b in enumerate(bins):
5457
bins_array[i, 0:len(b)] = b
5558
return bins_array
5659

60+
5761
def lambada_init(params):
5862
ds_configs = params['dataset_configs']
5963
l = [
@@ -67,45 +71,14 @@ def lambada_init(params):
6771
tokens_data = lambada_read_or_create_tokens_data(params, lt_path)
6872
bins_array = bin_pack(params, tokens_data)
6973
params['lambada_tokens_path'] = lt_path
70-
params['lambada_n_steps'] = len(bins_array)//params['eval_batch_size']
74+
params['lambada_n_steps'] = len(bins_array) // params['eval_batch_size']
75+
7176

7277
def lambada_get_task_info(params):
7378
return {
7479
'n_steps': params['lambada_n_steps'],
7580
}
7681

77-
def wikitext_detokenizer(string):
78-
# contractions
79-
string = string.replace("s '", "s'")
80-
string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
81-
# number separators
82-
string = string.replace(" @-@ ", "-")
83-
string = string.replace(" @,@ ", ",")
84-
string = string.replace(" @.@ ", ".")
85-
# punctuation
86-
string = string.replace(" : ", ": ")
87-
string = string.replace(" ; ", "; ")
88-
string = string.replace(" . ", ". ")
89-
string = string.replace(" ! ", "! ")
90-
string = string.replace(" ? ", "? ")
91-
string = string.replace(" , ", ", ")
92-
# double brackets
93-
string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
94-
string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
95-
string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
96-
string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
97-
string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
98-
# miscellaneous
99-
string = string.replace("= = = =", "====")
100-
string = string.replace("= = =", "===")
101-
string = string.replace("= =", "==")
102-
string = string.replace(" " + chr(176) + " ", chr(176))
103-
string = string.replace(" \n", "\n")
104-
string = string.replace("\n ", "\n")
105-
string = string.replace(" N ", " 1 ")
106-
string = string.replace(" 's", "'s")
107-
108-
return string
10982

11083
# The LAMBADA evaluation code looks at the logits of each position just before an eos_token
11184
def lambada_input(params):
@@ -115,80 +88,19 @@ def lambada_input(params):
11588
tokens_data = lambada_read_or_create_tokens_data(params, lt_path)
11689
bins_array = bin_pack(params, tokens_data)
11790
dataset = tf.data.Dataset.from_tensor_slices(bins_array)
118-
def _get_output(bin):
119-
bin = tf.cast(bin, dtype=tf.int32)
120-
indexes = tf.range(n_ctx)
121-
results = tf.gather(bin, (indexes+1)%n_ctx)
122-
eos_next_positions = tf.math.equal(tf.gather(bin, (indexes+2)%n_ctx), eos_token)
123-
output = tf.where(eos_next_positions, results, tf.constant(eos_token, shape=[n_ctx]))
124-
bin = tf.reshape(bin, [n_ctx])
125-
bin = tf.cast(bin, dtype=tf.int32)
126-
output = tf.reshape(output, [n_ctx])
127-
output = tf.cast(output, dtype=tf.int32)
128-
return bin, output
129-
dataset = dataset.map(_get_output)
130-
dataset = dataset.batch(params['eval_batch_size'], drop_remainder=True)
131-
dataset = dataset.repeat()
132-
return dataset
13391

134-
def wikitext_create_tokens_data(params, path, version):
135-
assert version.lower() in ["wikitext2", "wikitext103"]
136-
wikitext2_src = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip"
137-
wikitext103_src = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip"
138-
version_src = wikitext103_src if version.lower() == "wikitext103" else wikitext2_src
139-
with open(path, 'w') as f:
140-
wikitext_path = f"./{version}-raw-v1.zip"
141-
os.system(f"wget {version_src} -O {wikitext_path}")
142-
os.makedirs(f"{version}", exist_ok=True)
143-
os.system(f"unzip {wikitext_path} -d {version}")
144-
n = 103 if version.lower() == "wikitext103" else 2
145-
with open(f"./{version}/wikitext-{n}-raw/wiki.test.raw", 'r') as wt:
146-
text = ftfy.fix_text(wikitext_detokenizer(wt.read()))
147-
enc = fetch_encoder(params)
148-
encoded_text = encode(enc, text)
149-
arrays = []
150-
for i in range(0, len(encoded_text), params["n_ctx"] - 1):
151-
arrays.append(encoded_text[i:i + params["n_ctx"] - 1])
152-
json.dump(arrays, f)
153-
return arrays
154-
155-
def wikitext_read_or_create_tokens_data(params, path, version):
156-
# if you tell me where the file should go, i will helpfully create it for you
157-
if not os.path.exists(path):
158-
return wikitext_create_tokens_data(params, path, version)
159-
with open(path) as f:
160-
return json.load(f)
161-
162-
def wikitext_init(params, version):
163-
wikitext_path = version + ".json"
164-
tokens_data = wikitext_read_or_create_tokens_data(params, wikitext_path, version)
165-
bins_array = bin_pack(params, tokens_data)
166-
params['wikitext_path'] = wikitext_path
167-
params['wikitext_n_steps'] = len(bins_array)//params['eval_batch_size']
168-
169-
def wikitext_get_task_info(params, version):
170-
return {
171-
'n_steps': params['wikitext_n_steps'],
172-
}
173-
174-
def wikitext_input(params, version):
175-
eos_token = 50256 if params['n_vocab'] >= 50257 else 0
176-
n_ctx = params['n_ctx']
177-
wt_path = params['wikitext_path']
178-
tokens_data = wikitext_read_or_create_tokens_data(params, wt_path, version)
179-
bins_array = bin_pack(params, tokens_data)
180-
dataset = tf.data.Dataset.from_tensor_slices(bins_array)
18192
def _get_output(bin):
18293
bin = tf.cast(bin, dtype=tf.int32)
18394
indexes = tf.range(n_ctx)
184-
results = tf.gather(bin, (indexes+1)%n_ctx)
185-
eos_next_positions = tf.math.equal(tf.gather(bin, (indexes+2)%n_ctx), eos_token)
95+
results = tf.gather(bin, (indexes + 1) % n_ctx)
96+
eos_next_positions = tf.math.equal(tf.gather(bin, (indexes + 2) % n_ctx), eos_token)
18697
output = tf.where(eos_next_positions, results, tf.constant(eos_token, shape=[n_ctx]))
18798
bin = tf.reshape(bin, [n_ctx])
18899
bin = tf.cast(bin, dtype=tf.int32)
189100
output = tf.reshape(output, [n_ctx])
190101
output = tf.cast(output, dtype=tf.int32)
191102
return bin, output
103+
192104
dataset = dataset.map(_get_output)
193105
dataset = dataset.batch(params['eval_batch_size'], drop_remainder=True)
194106
dataset = dataset.repeat()
@@ -200,15 +112,5 @@ def _get_output(bin):
200112
'init_fn': lambada_init,
201113
'get_task_info_fn': lambada_get_task_info,
202114
'input_fn': lambada_input,
203-
},
204-
'wikitext2': {
205-
'init_fn': partial(wikitext_init, version='wikitext2'),
206-
'get_task_info_fn': partial(wikitext_get_task_info, version='wikitext2'),
207-
'input_fn': partial(wikitext_input, version='wikitext2'),
208-
},
209-
'wikitext103': {
210-
'init_fn': partial(wikitext_init, version='wikitext103'),
211-
'get_task_info_fn': partial(wikitext_get_task_info, version='wikitext103'),
212-
'input_fn': partial(wikitext_input, version='wikitext103'),
213-
},
115+
}
214116
}

0 commit comments

Comments
 (0)