Skip to content

Commit f985e9d

Browse files
jananisriramfacebook-github-bot
authored andcommitted
Add unit tests on CPU for TritonBench features (#2323)
Summary: Pull Request resolved: #2323 Add unit tests that run on the CPU to verify the behavior of the following: - `x_only = True` for metric registration in [`register_metric()`](https://www.internalfb.com/code/fbsource/[731f07681fbbb38750aee3b165137e39fa6cee50]/fbcode/pytorch/benchmark/torchbenchmark/util/triton_op.py?lines=337) - custom `label` argument for benchmark registration in [`register_benchmark()`](https://www.internalfb.com/code/fbsource/[731f07681fbbb38750aee3b165137e39fa6cee50]/fbcode/pytorch/benchmark/torchbenchmark/util/triton_op.py?lines=316) Reviewed By: xuzhao9 Differential Revision: D58558868
1 parent caa76d8 commit f985e9d

File tree

3 files changed

+46
-0
lines changed

3 files changed

+46
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .operator import Operator
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from typing import Generator, List, Optional
2+
3+
import torch
4+
5+
from torchbenchmark.util.triton_op import (
6+
BenchmarkOperator,
7+
BenchmarkOperatorMetrics,
8+
register_benchmark,
9+
register_metric,
10+
)
11+
12+
13+
class Operator(BenchmarkOperator):
14+
15+
DEFAULT_METRICS = ["test_metric"]
16+
17+
def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
18+
super().__init__(mode=mode, device=device, extra_args=extra_args)
19+
20+
@register_benchmark(label="new_op_label")
21+
def test_op(self, x: torch.Tensor):
22+
return lambda: x
23+
24+
def get_x_val(self, example_inputs):
25+
return example_inputs[0].shape
26+
27+
def get_x_vals(self) -> List[int]:
28+
return [2**n for n in [1, 2, 3]]
29+
30+
def get_input_iter(self) -> Generator:
31+
for x in self.get_x_vals():
32+
yield (torch.Tensor(torch.randn(x, device=self.device, dtype=self.dtype)),)
33+
34+
@register_metric(x_only=True)
35+
def test_metric(
36+
self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics
37+
):
38+
return [ex.shape[0] + 2 for ex in example_inputs]
39+
40+
@register_metric()
41+
def test_metric_per_benchmark(
42+
self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics
43+
):
44+
return [ex.shape[0] + 3 for ex in example_inputs]

torchbenchmark/util/triton_op.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,7 @@ def __init__(self, mode: str, device: str, extra_args: Optional[List[str]]=None)
468468
self._only = _split_params_by_comma(self.tb_args.only)
469469
self._input_id = self.tb_args.input_id
470470
self._num_inputs = self.tb_args.num_inputs
471+
self.device = device
471472

472473
# Run the post initialization
473474
def __post__init__(self):

0 commit comments

Comments
 (0)