Skip to content

Commit 40b376d

Browse files
jananisriramfacebook-github-bot
authored andcommitted
Add jagged_sum operator for padded nested tensors to TritonBench (#2305)
Summary: Pull Request resolved: #2305 Add a `jagged_sum` reduction operator for padded nested tensors, based on the PyTorch `sum` operator, to TritonBench. This diff uses the PyTorch function [`torch.ops.aten._jagged_to_padded_dense_forward`](https://www.internalfb.com/code/fbsource/[92c2a067ab04e3eebc999254fed4ae2fbea6def3]/fbcode/deeplearning/fbgemm/fbgemm_gpu/fb/inductor_lowerings/elementwise_ops.py?lines=26), hosted at this [GitHub pull request](pytorch/pytorch#125968), to pad each 2-dimensional tensor in a nested tensor of shape `(B, *, M)`, then reduce across the `N`-th dimension (`dim == 1`) to a `(B, M)` output tensor. Measure accuracy of padded implementation against unpadded baseline implementation via `accuracy` TritonBench metric. Reviewed By: davidberard98 Differential Revision: D58423489 fbshipit-source-id: d2f6095f8af1cb188bb979e2f5605ad80db50a46
1 parent b949580 commit 40b376d

File tree

1 file changed

+60
-41
lines changed

1 file changed

+60
-41
lines changed

torchbenchmark/operators/jagged_sum/operator.py

Lines changed: 60 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
import itertools
23
import math
34
import random
45
from typing import Callable, Generator, List, Optional, Tuple
@@ -14,19 +15,21 @@
1415
register_metric,
1516
)
1617

17-
random.seed(16)
18-
torch.manual_seed(16)
18+
seed = 16
19+
random.seed(seed)
20+
torch.manual_seed(seed)
1921

2022
GIGABYTES_PER_BYTE = 1e-6
2123
RANDOM_CHOICE_MARGIN = 0.3
24+
ABSOLUTE_TOLERANCE = 1e-3
2225

2326

2427
def parse_op_args(args: List[str]):
2528
parser = argparse.ArgumentParser()
2629
parser.add_argument(
2730
"--seqlen",
2831
type=int,
29-
default=100,
32+
default=500,
3033
help="Maximum sequence length on ragged dimension (integer)",
3134
)
3235
parser.add_argument(
@@ -40,6 +43,9 @@ def parse_op_args(args: List[str]):
4043

4144
class Operator(BenchmarkOperator):
4245

46+
DEFAULT_METRICS = ["latency", "accuracy"]
47+
use_cuda_graphs = False # enables GPU/CPU sync (for methods like NestedTensor unbind)
48+
4349
def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
4450
super().__init__(mode=mode, device=device, extra_args=extra_args)
4551
self.sizes = range(4, 10, 2)
@@ -58,6 +64,17 @@ def torch_jagged_sum_no_pad(self, x: torch.Tensor):
5864
dtype=self.dtype,
5965
)
6066

67+
@register_benchmark()
68+
def torch_jagged_sum_pad(self, x: torch.Tensor):
69+
return lambda: torch.sum(
70+
torch.ops.aten._jagged_to_padded_dense_forward(
71+
x.values(),
72+
[x.offsets()], # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `offsets`.
73+
max_lengths=[self.seqlen], # max length of ragged dimension
74+
),
75+
dim=1,
76+
) # sum along ragged dimension (dim == 1)
77+
6178
def get_x_val(self, example_inputs):
6279
return len(example_inputs[0])
6380

@@ -90,50 +107,52 @@ def get_input_iter(self) -> Generator:
90107
"""
91108

92109
B_vals, M_vals = self.get_x_vals()
93-
94-
for B in B_vals:
95-
for M in M_vals:
96-
tensors = []
97-
98-
# greater sparsity --> shorter sequence lengths on ragged dimension
99-
seqlen_avg = math.floor(
100-
self.seqlen * (1 - self.sparsity)
101-
) # average sequence length across all tensors in nested tensor
102-
seqlen_margin = math.floor(
103-
self.seqlen * RANDOM_CHOICE_MARGIN
104-
) # use margin to constrain sequence lengths to range [seqlen_avg - seqlen_margin, seqlen_avg + seqlen_margin] to approximate an average sequence length, which correlates with sparsity
105-
106-
for _ in range(B):
107-
seqlen_randint = random.randint(
108-
max(seqlen_avg - seqlen_margin, 1),
109-
min(seqlen_avg + seqlen_margin, self.seqlen),
110-
)
111-
tensor_2d = torch.randn(
112-
(seqlen_randint, M), device=self.device, dtype=self.dtype
113-
)
114-
tensors.append(tensor_2d)
115-
116-
nt = torch.nested.nested_tensor(
117-
tensors,
118-
layout=torch.jagged,
119-
device=self.device,
120-
dtype=self.dtype,
110+
B_M_vals = itertools.product(B_vals, M_vals)
111+
112+
for B, M in B_M_vals:
113+
tensors = []
114+
115+
# greater sparsity --> shorter sequence lengths on ragged dimension
116+
seqlen_avg = math.floor(
117+
self.seqlen * (1 - self.sparsity)
118+
) # average sequence length across all tensors in nested tensor
119+
seqlen_margin = math.floor(
120+
self.seqlen * RANDOM_CHOICE_MARGIN
121+
) # use margin to constrain sequence lengths to range [seqlen_avg - seqlen_margin, seqlen_avg + seqlen_margin] to approximate an average sequence length, which correlates with sparsity
122+
123+
for _ in range(B):
124+
seqlen_randint = random.randint(
125+
max(
126+
seqlen_avg - seqlen_margin, 1
127+
), # seqlen_randint must be at least 1
128+
min(
129+
seqlen_avg + seqlen_margin, self.seqlen
130+
), # seqlen_randint must not exceed self.seqlen
121131
)
132+
tensor_2d = torch.randn(
133+
(seqlen_randint, M), device=self.device, dtype=self.dtype
134+
)
135+
tensors.append(tensor_2d)
122136

123-
yield (nt,)
137+
nt = torch.nested.nested_tensor(
138+
tensors,
139+
layout=torch.jagged,
140+
device=self.device,
141+
dtype=self.dtype,
142+
)
124143

125-
@register_metric()
126-
def B_M(self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics):
127-
return tuple([(ex.size(0), ex.size(2)) for ex in example_inputs])[
128-
0
129-
] # return (B, M) for each example input
144+
yield (nt,)
145+
146+
def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
147+
output = fn()
148+
baseline_output = baseline_fn()
149+
return torch.allclose(output, baseline_output, atol=ABSOLUTE_TOLERANCE)
130150

131151
@register_metric()
132152
def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
133-
gbps = (
134-
lambda ms: example_inputs[0].element_size()
153+
return (
154+
example_inputs[0].element_size()
135155
* example_inputs[0].numel()
136-
/ ms
156+
/ metrics.latency
137157
* GIGABYTES_PER_BYTE
138158
)
139-
return list(map(gbps, metrics.latency if metrics.latency else [0]))

0 commit comments

Comments
 (0)