From ebad72bd4a3c466d2df06903c67752dd4d3ca24d Mon Sep 17 00:00:00 2001 From: Haesun Park Date: Tue, 17 Sep 2019 11:23:41 +0900 Subject: [PATCH] fix conflicting glove.6B.100d.trimmed.txt --- utils/write.py | 45 ++++++++++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/utils/write.py b/utils/write.py index fedcece02f..9cf475f1a1 100644 --- a/utils/write.py +++ b/utils/write.py @@ -8,6 +8,8 @@ import numpy as np +import os.path + _MAX_BATCH_SIZE = 128 _MAX_DOC_LENGTH = 200 @@ -36,28 +38,29 @@ def _add_word(word): embeddings_path = './data/glove/glove.6B.100d.trimmed.txt' -with open(embeddings_path) as f: - line = f.readline() - chunks = line.split(" ") - dimensions = len(chunks) - 1 - f.seek(0) - - vocab_size = sum(1 for line in f) - vocab_size += 4 #3 - f.seek(0) - - glove = np.ndarray((vocab_size, dimensions), dtype=np.float32) - glove[PADDING_TOKEN] = np.random.normal(0, 0.02, dimensions) - glove[UNKNOWN_TOKEN] = np.random.normal(0, 0.02, dimensions) - glove[START_TOKEN] = np.random.normal(0, 0.02, dimensions) - glove[END_TOKEN] = np.random.normal(0, 0.02, dimensions) - - for line in f: +if os.path.exists(embeddings_path): + with open(embeddings_path) as f: + line = f.readline() chunks = line.split(" ") - idx = _add_word(chunks[0]) - glove[idx] = [float(chunk) for chunk in chunks[1:]] - if len(_idx_to_word) >= vocab_size: - break + dimensions = len(chunks) - 1 + f.seek(0) + + vocab_size = sum(1 for line in f) + vocab_size += 4 #3 + f.seek(0) + + glove = np.ndarray((vocab_size, dimensions), dtype=np.float32) + glove[PADDING_TOKEN] = np.random.normal(0, 0.02, dimensions) + glove[UNKNOWN_TOKEN] = np.random.normal(0, 0.02, dimensions) + glove[START_TOKEN] = np.random.normal(0, 0.02, dimensions) + glove[END_TOKEN] = np.random.normal(0, 0.02, dimensions) + + for line in f: + chunks = line.split(" ") + idx = _add_word(chunks[0]) + glove[idx] = [float(chunk) for chunk in chunks[1:]] + if len(_idx_to_word) >= vocab_size: + break