Skip to content

Commit c553e5a

Browse files
Added model_path argument to support sampling from earlier saved models.
1 parent f5901a6 commit c553e5a

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

sample.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,22 @@ def main():
1717
parser.add_argument('--init_dir', type=str, default='',
1818
help='continue from the outputs in the given directory')
1919

20+
# Parameters for picking which model to use.
21+
parser.add_argument('--model_path', type=str, default='',
22+
help='path to the model file like output/best_model/model-40.')
23+
2024
# Parameters for sampling.
2125
parser.add_argument('--temperature', type=float,
2226
default=1.0,
2327
help=('Temperature for sampling from softmax: '
2428
'higher temperature, more random; '
2529
'lower temperature, more greedy.'))
26-
30+
2731
parser.add_argument('--max_prob', dest='max_prob', action='store_true',
2832
help='always pick the most probable next character in sampling')
2933

3034
parser.set_defaults(max_prob=False)
31-
35+
3236
parser.add_argument('--start_text', type=str,
3337
default='The meaning of life is ',
3438
help='the text to start with')
@@ -61,7 +65,12 @@ def main():
6165
with open(os.path.join(args.init_dir, 'result.json'), 'r') as f:
6266
result = json.load(f)
6367
params = result['params']
64-
best_model = result['best_model']
68+
69+
if args.model_path:
70+
best_model = args.model_path
71+
else:
72+
best_model = result['best_model']
73+
6574
best_valid_ppl = result['best_valid_ppl']
6675
if 'encoding' in result:
6776
args.encoding = result['encoding']

0 commit comments

Comments
 (0)