@@ -17,18 +17,22 @@ def main():
17
17
parser .add_argument ('--init_dir' , type = str , default = '' ,
18
18
help = 'continue from the outputs in the given directory' )
19
19
20
+ # Parameters for picking which model to use.
21
+ parser .add_argument ('--model_path' , type = str , default = '' ,
22
+ help = 'path to the model file like output/best_model/model-40.' )
23
+
20
24
# Parameters for sampling.
21
25
parser .add_argument ('--temperature' , type = float ,
22
26
default = 1.0 ,
23
27
help = ('Temperature for sampling from softmax: '
24
28
'higher temperature, more random; '
25
29
'lower temperature, more greedy.' ))
26
-
30
+
27
31
parser .add_argument ('--max_prob' , dest = 'max_prob' , action = 'store_true' ,
28
32
help = 'always pick the most probable next character in sampling' )
29
33
30
34
parser .set_defaults (max_prob = False )
31
-
35
+
32
36
parser .add_argument ('--start_text' , type = str ,
33
37
default = 'The meaning of life is ' ,
34
38
help = 'the text to start with' )
@@ -61,7 +65,12 @@ def main():
61
65
with open (os .path .join (args .init_dir , 'result.json' ), 'r' ) as f :
62
66
result = json .load (f )
63
67
params = result ['params' ]
64
- best_model = result ['best_model' ]
68
+
69
+ if args .model_path :
70
+ best_model = args .model_path
71
+ else :
72
+ best_model = result ['best_model' ]
73
+
65
74
best_valid_ppl = result ['best_valid_ppl' ]
66
75
if 'encoding' in result :
67
76
args .encoding = result ['encoding' ]
0 commit comments