1
1
#!/usr/bin/env python3
2
2
3
- from argparse import ArgumentDefaultsHelpFormatter , ArgumentParser , FileType , Namespace
3
+ import sys
4
+ from argparse import FileType
4
5
from dataclasses import dataclass
5
6
from itertools import chain
6
- from typing import Any , Generator , Iterable , Optional
7
+ from typing import Any , Generator , Iterable , Optional , Sequence
7
8
8
9
import torch
10
+ from simple_parsing import choice , field , flag
9
11
from tabulate import tabulate , tabulate_formats
10
12
from tqdm import tqdm
11
13
from transformers import (
17
19
from transformers .generation .utils import GenerateOutput , GenerationMixin
18
20
19
21
from mbrs import timer
22
+ from mbrs .args import ArgumentParser
20
23
21
24
22
25
def buffer_lines (input_stream : Iterable [str ], buffer_size : int = 64 ):
@@ -30,69 +33,79 @@ def buffer_lines(input_stream: Iterable[str], buffer_size: int = 64):
30
33
yield buf
31
34
32
35
33
- def get_argparser () -> ArgumentParser :
34
- parser = ArgumentParser (formatter_class = ArgumentDefaultsHelpFormatter )
35
- # fmt: off
36
- parser .add_argument ("input" , nargs = "?" , default = "-" ,
37
- type = FileType ("r" , encoding = "utf-8" ),
38
- help = "Input file. If not specified, read from stdin." )
39
- parser .add_argument ("--output" , "-o" , default = "-" , type = FileType ("w" ),
40
- help = "Output file." )
41
- parser .add_argument ("--lprobs" , default = None , type = FileType ("w" ),
42
- help = "Reference log-probabilities file. "
43
- "This option is useful for the model-based estimation." )
44
- parser .add_argument ("--length_normalized_lprobs" , default = None , type = FileType ("w" ),
45
- help = "Length-normalized reference log-probabilities file. "
46
- "This option is useful for the model-based estimation." )
47
- parser .add_argument ("--model" , "-m" , type = str , default = "facebook/m2m100_418M" ,
48
- help = "Model name or path." )
49
- parser .add_argument ("--num_candidates" , "-n" , type = int , default = 1 ,
50
- help = "Number of candidates to be returned." )
51
- parser .add_argument ("--sampling" , "-s" , type = str , default = "" ,
52
- choices = ["eps" ],
53
- help = "Sampling method." )
54
- parser .add_argument ("--beam_size" , type = int , default = 5 ,
55
- help = "Beam size." )
56
- parser .add_argument ("--epsilon" , "--eps" , "-e" , type = float , default = 0.02 ,
57
- help = "Cutoff parameter for epsilon sampling." )
58
- parser .add_argument ("--lang_pair" , "-l" , type = str , default = "en-de" ,
59
- help = "Language name pair. Some models like M2M100 uses this information." )
60
- parser .add_argument ("--max_length" , type = int , default = 1024 ,
61
- help = "Maximum length of an output sentence." )
62
- parser .add_argument ("--min_length" , type = int , default = 1 ,
63
- help = "Minimum length of an output sentence." )
64
- parser .add_argument ("--length_penalty" , type = float , default = None ,
65
- help = "Length penalty." )
66
- parser .add_argument ("--batch_size" , "-b" , type = int , default = 8 ,
67
- help = "Batch size." )
68
- parser .add_argument ("--sampling_size" , type = int , default = 8 ,
69
- help = "Sampling size in a single inference. "
70
- "The model generates this number of samples at a time "
71
- "until the total number of samples reaches `--num_candidates`." )
72
- parser .add_argument ("--unique" , action = "store_true" ,
73
- help = "Generate unique sentences for each input." )
74
- parser .add_argument ("--retry" , type = int , default = 100 ,
75
- help = "Retry to do sampling N times when generate unique sentences. "
76
- "If no unique sentences are found after this number of attempts, "
77
- "non-unique sentences will be included in outputs. " )
78
- parser .add_argument ("--fp16" , action = "store_true" ,
79
- help = "Use float16." )
80
- parser .add_argument ("--bf16" , action = "store_true" ,
81
- help = "Use bfloat16." )
82
- parser .add_argument ("--cpu" , action = "store_true" ,
83
- help = "Force to use CPU." )
84
- parser .add_argument ("--seed" , type = int , default = 0 ,
85
- help = "Random number seed." )
86
- parser .add_argument ("--quiet" , "-q" , action = "store_true" ,
87
- help = "No report statistics." )
88
- parser .add_argument ("--report" , default = "-" , type = FileType ("w" ),
89
- help = "Report file." )
90
- parser .add_argument ("--report_format" , type = str , default = "rounded_outline" ,
91
- choices = tabulate_formats ,
92
- help = "Report runtime statistics." )
93
- parser .add_argument ("--width" , "-w" , type = int , default = 1 ,
94
- help = "Number of digits for values of float point." )
95
- # fmt: on
36
+ @dataclass
37
+ class GenerationArguments :
38
+ """Generation arguments."""
39
+
40
+ # Input file. If not specified, read from stdin.
41
+ input : FileType ("r" , encoding = "utf-8" ) = field (
42
+ default = "-" , nargs = "?" , positional = True
43
+ )
44
+ # Output file.
45
+ output : FileType ("w" ) = field (default = "-" , alias = ["-o" ])
46
+ # Reference log-probabilities file.
47
+ # This option is typically used for the model-based estimation.
48
+ lprobs : FileType ("w" ) = field (default = None )
49
+ # Length-normalized reference log-probabilities file.
50
+ # This option is typically used for the model-based estimation.
51
+ length_normalized_lprobs : FileType ("w" ) = field (default = None )
52
+ # Model name or path.
53
+ model : str = field (default = "facebook/m2m100_418M" , alias = ["-m" ])
54
+ # Number of candidates to be returned.
55
+ num_candidates : int = field (default = 1 , alias = ["-n" ])
56
+ # Sampling method.
57
+ sampling : str = field (default = "" , alias = ["-s" ], metadata = {"choices" : ["" , "eps" ]})
58
+ # Beam size.
59
+ beam_size : int = field (default = 5 )
60
+ # Cutoff parameter for epsilon sampling.
61
+ epsilon : float = field (default = 0.02 , alias = ["--eps" , "-e" ])
62
+ # Language code pair. Some models like M2M100 uses this information.
63
+ lang_pair : str = field (default = "en-de" )
64
+ # Maximum length of an output sentence.
65
+ max_length : int = field (default = 1024 )
66
+ # Minimum length of an output sentence.
67
+ min_length : int = field (default = 1 )
68
+ # Length penalty.
69
+ length_penalty : float | None = field (default = None )
70
+ # Batch size.
71
+ batch_size : int = field (default = 1 )
72
+ # Sampling size in a single inference.
73
+ # The model generates this number of samples at a time
74
+ # until the total number of samples reaches `--num_candidates`.
75
+ sampling_size : int = field (default = 8 )
76
+ # Generate unique sentences for each input.
77
+ unique : bool = field (default = False )
78
+ # Retry to do sampling N times when generate unique sentences.
79
+ # If no unique sentences are found after this number of attempts,
80
+ # non-unique sentences will be included in outputs.
81
+ retry : int = field (default = 100 )
82
+ # Use float16.
83
+ fp16 : bool = flag (default = False )
84
+ # Use bfloat16.
85
+ bf16 : bool = flag (default = False )
86
+ # Force to use CPU.
87
+ cpu : bool = flag (default = False )
88
+ # Random number seed.
89
+ seed : int = field (default = 0 )
90
+ # No report statistics..
91
+ quiet : bool = flag (default = False , alias = ["-q" ])
92
+ # Report file.
93
+ report : FileType ("w" ) = field (default = "-" )
94
+ # Report runtime statistics with the given format.
95
+ report_format : str = choice (* tabulate_formats , default = "rounded_outline" )
96
+ # Number of digits for values of float point.
97
+ width : int = field (default = 1 , alias = ["-w" ])
98
+
99
+
100
+ def get_argparser (args : Sequence [str ] | None = None ) -> ArgumentParser :
101
+ parser = ArgumentParser (add_help = True , add_config_path_arg = True )
102
+ parser .add_arguments (GenerationArguments , "generation" )
103
+ return parser
104
+
105
+
106
+ def format_argparser () -> ArgumentParser :
107
+ parser = get_argparser ()
108
+ parser .parse_known_args_preprocess (sys .argv [1 :] + ["--help" ])
96
109
return parser
97
110
98
111
@@ -229,7 +242,7 @@ def memory_efficient_compute_transition_scores(
229
242
return transition_scores
230
243
231
244
232
- def main (args : Namespace ) -> None :
245
+ def main (args : GenerationArguments ) -> None :
233
246
set_seed (args .seed )
234
247
235
248
src_lang , tgt_lang = tuple (args .lang_pair .split ("-" ))
@@ -246,7 +259,7 @@ def main(args: Namespace) -> None:
246
259
model .bfloat16 ()
247
260
model .cuda ()
248
261
249
- generation_kwargs = {
262
+ generation_kwargs : dict [ str , str | bool | int | float ] = {
250
263
"max_length" : args .max_length ,
251
264
"min_length" : args .min_length ,
252
265
"return_dict_in_generate" : True ,
@@ -338,7 +351,7 @@ def generate(inputs: list[str]) -> Generator[Sample, None, None]:
338
351
batch_indices : list [int ] = list (range (num_inputs ))
339
352
sampling_size : int = args .sampling_size
340
353
num_retry = 0
341
- while not all (finished ) :
354
+ while not all (finished ):
342
355
shards = decode (inputs , sampling_size , generation_kwargs )
343
356
num_retry += 1
344
357
new_batch_indices : list [int ] = []
@@ -404,7 +417,7 @@ def generate(inputs: list[str]) -> Generator[Sample, None, None]:
404
417
405
418
def cli_main ():
406
419
args = get_argparser ().parse_args ()
407
- main (args )
420
+ main (args . generation )
408
421
409
422
410
423
if __name__ == "__main__" :
0 commit comments