You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add jagged_sum operator for padded nested tensors to TritonBench (#2305)
Summary:
Pull Request resolved: #2305
Add a `jagged_sum` reduction operator for padded nested tensors, based on the PyTorch `sum` operator, to TritonBench. This diff uses the PyTorch function [`torch.ops.aten._jagged_to_padded_dense_forward`](https://www.internalfb.com/code/fbsource/[92c2a067ab04e3eebc999254fed4ae2fbea6def3]/fbcode/deeplearning/fbgemm/fbgemm_gpu/fb/inductor_lowerings/elementwise_ops.py?lines=26), hosted at this [GitHub pull request](pytorch/pytorch#125968), to pad each 2-dimensional tensor in a nested tensor of shape `(B, *, M)`, then reduce across the `N`-th dimension (`dim == 1`) to a `(B, M)` output tensor.
Measure accuracy of padded implementation against unpadded baseline implementation via `accuracy` TritonBench metric.
Reviewed By: davidberard98
Differential Revision: D58423489
fbshipit-source-id: d2f6095f8af1cb188bb979e2f5605ad80db50a46
# greater sparsity --> shorter sequence lengths on ragged dimension
99
-
seqlen_avg=math.floor(
100
-
self.seqlen* (1-self.sparsity)
101
-
) # average sequence length across all tensors in nested tensor
102
-
seqlen_margin=math.floor(
103
-
self.seqlen*RANDOM_CHOICE_MARGIN
104
-
) # use margin to constrain sequence lengths to range [seqlen_avg - seqlen_margin, seqlen_avg + seqlen_margin] to approximate an average sequence length, which correlates with sparsity
105
-
106
-
for_inrange(B):
107
-
seqlen_randint=random.randint(
108
-
max(seqlen_avg-seqlen_margin, 1),
109
-
min(seqlen_avg+seqlen_margin, self.seqlen),
110
-
)
111
-
tensor_2d=torch.randn(
112
-
(seqlen_randint, M), device=self.device, dtype=self.dtype
113
-
)
114
-
tensors.append(tensor_2d)
115
-
116
-
nt=torch.nested.nested_tensor(
117
-
tensors,
118
-
layout=torch.jagged,
119
-
device=self.device,
120
-
dtype=self.dtype,
110
+
B_M_vals=itertools.product(B_vals, M_vals)
111
+
112
+
forB, MinB_M_vals:
113
+
tensors= []
114
+
115
+
# greater sparsity --> shorter sequence lengths on ragged dimension
116
+
seqlen_avg=math.floor(
117
+
self.seqlen* (1-self.sparsity)
118
+
) # average sequence length across all tensors in nested tensor
119
+
seqlen_margin=math.floor(
120
+
self.seqlen*RANDOM_CHOICE_MARGIN
121
+
) # use margin to constrain sequence lengths to range [seqlen_avg - seqlen_margin, seqlen_avg + seqlen_margin] to approximate an average sequence length, which correlates with sparsity
122
+
123
+
for_inrange(B):
124
+
seqlen_randint=random.randint(
125
+
max(
126
+
seqlen_avg-seqlen_margin, 1
127
+
), # seqlen_randint must be at least 1
128
+
min(
129
+
seqlen_avg+seqlen_margin, self.seqlen
130
+
), # seqlen_randint must not exceed self.seqlen
121
131
)
132
+
tensor_2d=torch.randn(
133
+
(seqlen_randint, M), device=self.device, dtype=self.dtype
0 commit comments