Skip to content

Commit ce25191

Browse files
authored
Merge pull request #30 from naist-nlp/plugin
Implement plug-in module loader
2 parents 661eb61 + b4094ab commit ce25191

File tree

8 files changed

+186
-29
lines changed

8 files changed

+186
-29
lines changed

docs/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ mbrs: A library for MBR decoding
3030
- :doc:`design`
3131
- :doc:`custom_metric`
3232
- :doc:`custom_decoder`
33+
- :doc:`plugin`
3334
- :doc:`timer`
3435

3536
.. grid-item-card:: :material-regular:`library_books;2em` References
@@ -66,6 +67,7 @@ mbrs: A library for MBR decoding
6667
design
6768
custom_metric
6869
custom_decoder
70+
plugin
6971
timer
7072

7173
.. toctree::

docs/plugin.rst

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
Plug-in loader
2+
==============
3+
4+
:code:`mbrs-decode` and :code:`mbrs-score` load plug-in modules via the :code:`--plugin_dir` option.
5+
6+
Examples
7+
~~~~~~~~
8+
9+
This tutorial explains how to load a user defined modules.
10+
11+
.. seealso::
12+
13+
:doc:`How to define a new metric <./custom_metric>`
14+
Detailed documentation of the metric customization.
15+
16+
:doc:`How to define a new decoder <./custom_decoder>`
17+
Detailed documentation of the decoder customization.
18+
19+
1. Define a new metric, decoder, or selector with :code:`@register` decorator.
20+
21+
.. code-block:: python
22+
:emphasize-lines: 4
23+
24+
from mbrs.metrics import register, Metric, MetricBLEU
25+
26+
27+
@register("my_bleu")
28+
class MetricMyBLEU(MetricBLEU):
29+
...
30+
31+
2. Prepare :code:`__init__.py` to specify classes to be loaded.
32+
33+
.. code-block:: python
34+
:emphasize-lines: 1
35+
36+
from .new import MetricNew
37+
38+
3. Then, load the modules with :code:`--plugin_dir` option with a path to the directory containing the above :code:`__init__.py`.
39+
40+
.. code-block:: bash
41+
:emphasize-lines: 2,6
42+
43+
mbrs-decode \
44+
--plugin_dir path/to/plugins/ \
45+
hypotheses.txt \
46+
--num_candidates 1024 \
47+
--decoder mbr \
48+
--metric my_bleu
49+
50+
:code:`mbrs-score` also supports plug-in loading.
51+
52+
.. code-block:: bash
53+
:emphasize-lines: 2,5
54+
55+
mbrs-score \
56+
--plugin_dir path/to/plugins/ \
57+
hypotheses.txt \
58+
-r references.txt \
59+
--metric my_bleu

mbrs/args.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1+
from __future__ import annotations
2+
13
import argparse
4+
import importlib
25
import logging
6+
import os
37
import sys
48
from argparse import Namespace
59
from pathlib import Path
@@ -12,6 +16,47 @@
1216

1317

1418
class ArgumentParser(simple_parsing.ArgumentParser):
19+
IMPORTED_MODULES: set[Path] = set()
20+
21+
def add_plugin_argumnets(self, parser: ArgumentParser) -> None:
22+
"""Add arguments for plugins.
23+
24+
Args:
25+
parser (ArgumentParser): Argument parser.
26+
"""
27+
parser.add_argument(
28+
"--plugin_dir",
29+
type=Path,
30+
default=None,
31+
help="Path to a directory containing user defined plugins.",
32+
)
33+
34+
def import_plugin(self, plugin_dir: Path) -> None:
35+
"""Import plugin modules.
36+
37+
Args:
38+
plugin_dir (pathlib.Path): A directory containing user defined plugins.
39+
"""
40+
plugin_dir = plugin_dir.absolute()
41+
if not os.path.exists(plugin_dir) or not os.path.isdir(
42+
os.path.dirname(plugin_dir)
43+
):
44+
raise FileNotFoundError(plugin_dir)
45+
46+
module_parent, module_name = os.path.split(plugin_dir)
47+
if plugin_dir not in ArgumentParser.IMPORTED_MODULES:
48+
if module_name not in sys.modules:
49+
sys.path.insert(0, module_parent)
50+
importlib.import_module(module_name)
51+
elif plugin_dir in sys.modules[module_name].__path__:
52+
logger.info(f"--plugin_dir={plugin_dir} has already been imported.")
53+
else:
54+
raise ImportError(
55+
f"Failed to import --plugin_dir={plugin_dir} because the module name "
56+
f"({module_name}) is not globally unique."
57+
)
58+
self.IMPORTED_MODULES.add(plugin_dir)
59+
1560
def preprocess_parser(self) -> None:
1661
"""Preprocess ArgumentParser."""
1762
self.parse_known_args_preprocess(sys.argv[1:])
@@ -67,6 +112,14 @@ def parse_known_args_preprocess(
67112
help="Path to a config file containing default values to use.",
68113
)
69114

115+
# Plugin loader
116+
self.add_plugin_argumnets(temp_parser)
117+
args_with_plugin_dir, args = temp_parser.parse_known_args(args)
118+
plugin_dir: Path | None = args_with_plugin_dir.plugin_dir
119+
if plugin_dir is not None:
120+
self.import_plugin(plugin_dir)
121+
self.add_plugin_argumnets(self)
122+
70123
assert isinstance(args, list)
71124
self._preprocessing(args=args, namespace=namespace)
72125

mbrs/args_test.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,44 @@
1+
import os
12
import pathlib
23

4+
import pytest
35
import yaml
46

57
from mbrs.cli.decode import get_argparser
8+
from mbrs.metrics import get_metric
69

710

811
class TestArgumentParser:
12+
def test_plugin_load(self, tmp_path: pathlib.Path):
13+
config_path = tmp_path / "config.yaml"
14+
hyps_path = tmp_path / "hyps.txt"
15+
cfg_dict = {
16+
"common": {
17+
"metric": "my_bleu",
18+
"hypotheses": str(hyps_path),
19+
"num_candidates": 2,
20+
},
21+
}
22+
23+
with open(config_path, mode="w") as f:
24+
yaml.dump(cfg_dict, f)
25+
26+
with open(hyps_path, mode="w") as f:
27+
f.writelines(["tests", "a test"])
28+
29+
cmd_args = ["--config_path", str(config_path)]
30+
with pytest.raises(NotImplementedError):
31+
parser = get_argparser(cmd_args)
32+
33+
plugin_dir = os.path.join(
34+
os.path.dirname(os.path.abspath(__file__)), "tests", "plugins"
35+
)
36+
cmd_args += ["--plugin_dir", plugin_dir]
37+
parser = get_argparser(cmd_args)
38+
args = parser.parse_args(args=cmd_args)
39+
assert args.common.metric == "my_bleu"
40+
assert get_metric(args.common.metric).__name__ == "MetricMyBLEU"
41+
942
def test_config_load(self, tmp_path: pathlib.Path):
1043
config_path = tmp_path / "config.yaml"
1144
hyps_path = tmp_path / "hyps.txt"

mbrs/cli/decode.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
1515
datefmt="%Y-%m-%d %H:%M:%S",
1616
level=os.environ.get("LOGLEVEL", "INFO").upper(),
17-
stream=sys.stdout,
17+
stream=sys.stderr,
1818
)
1919
logger = logging.getLogger(__name__)
2020

@@ -36,6 +36,7 @@
3636
from mbrs.metrics import Metric, MetricEnum, get_metric
3737
from mbrs.selectors import Selector, get_selector
3838

39+
3940
simple_parsing.parsing.logger.setLevel(logging.ERROR)
4041
dataclass_wrapper.logger.setLevel(logging.ERROR)
4142

@@ -66,11 +67,17 @@ class CommonArguments:
6667
# Number of references for each sentence.
6768
num_references: int | None = field(default=None)
6869
# Type of the decoder.
69-
decoder: str = choice(*registry.get_registry("decoder").keys(), default="mbr")
70+
decoder: str = field(
71+
default="mbr", metadata={"choices": registry.get_registry("decoder")}
72+
)
7073
# Type of the metric.
71-
metric: str = choice(*registry.get_registry("metric").keys(), default="bleu")
74+
metric: str = field(
75+
default="bleu", metadata={"choices": registry.get_registry("metric")}
76+
)
7277
# Type of the selector.
73-
selector: str = choice(*registry.get_registry("selector").keys(), default="nbest")
78+
selector: str = field(
79+
default="nbest", metadata={"choices": registry.get_registry("selector")}
80+
)
7481
# Return the n-best hypotheses.
7582
nbest: int = field(default=1)
7683
# No verbose information and report.
@@ -174,16 +181,11 @@ def main(args: Namespace) -> None:
174181
reference_lprobs = f.readlines()
175182
assert len(references) == len(reference_lprobs)
176183

177-
metric_type = get_metric(args.common.metric)
178-
metric: Metric = metric_type(args.metric)
179-
180-
selector_type = get_selector(args.common.selector)
181-
selector: Selector = selector_type(args.selector)
182-
183-
decoder_type = get_decoder(args.common.decoder)
184-
decoder: DecoderReferenceBased | DecoderReferenceless = decoder_type(
185-
args.decoder, metric, selector
186-
)
184+
metric: Metric = get_metric(args.common.metric)(args.metric)
185+
selector: Selector = get_selector(args.common.selector)(args.selector)
186+
decoder: DecoderReferenceBased | DecoderReferenceless = get_decoder(
187+
args.common.decoder
188+
)(args.decoder, metric, selector)
187189

188190
num_cands = args.common.num_candidates
189191
num_refs = args.common.num_references or num_cands

mbrs/cli/score.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,25 @@
99
from dataclasses import asdict, dataclass
1010
from typing import Sequence
1111

12+
logging.basicConfig(
13+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
14+
datefmt="%Y-%m-%d %H:%M:%S",
15+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
16+
stream=sys.stderr,
17+
)
18+
logger = logging.getLogger(__name__)
19+
1220
import simple_parsing
1321
from simple_parsing import choice, field, flag
1422
from simple_parsing.wrappers import dataclass_wrapper
1523

1624
from mbrs import registry
1725
from mbrs.args import ArgumentParser
18-
from mbrs.metrics import Metric, get_metric
19-
from mbrs.metrics.base import MetricReferenceless
26+
from mbrs.metrics import Metric, MetricReferenceless, get_metric
2027

2128
simple_parsing.parsing.logger.setLevel(logging.ERROR)
2229
dataclass_wrapper.logger.setLevel(logging.ERROR)
2330

24-
logging.basicConfig(
25-
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
26-
datefmt="%Y-%m-%d %H:%M:%S",
27-
level=os.environ.get("LOGLEVEL", "INFO").upper(),
28-
stream=sys.stdout,
29-
)
30-
logger = logging.getLogger(__name__)
31-
3231

3332
class Format(enum.Enum):
3433
plain = "plain"
@@ -48,7 +47,9 @@ class CommonArguments:
4847
# Output format.
4948
format: Format = choice(Format, default=Format.json)
5049
# Type of the metric.
51-
metric: str = choice(*registry.get_registry("metric").keys(), default="bleu")
50+
metric: str = field(
51+
default="bleu", metadata={"choices": registry.get_registry("metric")}
52+
)
5253
# No verbose information and report.
5354
quiet: bool = flag(default=False)
5455
# Number of digits for values of float point.
@@ -61,11 +62,12 @@ def get_argparser(args: Sequence[str] | None = None) -> ArgumentParser:
6162
for _field in meta_parser._wrappers[0].fields:
6263
_field.required = False
6364
known_args, _ = meta_parser.parse_known_args(args=args)
64-
metric_type = get_metric(known_args.common.metric)
6565

6666
parser = ArgumentParser(add_help=True, add_config_path_arg=True)
6767
parser.add_arguments(CommonArguments, "common")
68-
parser.add_arguments(metric_type.Config, "metric", prefix="metric.")
68+
parser.add_arguments(
69+
get_metric(known_args.common.metric).Config, "metric", prefix="metric."
70+
)
6971
return parser
7072

7173

@@ -92,8 +94,7 @@ def main(args: Namespace) -> None:
9294
references = f.readlines()
9395
assert num_sents == len(references)
9496

95-
metric_type = get_metric(args.common.metric)
96-
metric: Metric = metric_type(args.metric)
97+
metric: Metric | MetricReferenceless = get_metric(args.common.metric)(args.metric)
9798

9899
if isinstance(metric, MetricReferenceless):
99100
assert sources is not None

mbrs/tests/plugins/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .my_bleu import MetricMyBLEU

mbrs/tests/plugins/my_bleu.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from mbrs.metrics import MetricBLEU, register
2+
3+
4+
@register("my_bleu")
5+
class MetricMyBLEU(MetricBLEU):
6+
"""My customized metric class."""

0 commit comments

Comments
 (0)