Skip to content

Commit cf6a109

Browse files
authored
Merge pull request #34 from naist-nlp/generate-dataclass
Use dataclass for generation arguments
2 parents 0b254dc + b199ff5 commit cf6a109

File tree

3 files changed

+83
-74
lines changed

3 files changed

+83
-74
lines changed

docs/cli_help.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ mbrs-generate
66

77
.. argparse::
88
:module: mbrs.cli.generate
9-
:func: get_argparser
9+
:func: format_argparser
1010
:prog: mbrs-generate
1111

1212
mbrs-decode

mbrs/cli/decode.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,6 @@
3737
from mbrs.selectors import Selector, get_selector
3838

3939

40-
simple_parsing.parsing.logger.setLevel(logging.ERROR)
41-
dataclass_wrapper.logger.setLevel(logging.ERROR)
42-
43-
4440
class Format(enum.Enum):
4541
plain = "plain"
4642
json = "json"

mbrs/cli/generate.py

Lines changed: 82 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
#!/usr/bin/env python3
22

3-
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, FileType, Namespace
3+
import sys
4+
from argparse import FileType
45
from dataclasses import dataclass
56
from itertools import chain
6-
from typing import Any, Generator, Iterable, Optional
7+
from typing import Any, Generator, Iterable, Optional, Sequence
78

89
import torch
10+
from simple_parsing import choice, field, flag
911
from tabulate import tabulate, tabulate_formats
1012
from tqdm import tqdm
1113
from transformers import (
@@ -17,6 +19,7 @@
1719
from transformers.generation.utils import GenerateOutput, GenerationMixin
1820

1921
from mbrs import timer
22+
from mbrs.args import ArgumentParser
2023

2124

2225
def buffer_lines(input_stream: Iterable[str], buffer_size: int = 64):
@@ -30,69 +33,79 @@ def buffer_lines(input_stream: Iterable[str], buffer_size: int = 64):
3033
yield buf
3134

3235

33-
def get_argparser() -> ArgumentParser:
34-
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
35-
# fmt: off
36-
parser.add_argument("input", nargs="?", default="-",
37-
type=FileType("r", encoding="utf-8"),
38-
help="Input file. If not specified, read from stdin.")
39-
parser.add_argument("--output", "-o", default="-", type=FileType("w"),
40-
help="Output file.")
41-
parser.add_argument("--lprobs", default=None, type=FileType("w"),
42-
help="Reference log-probabilities file. "
43-
"This option is useful for the model-based estimation.")
44-
parser.add_argument("--length_normalized_lprobs", default=None, type=FileType("w"),
45-
help="Length-normalized reference log-probabilities file. "
46-
"This option is useful for the model-based estimation.")
47-
parser.add_argument("--model", "-m", type=str, default="facebook/m2m100_418M",
48-
help="Model name or path.")
49-
parser.add_argument("--num_candidates", "-n", type=int, default=1,
50-
help="Number of candidates to be returned.")
51-
parser.add_argument("--sampling", "-s", type=str, default="",
52-
choices=["eps"],
53-
help="Sampling method.")
54-
parser.add_argument("--beam_size", type=int, default=5,
55-
help="Beam size.")
56-
parser.add_argument("--epsilon", "--eps", "-e", type=float, default=0.02,
57-
help="Cutoff parameter for epsilon sampling.")
58-
parser.add_argument("--lang_pair", "-l", type=str, default="en-de",
59-
help="Language name pair. Some models like M2M100 uses this information.")
60-
parser.add_argument("--max_length", type=int, default=1024,
61-
help="Maximum length of an output sentence.")
62-
parser.add_argument("--min_length", type=int, default=1,
63-
help="Minimum length of an output sentence.")
64-
parser.add_argument("--length_penalty", type=float, default=None,
65-
help="Length penalty.")
66-
parser.add_argument("--batch_size", "-b", type=int, default=8,
67-
help="Batch size.")
68-
parser.add_argument("--sampling_size", type=int, default=8,
69-
help="Sampling size in a single inference. "
70-
"The model generates this number of samples at a time "
71-
"until the total number of samples reaches `--num_candidates`.")
72-
parser.add_argument("--unique", action="store_true",
73-
help="Generate unique sentences for each input.")
74-
parser.add_argument("--retry", type=int, default=100,
75-
help="Retry to do sampling N times when generate unique sentences. "
76-
"If no unique sentences are found after this number of attempts, "
77-
"non-unique sentences will be included in outputs. ")
78-
parser.add_argument("--fp16", action="store_true",
79-
help="Use float16.")
80-
parser.add_argument("--bf16", action="store_true",
81-
help="Use bfloat16.")
82-
parser.add_argument("--cpu", action="store_true",
83-
help="Force to use CPU.")
84-
parser.add_argument("--seed", type=int, default=0,
85-
help="Random number seed.")
86-
parser.add_argument("--quiet", "-q", action="store_true",
87-
help="No report statistics.")
88-
parser.add_argument("--report", default="-", type=FileType("w"),
89-
help="Report file.")
90-
parser.add_argument("--report_format", type=str, default="rounded_outline",
91-
choices=tabulate_formats,
92-
help="Report runtime statistics.")
93-
parser.add_argument("--width", "-w", type=int, default=1,
94-
help="Number of digits for values of float point.")
95-
# fmt: on
36+
@dataclass
37+
class GenerationArguments:
38+
"""Generation arguments."""
39+
40+
# Input file. If not specified, read from stdin.
41+
input: FileType("r", encoding="utf-8") = field(
42+
default="-", nargs="?", positional=True
43+
)
44+
# Output file.
45+
output: FileType("w") = field(default="-", alias=["-o"])
46+
# Reference log-probabilities file.
47+
# This option is typically used for the model-based estimation.
48+
lprobs: FileType("w") = field(default=None)
49+
# Length-normalized reference log-probabilities file.
50+
# This option is typically used for the model-based estimation.
51+
length_normalized_lprobs: FileType("w") = field(default=None)
52+
# Model name or path.
53+
model: str = field(default="facebook/m2m100_418M", alias=["-m"])
54+
# Number of candidates to be returned.
55+
num_candidates: int = field(default=1, alias=["-n"])
56+
# Sampling method.
57+
sampling: str = field(default="", alias=["-s"], metadata={"choices": ["", "eps"]})
58+
# Beam size.
59+
beam_size: int = field(default=5)
60+
# Cutoff parameter for epsilon sampling.
61+
epsilon: float = field(default=0.02, alias=["--eps", "-e"])
62+
# Language code pair. Some models like M2M100 uses this information.
63+
lang_pair: str = field(default="en-de")
64+
# Maximum length of an output sentence.
65+
max_length: int = field(default=1024)
66+
# Minimum length of an output sentence.
67+
min_length: int = field(default=1)
68+
# Length penalty.
69+
length_penalty: float | None = field(default=None)
70+
# Batch size.
71+
batch_size: int = field(default=1)
72+
# Sampling size in a single inference.
73+
# The model generates this number of samples at a time
74+
# until the total number of samples reaches `--num_candidates`.
75+
sampling_size: int = field(default=8)
76+
# Generate unique sentences for each input.
77+
unique: bool = field(default=False)
78+
# Retry to do sampling N times when generate unique sentences.
79+
# If no unique sentences are found after this number of attempts,
80+
# non-unique sentences will be included in outputs.
81+
retry: int = field(default=100)
82+
# Use float16.
83+
fp16: bool = flag(default=False)
84+
# Use bfloat16.
85+
bf16: bool = flag(default=False)
86+
# Force to use CPU.
87+
cpu: bool = flag(default=False)
88+
# Random number seed.
89+
seed: int = field(default=0)
90+
# No report statistics..
91+
quiet: bool = flag(default=False, alias=["-q"])
92+
# Report file.
93+
report: FileType("w") = field(default="-")
94+
# Report runtime statistics with the given format.
95+
report_format: str = choice(*tabulate_formats, default="rounded_outline")
96+
# Number of digits for values of float point.
97+
width: int = field(default=1, alias=["-w"])
98+
99+
100+
def get_argparser(args: Sequence[str] | None = None) -> ArgumentParser:
101+
parser = ArgumentParser(add_help=True, add_config_path_arg=True)
102+
parser.add_arguments(GenerationArguments, "generation")
103+
return parser
104+
105+
106+
def format_argparser() -> ArgumentParser:
107+
parser = get_argparser()
108+
parser.parse_known_args_preprocess(sys.argv[1:] + ["--help"])
96109
return parser
97110

98111

@@ -229,7 +242,7 @@ def memory_efficient_compute_transition_scores(
229242
return transition_scores
230243

231244

232-
def main(args: Namespace) -> None:
245+
def main(args: GenerationArguments) -> None:
233246
set_seed(args.seed)
234247

235248
src_lang, tgt_lang = tuple(args.lang_pair.split("-"))
@@ -246,7 +259,7 @@ def main(args: Namespace) -> None:
246259
model.bfloat16()
247260
model.cuda()
248261

249-
generation_kwargs = {
262+
generation_kwargs: dict[str, str | bool | int | float] = {
250263
"max_length": args.max_length,
251264
"min_length": args.min_length,
252265
"return_dict_in_generate": True,
@@ -338,7 +351,7 @@ def generate(inputs: list[str]) -> Generator[Sample, None, None]:
338351
batch_indices: list[int] = list(range(num_inputs))
339352
sampling_size: int = args.sampling_size
340353
num_retry = 0
341-
while not all(finished) :
354+
while not all(finished):
342355
shards = decode(inputs, sampling_size, generation_kwargs)
343356
num_retry += 1
344357
new_batch_indices: list[int] = []
@@ -404,7 +417,7 @@ def generate(inputs: list[str]) -> Generator[Sample, None, None]:
404417

405418
def cli_main():
406419
args = get_argparser().parse_args()
407-
main(args)
420+
main(args.generation)
408421

409422

410423
if __name__ == "__main__":

0 commit comments

Comments
 (0)