Skip to content

Commit 728540d

Browse files
jananisriramfacebook-github-bot
authored andcommitted
Add variable seqlen and sparsity parameters to jagged_sum benchmark
Summary: Modify existing `jagged_sum` operator benchmark to optionally accept any of the following parameters: `B` (dimension 0 of nested tensor), `M` (dimension 2 of nested tensor), `seqlen` (maximum sequence length on ragged dimension), or `sparsity` (average sparsity on ragged dimension). This diff fixes the provided command line parameters and varies all other parameters above, enabling testing of all combinations of multiple parameters in parallel. The following errors persist with sufficiently large inputs: - `RuntimeError: numel needs to be smaller than int32_t max; otherwise, please use packed_accessor64` (when running command `buck2 run mode/{opt,inplace} //pytorch/benchmark:triton -- --op jagged_sum --B 1024 --M 1024 --sparsity 0.3`) - `torch.OutOfMemoryError: CUDA out of memory.` Reviewed By: davidberard98 Differential Revision: D58772201
1 parent 53faa0a commit 728540d

File tree

2 files changed

+83
-44
lines changed

2 files changed

+83
-44
lines changed

torchbenchmark/operators/jagged_sum/kernels.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ def triton_jagged_sum_kernel_simple_fused_sum_then_buffer(
5959
for block_pos in range(
6060
0, MAX_SEQLEN, BLOCK_SIZE_RAGGED
6161
): # loop over ragged dimension, ranging until maximum seqlen
62-
block_start_ragged = ragged_start + block_pos # offset block position by start of current program
62+
block_start_ragged = (
63+
ragged_start + block_pos
64+
) # offset block position by start of current program
6365
offsets_ragged = block_start_ragged + tl.arange(0, BLOCK_SIZE_RAGGED)
6466
mask_ragged = offsets_ragged < ragged_end
6567

@@ -132,7 +134,9 @@ def triton_jagged_sum_kernel_simple_fused_buffer_then_sum(
132134
for block_pos in range(
133135
0, MAX_SEQLEN, BLOCK_SIZE_RAGGED
134136
): # loop over ragged dimension, ranging until maximum seqlen
135-
block_start_ragged = ragged_start + block_pos # offset block position by start of current program
137+
block_start_ragged = (
138+
ragged_start + block_pos
139+
) # offset block position by start of current program
136140
offsets_ragged = block_start_ragged + tl.arange(0, BLOCK_SIZE_RAGGED)
137141
mask_ragged = offsets_ragged < ragged_end
138142

torchbenchmark/operators/jagged_sum/operator.py

Lines changed: 77 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,25 @@
3131

3232
def parse_op_args(args: List[str]):
3333
parser = argparse.ArgumentParser()
34+
parser.add_argument(
35+
"--B",
36+
type=int,
37+
help="[Optional] Size of dimension 0 in shape (B, *, M) (integer)",
38+
)
39+
parser.add_argument(
40+
"--M",
41+
type=int,
42+
help="[Optional] Size of dimension 2 in shape (B, *, M) (integer)",
43+
)
3444
parser.add_argument(
3545
"--seqlen",
3646
type=int,
37-
default=500,
38-
help="Maximum sequence length on ragged dimension (integer)",
47+
help="[Optional] Maximum sequence length on ragged dimension (integer)",
3948
)
4049
parser.add_argument(
4150
"--sparsity",
4251
type=float,
43-
default=0.5,
44-
help="Average sparsity for nested tensor (float, (0.0-1.0))",
52+
help="[Optional] Average sparsity for nested tensor (float, (0.0-1.0))",
4553
)
4654
parser.add_argument(
4755
"--sum-then-buffer",
@@ -91,12 +99,16 @@ def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = Non
9199
) # bias towards larger sizes, which are more representative of real-world shapes
92100

93101
args = parse_op_args(self.extra_args)
94-
self.seqlen = args.seqlen
95-
self.sparsity = args.sparsity
102+
self.B = args.B if args.B is not None else None
103+
self.M = args.M if args.M is not None else None
104+
self.seqlen = args.seqlen if args.seqlen is not None else None
105+
self.sparsity = args.sparsity if args.sparsity is not None else None
96106
self.sum_then_buffer = args.sum_then_buffer
97107

98108
@register_benchmark(baseline=True)
99-
def torch_jagged_sum_no_pad(self, x: torch.Tensor):
109+
def torch_jagged_sum_no_pad(
110+
self, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float
111+
):
100112
return lambda: torch.tensor(
101113
[
102114
torch.sum(t, dim=0).tolist() for t in x.unbind()
@@ -106,66 +118,87 @@ def torch_jagged_sum_no_pad(self, x: torch.Tensor):
106118
)
107119

108120
@register_benchmark()
109-
def torch_jagged_sum_pad(self, x: torch.Tensor):
121+
def torch_jagged_sum_pad(
122+
self, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float
123+
):
110124
return lambda: torch.sum(
111125
torch.ops.aten._jagged_to_padded_dense_forward(
112126
x.values(),
113127
[x.offsets()], # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `offsets`.
114-
max_lengths=[self.seqlen], # max length of ragged dimension
128+
max_lengths=[seqlen], # max length of ragged dimension
115129
),
116130
dim=1,
117131
) # sum along ragged dimension (dim == 1)
118132

119133
@register_benchmark()
120-
def triton_jagged_sum_no_pad(self, x: torch.Tensor):
134+
def triton_jagged_sum_no_pad(
135+
self, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float
136+
):
121137
def _inner():
122-
return execute_kernel_simple_fused(x, self.seqlen, self.sum_then_buffer)
138+
return execute_kernel_simple_fused(x, seqlen, self.sum_then_buffer)
123139

124140
return _inner
125141

126142
def get_x_val(self, example_inputs):
127143
return len(example_inputs[0])
128144

129-
def get_x_vals(self) -> Tuple[List[int], List[int]]:
130-
B_vals, M_vals = [], []
131-
132-
B_vals.extend([2**n for n in self.sizes])
133-
B_vals.extend(
134-
[
135-
(n - 1) * (n + 1)
136-
for n in self.sizes
137-
if n - 1 > 0 and (n - 1) * (n + 1) not in B_vals
138-
]
139-
)
145+
def get_x_vals(self) -> Tuple[List[int], List[int], List[int], List[float]]:
146+
B_vals, M_vals, seqlen_vals, sparsity_vals = [], [], [], []
147+
148+
def get_dim_vals():
149+
vals = []
150+
vals.extend([2**n for n in self.sizes])
151+
vals.extend(
152+
[
153+
(n - 1) * (n + 1)
154+
for n in self.sizes
155+
if n - 1 > 0 and (n - 1) * (n + 1) not in vals
156+
]
157+
)
158+
return vals
159+
160+
if self.B is None:
161+
B_vals.extend(get_dim_vals())
162+
else:
163+
B_vals.extend([self.B])
164+
165+
if self.M is None:
166+
M_vals.extend(get_dim_vals())
167+
else:
168+
M_vals.extend([self.M])
169+
170+
if self.seqlen is None:
171+
seqlen_vals.extend(
172+
list(range(100, 1000, 100))
173+
+ list(range(1000, 10000, 1000))
174+
)
175+
else:
176+
seqlen_vals.extend([self.seqlen])
140177

141-
M_vals.extend([2**n for n in self.sizes])
142-
M_vals.extend(
143-
[
144-
(n - 1) * (n + 1)
145-
for n in self.sizes
146-
if n - 1 > 0 and (n - 1) * (n + 1) not in M_vals
147-
]
148-
)
178+
if self.sparsity is None:
179+
sparsity_vals.extend([n / 10 for n in range(1, 10)])
180+
else:
181+
sparsity_vals.extend([self.sparsity])
149182

150-
return B_vals, M_vals
183+
return B_vals, M_vals, seqlen_vals, sparsity_vals
151184

152185
def get_input_iter(self) -> Generator:
153186
"""
154187
Generate random nested tensors of shape (B, *, M), where * is the ragged dimension
155188
"""
156189

157-
B_vals, M_vals = self.get_x_vals()
158-
B_M_vals = itertools.product(B_vals, M_vals)
190+
B_vals, M_vals, seqlen_vals, sparsity_vals = self.get_x_vals()
191+
vals = itertools.product(B_vals, M_vals, seqlen_vals, sparsity_vals)
159192

160-
for B, M in B_M_vals:
193+
for B, M, seqlen, sparsity in vals:
161194
tensors = []
162195

163196
# greater sparsity --> shorter sequence lengths on ragged dimension
164197
seqlen_avg = math.floor(
165-
self.seqlen * (1 - self.sparsity)
198+
seqlen * (1 - sparsity)
166199
) # average sequence length across all tensors in nested tensor
167200
seqlen_margin = math.floor(
168-
self.seqlen * RANDOM_CHOICE_MARGIN
201+
seqlen * RANDOM_CHOICE_MARGIN
169202
) # use margin to constrain sequence lengths to range [seqlen_avg - seqlen_margin, seqlen_avg + seqlen_margin] to approximate an average sequence length, which correlates with sparsity
170203

171204
for _ in range(B):
@@ -174,7 +207,7 @@ def get_input_iter(self) -> Generator:
174207
seqlen_avg - seqlen_margin, 1
175208
), # seqlen_randint must be at least 1
176209
min(
177-
seqlen_avg + seqlen_margin, self.seqlen
210+
seqlen_avg + seqlen_margin, seqlen
178211
), # seqlen_randint must not exceed self.seqlen
179212
)
180213
tensor_2d = torch.randn(
@@ -189,7 +222,7 @@ def get_input_iter(self) -> Generator:
189222
dtype=self.dtype,
190223
)
191224

192-
yield (nt,)
225+
yield (nt, B, M, seqlen, sparsity)
193226

194227
def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
195228
output = fn()
@@ -205,15 +238,17 @@ def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
205238
* GIGABYTES_PER_BYTE
206239
)
207240

208-
@register_metric(x_only=True)
241+
@register_metric(x_only=True) # TODO modify!!!!
209242
def input_shape(
210243
self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics
211244
):
212245
return (
213-
example_inputs[0].shape[0],
246+
f"B: {example_inputs[1]}", # B
214247
"*",
215-
example_inputs[0].shape[2],
216-
) # return (B, '*', M) for each example input
248+
f"M: {example_inputs[2]}", # M
249+
f"max seqlen: {example_inputs[3]}", # seqlen
250+
f"sparsity: {example_inputs[4]}", # sparsity
251+
) # return (B, '*', M, max seqlen, sparsity) for each example input
217252

218253
@register_metric(skip_baseline=True)
219254
def best_config(

0 commit comments

Comments
 (0)