Skip to content

Commit 2d8999b

Browse files
jananisriramfacebook-github-bot
authored andcommitted
Add sum reduction operator to TritonBench (#2282)
Summary: Pull Request resolved: #2282 Add a Triton reduction kernel for the `sum` operator where `dim=None` to TritonBench, following the [TritonBench guide](https://fb.workplace.com/notes/953949486404240). This implementation works for all matrices being reduced to a scalar value. To measure accuracy of Triton reduction kernel, add accuracy metric to sum kernel in TritonBench in order to test accuracy of Triton implementation against baseline PyTorch implementation, referencing [`torchbenchmark/operators/gemm/operator.py`](https://www.internalfb.com/code/fbsource/[767bb6faa353685b84f08a39f36fdcf6ca170c85]/fbcode/pytorch/benchmark/torchbenchmark/operators/gemm/operator.py?lines=236). Reset output registers per run of the Triton kernel for accurate Triton output. To measure performance of the Triton reduction kernel against PyTorch, add gbps metric, referencing [`torchbenchmark/operators/vector_add/operator.py`](https://www.internalfb.com/code/fbsource/[858eda681c7618f9427ba55cef8d4aba712cb26e]/fbcode/pytorch/benchmark/torchbenchmark/operators/vector_add/operator.py?lines=19). Referenced the existing [vector_add](https://www.internalfb.com/code/fbsource/fbcode/pytorch/benchmark/torchbenchmark/operators/vector_add/) and [grouped_gemm](https://www.internalfb.com/code/fbsource/fbcode/pytorch/benchmark/torchbenchmark/operators/grouped_gemm/) TritonBench operators as frameworks for implementation. See the [TritonBench Operator Coverage Tracker](https://docs.google.com/spreadsheets/d/1091POOPSPsUnlNVEKaz2X_DQXdIwFv-fGOH_g9by-Zo/edit#gid=0) for current operator coverage in TritonBench. Reviewed By: xuzhao9, davidberard98 Differential Revision: D58048782 fbshipit-source-id: 73fdf075527733a4c56b306909d4cf4bda121971
1 parent cfae89c commit 2d8999b

File tree

3 files changed

+125
-0
lines changed

3 files changed

+125
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .operator import Operator
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
5+
6+
@triton.jit
7+
def triton_sum_kernel_scalar(
8+
input_ptr,
9+
output_ptr,
10+
M, # number of elements
11+
BLOCK_SIZE_M: tl.constexpr, # number of elements per block
12+
):
13+
pid = tl.program_id(axis=0) # i-th block of input
14+
15+
block_start = pid * BLOCK_SIZE_M
16+
# offsets have shape equal to input shape
17+
offsets = block_start + tl.arange(0, BLOCK_SIZE_M) # create 1D vector (input shape) ranging from beginning to end of this program's block
18+
19+
# mask has shape equal to input shape
20+
mask = offsets < M # mask out offsets that are out of bounds for input
21+
22+
# loaded pointers have shape equal to input shape
23+
x = tl.load(input_ptr + offsets, mask=mask, other=mask) # load input, where the loaded pointers are in the desired input shape
24+
25+
output = tl.sum(x)
26+
27+
# output_offsets have shape equal to output shape
28+
output_offsets = tl.arange(0, 1) # create offsets for scalar output pointer (output shape == (1,))
29+
30+
# stored pointers have shape equal to output shape
31+
tl.store(output_ptr + output_offsets, output) # store output, where the stored pointers are in the desired output shape
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import argparse
2+
from typing import Callable, Generator, List, Optional, Tuple
3+
4+
import torch
5+
import triton
6+
import triton.language as tl
7+
from torchbenchmark.util.triton_op import (
8+
BenchmarkOperator,
9+
BenchmarkOperatorMetrics,
10+
register_benchmark,
11+
register_metric,
12+
)
13+
14+
from .kernels import triton_sum_kernel_scalar
15+
16+
17+
class Operator(BenchmarkOperator):
18+
19+
DEFAULT_METRICS = ["latency", "accuracy"]
20+
21+
def __init__(self, mode: str, device: str, extra_args: Optional[List[str]]=None):
22+
super().__init__(mode=mode, device=device, extra_args=extra_args)
23+
self.sizes = range(1, 17)
24+
25+
@register_benchmark()
26+
def triton_sum(self, x: torch.Tensor):
27+
x_1d = x.view(-1)
28+
M = x_1d.shape[0]
29+
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE_M"]),)
30+
BLOCK_SIZE_M = triton.next_power_of_2(M) # race condition in cases where BLOCK_SIZE < n_elements^2
31+
32+
def _inner():
33+
output = torch.zeros(1, device=x.device, dtype=x.dtype)
34+
35+
triton_sum_kernel_scalar[grid](
36+
x_1d, output, M=M, BLOCK_SIZE_M=BLOCK_SIZE_M,
37+
)
38+
39+
return output
40+
41+
return _inner
42+
43+
@register_benchmark(baseline=True)
44+
def torch_sum(self, x: torch.Tensor):
45+
result = torch.sum(x)
46+
return lambda: result
47+
48+
def get_x_val(self, example_inputs):
49+
return len(example_inputs[0])
50+
51+
def get_x_vals(self) -> List[int]:
52+
x_vals = []
53+
54+
x_vals.extend([2**n for n in self.sizes])
55+
x_vals.extend([(n - 1) * (n + 1) for n in self.sizes if n - 1 > 0])
56+
57+
return x_vals
58+
59+
def get_input_iter(self) -> Generator:
60+
# reduce to a scalar value
61+
for size in self.get_x_vals(): # 1D matrix
62+
input_1d = torch.randn(size, device=self.device, dtype=self.dtype)
63+
yield (input_1d, )
64+
65+
for size in self.get_x_vals(): # 2D matrix
66+
if size < pow(2, 8): # ensure we don't exceed floating point limitations
67+
input_2d = torch.randn((size, size), device=self.device, dtype=self.dtype)
68+
yield (input_2d, )
69+
70+
for size in self.get_x_vals(): # 3D matrix
71+
if size < pow(2, 4): # ensure we don't exceed floating point limitations
72+
input_2d = torch.randn((size, size, size), device=self.device, dtype=self.dtype)
73+
yield (input_2d, )
74+
75+
def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
76+
output = fn()
77+
baseline_output = baseline_fn()
78+
return torch.allclose(output, baseline_output, atol=1e-4)
79+
80+
@register_metric(skip_baseline=True)
81+
def input_dims(self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics):
82+
return [ex.dim() for ex in example_inputs]
83+
84+
@register_metric()
85+
def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
86+
gbps = (
87+
lambda ms: 3
88+
* example_inputs[0].element_size()
89+
* example_inputs[0].numel()
90+
/ ms
91+
* 1e-6
92+
)
93+
return list(map(gbps, metrics.latency if metrics.latency else [0]))

0 commit comments

Comments
 (0)