Skip to content

Commit 661eb61

Browse files
authored
Merge pull request #29 from naist-nlp/bertscore
Implement BERTScore
2 parents 95aa3de + eda8ade commit 661eb61

File tree

8 files changed

+350
-110
lines changed

8 files changed

+350
-110
lines changed

.github/workflows/ci.yaml

Lines changed: 9 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@ jobs:
1616
matrix:
1717
platform: ["ubuntu-latest", "windows-latest"]
1818
python-version: ["3.10", "3.11"]
19+
pytest_marker:
20+
- null
21+
- "metrics_bertscore"
22+
- "metrics_bleurt"
23+
- "metrics_xcometlite"
24+
- "metrics_metricx24"
25+
- "metrics_metricx23"
26+
- "metrics_metricx23qe"
1927
runs-on: ${{ matrix.platform }}
2028
steps:
2129
- uses: actions/checkout@v4
@@ -30,111 +38,4 @@ jobs:
3038
- name: Test with pytest
3139
run: |
3240
uv run huggingface-cli login --token ${{ secrets.HUGGINGFACE_TOKEN }}
33-
uv run pytest
34-
35-
metrics_bleurt:
36-
strategy:
37-
matrix:
38-
platform: ["ubuntu-latest", "windows-latest"]
39-
python-version: ["3.10", "3.11"]
40-
runs-on: ${{ matrix.platform }}
41-
steps:
42-
- uses: actions/checkout@v4
43-
- name: Install uv
44-
uses: astral-sh/setup-uv@v4
45-
with:
46-
python-version: ${{ matrix.python-version }}
47-
- name: Set up Python ${{ matrix.python-version }}
48-
run: uv python install
49-
- name: Install the project
50-
run: uv sync --all-extras --dev
51-
- name: Test with pytest
52-
run: |
53-
uv run huggingface-cli login --token ${{ secrets.HUGGINGFACE_TOKEN }}
54-
uv run pytest -m "metrics_bleurt"
55-
56-
metrics_xcometlite:
57-
strategy:
58-
matrix:
59-
platform: ["ubuntu-latest", "windows-latest"]
60-
python-version: ["3.10", "3.11"]
61-
62-
runs-on: ${{ matrix.platform }}
63-
64-
steps:
65-
- uses: actions/checkout@v4
66-
- name: Install uv
67-
uses: astral-sh/setup-uv@v4
68-
with:
69-
python-version: ${{ matrix.python-version }}
70-
- name: Set up Python ${{ matrix.python-version }}
71-
run: uv python install
72-
- name: Install the project
73-
run: uv sync --all-extras --dev
74-
- name: Test with pytest
75-
run: |
76-
uv run huggingface-cli login --token ${{ secrets.HUGGINGFACE_TOKEN }}
77-
uv run pytest -m "metrics_xcometlite"
78-
79-
metrics_metricx24:
80-
strategy:
81-
matrix:
82-
platform: ["ubuntu-latest", "windows-latest"]
83-
python-version: ["3.10", "3.11"]
84-
runs-on: ${{ matrix.platform }}
85-
steps:
86-
- uses: actions/checkout@v4
87-
- name: Install uv
88-
uses: astral-sh/setup-uv@v4
89-
with:
90-
python-version: ${{ matrix.python-version }}
91-
- name: Set up Python ${{ matrix.python-version }}
92-
run: uv python install
93-
- name: Install the project
94-
run: uv sync --all-extras --dev
95-
- name: Test with pytest
96-
run: |
97-
uv run huggingface-cli login --token ${{ secrets.HUGGINGFACE_TOKEN }}
98-
uv run pytest -m "metrics_metricx24"
99-
100-
metrics_metricx23:
101-
strategy:
102-
matrix:
103-
platform: ["ubuntu-latest", "windows-latest"]
104-
python-version: ["3.10", "3.11"]
105-
runs-on: ${{ matrix.platform }}
106-
steps:
107-
- uses: actions/checkout@v4
108-
- name: Install uv
109-
uses: astral-sh/setup-uv@v4
110-
with:
111-
python-version: ${{ matrix.python-version }}
112-
- name: Set up Python ${{ matrix.python-version }}
113-
run: uv python install
114-
- name: Install the project
115-
run: uv sync --all-extras --dev
116-
- name: Test with pytest
117-
run: |
118-
uv run huggingface-cli login --token ${{ secrets.HUGGINGFACE_TOKEN }}
119-
uv run pytest -m "metrics_metricx23"
120-
121-
metrics_metricx23qe:
122-
strategy:
123-
matrix:
124-
platform: ["ubuntu-latest", "windows-latest"]
125-
python-version: ["3.10", "3.11"]
126-
runs-on: ${{ matrix.platform }}
127-
steps:
128-
- uses: actions/checkout@v4
129-
- name: Install uv
130-
uses: astral-sh/setup-uv@v4
131-
with:
132-
python-version: ${{ matrix.python-version }}
133-
- name: Set up Python ${{ matrix.python-version }}
134-
run: uv python install
135-
- name: Install the project
136-
run: uv sync --all-extras --dev
137-
- name: Test with pytest
138-
run: |
139-
uv run huggingface-cli login --token ${{ secrets.HUGGINGFACE_TOKEN }}
140-
uv run pytest -m "metrics_metricx23qe"
41+
uv run pytest ${{ matrix.pytest_marker && format('-m {0}', matrix.pytest_marker) || '' }}

.readthedocs.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ build:
1414
- asdf plugin add uv
1515
- asdf install uv latest
1616
- asdf global uv latest
17-
- uv sync --extra docs --frozen
17+
- uv python install 3.10
18+
- uv sync --all-extras --all-groups
1819
- uv run sphinx-apidoc --remove-old -d1 -Tfe -o docs/source ./ "$READTHEDOCS_REPOSITORY_PATH/**/*_test.py" "$READTHEDOCS_REPOSITORY_PATH/**/conftest.py"
1920
- uv run -m sphinx -T -b html -d docs/_build/doctrees -D language=en docs $READTHEDOCS_OUTPUT/html
2021

README.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
<b>
1616
<a href="https://aclanthology.org/2024.emnlp-demo.37">Paper</a> |
1717
<a href="https://mbrs.readthedocs.io">Reference docs</a> |
18-
<a href="https://github.com/naist-nlp/mbrs#citation">Citation</a>
18+
<a href="https://github.com/naist-nlp/mbrs#citation">Citation</a> |
19+
<a href="https://github.com/naist-nlp/mbrs/releases">Release notes</a>
1920
</b>
2021
</p>
2122

@@ -35,6 +36,13 @@ cd mbrs/
3536
pip install ./
3637
```
3738

39+
For uv users:
40+
``` bash
41+
git clone https://github.com/naist-nlp/mbrs.git
42+
cd mbrs/
43+
uv sync
44+
```
45+
3846
## Quick start
3947

4048
mbrs provides two interfaces: command-line interface (CLI) and Python
@@ -155,6 +163,7 @@ Currently, the following metrics are supported:
155163
to [\@lucadiliello](https://github.com/lucadiliello/bleurt-pytorch))
156164
- MetricX ([Juraska et al., 2023](https://aclanthology.org/2023.wmt-1.63);
157165
[Juraska et al., 2024](https://aclanthology.org/2024.wmt-1.35)): `metricx`
166+
- BERTScore [(Zhang et al., 2020)](https://openreview.net/forum?id=SkeHuCVFDr): `bertscore`
158167

159168
### Decoders
160169

docs/list_metrics.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,7 @@ Supported metrics are listed below.
5151
- :code:`metricx`
5252
- :doc:`MetricMetricX <./source/mbrs.metrics.metricx>`
5353
- `(Juraska et al., 2023) <https://aclanthology.org/2023.wmt-1.63>`_ `(Juraska et al., 2024) <https://aclanthology.org/2024.wmt-1.35>`_
54+
* - BERTScore
55+
- :code:`bertscore`
56+
- :doc:`MetricBERTScore <./source/mbrs.metrics.bertscore>`
57+
- `(Zhang et al., 2020) <https://openreview.net/forum?id=SkeHuCVFDr>`_

mbrs/metrics/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
register, get_metric = registry.setup("metric")
1616

17+
from .bertscore import MetricBERTScore
1718
from .bleu import MetricBLEU
1819
from .bleurt import MetricBLEURT
1920
from .chrf import MetricChrF
@@ -29,6 +30,7 @@
2930
"MetricAggregatable",
3031
"MetricCacheable",
3132
"MetricReferenceless",
33+
"MetricBERTScore",
3234
"MetricBLEU",
3335
"MetricChrF",
3436
"MetricCOMET",

mbrs/metrics/bertscore.py

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
from __future__ import annotations
2+
3+
import enum
4+
import itertools
5+
from dataclasses import dataclass
6+
from typing import Optional
7+
8+
import torch
9+
import transformers
10+
from bert_score import BERTScorer
11+
from simple_parsing.helpers.fields import choice
12+
from torch import Tensor
13+
14+
from mbrs import timer
15+
16+
from . import Metric, register
17+
18+
transformers.logging.set_verbosity_error()
19+
20+
21+
class BERTScoreScoreType(int, enum.Enum):
22+
precision = 0
23+
recall = 1
24+
f1 = 2
25+
26+
27+
@register("bertscore")
28+
class MetricBERTScore(Metric):
29+
"""BERTScore metric class."""
30+
31+
scorer: BERTScorer
32+
33+
@dataclass
34+
class Config(Metric.Config):
35+
"""BERTScore metric configuration.
36+
37+
- score_type (BERTScoreScoreType): The output score type, i.e.,
38+
precision, recall, or f1.
39+
- model_type (str): Contexual embedding model specification, default using the
40+
suggested model for the target langauge; has to specify at least one of
41+
`model_type` or `lang`.
42+
- num_layers (int): The layer of representation to use. Default using the number
43+
of layer tuned on WMT16 correlation data.
44+
- idf (bool): A booling to specify whether to use idf or not. (This should be
45+
True even if `idf_sents` is given.)
46+
- idf_sents (list[str]): List of sentences used to compute the idf weights.
47+
- batch_size (int): Bert score processing batch size
48+
- nthreads (int): Number of threads.
49+
- lang (str): Language of the sentences; has to specify at least one of
50+
`model_type` or `lang`. `lang` needs to be specified when
51+
`rescale_with_baseline` is True.
52+
- rescale_with_baseline (bool): Rescale bertscore with pre-computed baseline.
53+
- baseline_path (str): Customized baseline file.
54+
- use_fast_tokenizer (bool): `use_fast` parameter passed to HF tokenizer.
55+
- fp16 (bool): Use float16 for the forward computation.
56+
- bf16 (bool): Use bfloat16 for the forward computation.
57+
- cpu (bool): Use CPU for the forward computation.
58+
"""
59+
60+
score_type: BERTScoreScoreType = choice(
61+
BERTScoreScoreType, default=BERTScoreScoreType.f1
62+
)
63+
model_type: Optional[str] = None
64+
num_layers: Optional[int] = None
65+
batch_size: int = 64
66+
nthreads: int = 4
67+
all_layers: bool = False
68+
idf: bool = False
69+
idf_sents: Optional[list[str]] = None
70+
lang: Optional[str] = None
71+
rescale_with_baseline: bool = False
72+
baseline_path: Optional[str] = None
73+
use_fast_tokenizer: bool = False
74+
fp16: bool = False
75+
bf16: bool = False
76+
cpu: bool = False
77+
78+
def __init__(self, cfg: MetricBERTScore.Config):
79+
self.cfg = cfg
80+
self.scorer = BERTScorer(
81+
model_type=cfg.model_type,
82+
num_layers=cfg.num_layers,
83+
batch_size=cfg.batch_size,
84+
nthreads=cfg.nthreads,
85+
all_layers=cfg.all_layers,
86+
idf=cfg.idf,
87+
idf_sents=cfg.idf_sents,
88+
device="cpu" if cfg.cpu else None,
89+
lang=cfg.lang,
90+
rescale_with_baseline=cfg.rescale_with_baseline,
91+
baseline_path=cfg.baseline_path,
92+
use_fast_tokenizer=cfg.use_fast_tokenizer,
93+
)
94+
self.scorer._model.eval()
95+
for param in self.scorer._model.parameters():
96+
param.requires_grad = False
97+
98+
if not cfg.cpu and torch.cuda.is_available():
99+
if cfg.fp16:
100+
self.scorer._model = self.scorer._model.half()
101+
elif cfg.bf16:
102+
self.scorer._model = self.scorer._model.bfloat16()
103+
self.scorer._model = self.scorer._model.cuda()
104+
105+
@property
106+
def device(self) -> torch.device:
107+
"""Returns the device of the model."""
108+
return self.scorer._model.device
109+
110+
def _choose_output_score(self, triplet: tuple[Tensor, Tensor, Tensor]) -> Tensor:
111+
"""Choose the output score from the triplet of precision, recall, and f1 scores.
112+
113+
Args:
114+
triplet (tuple[Tensor, Tensor, Tensor]): A triplet of precision, recall, and f1 scores.
115+
116+
Returns:
117+
Tensor: Output score.
118+
"""
119+
return triplet[self.cfg.score_type]
120+
121+
def score(self, hypothesis: str, reference: str, *_, **__) -> float:
122+
"""Calculate the score of the given hypothesis.
123+
124+
Args:
125+
hypothesis (str): A hypothesis.
126+
reference (str): A reference.
127+
128+
Returns:
129+
float: The score of the given hypothesis.
130+
"""
131+
return self._choose_output_score(
132+
self.scorer.score(
133+
[hypothesis],
134+
[reference],
135+
batch_size=self.cfg.batch_size,
136+
)
137+
).item()
138+
139+
def scores(self, hypotheses: list[str], references: list[str], *_, **__) -> Tensor:
140+
"""Calculate the scores of the given hypothesis.
141+
142+
Args:
143+
hypotheses (list[str]): N hypotheses.
144+
references (list[str]): N references.
145+
146+
Returns:
147+
Tensor: The N scores of the given hypotheses.
148+
"""
149+
150+
with timer.measure("score") as t:
151+
t.set_delta_ncalls(len(hypotheses))
152+
return self._choose_output_score(
153+
self.scorer.score(
154+
hypotheses,
155+
references,
156+
batch_size=self.cfg.batch_size,
157+
)
158+
).view(len(hypotheses))
159+
160+
def pairwise_scores(
161+
self, hypotheses: list[str], references: list[str], *_, **__
162+
) -> Tensor:
163+
"""Calculate the pairwise scores.
164+
165+
Args:
166+
hypotheses (list[str]): Hypotheses.
167+
references (list[str]): References.
168+
169+
Returns:
170+
Tensor: Score matrix of shape `(H, R)`, where `H` is the number
171+
of hypotheses and `R` is the number of references.
172+
"""
173+
hyps, refs = tuple(zip(*itertools.product(hypotheses, references)))
174+
with timer.measure("score") as t:
175+
t.set_delta_ncalls(len(hypotheses) * len(references))
176+
return self._choose_output_score(
177+
self.scorer.score(hyps, refs, batch_size=self.cfg.batch_size)
178+
).view(len(hypotheses), len(references))
179+
180+
def corpus_score(
181+
self, hypotheses: list[str], references: list[str], *_, **__
182+
) -> float:
183+
"""Calculate the corpus-level score.
184+
185+
Args:
186+
hypotheses (list[str]): Hypotheses.
187+
references (list[str]): References.
188+
189+
Returns:
190+
float: The corpus score.
191+
"""
192+
return self.scores(hypotheses, references).mean().item()

0 commit comments

Comments
 (0)