Skip to content

Commit a5f1710

Browse files
int3facebook-github-bot
authored andcommitted
Add CUTLASS + PT2-Triton kernels to gemm benchmark
Summary: I did it by simply setting the max_autotune backend to only CUTLASS/TRITON as needed. I also modified the baseline benchmark to explicitly disable autotuning, so that we can be more confident that it is invoking the ATen kernel. Reviewed By: bertmaher, xuzhao9, chenyang78 Differential Revision: D56685216 fbshipit-source-id: 1638266254690b929f8c5591a194127c6a7c7be8
1 parent d21607d commit a5f1710

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

torchbenchmark/operators/gemm/operator.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .data_io import parse_args, read_shapes_from_csv
2121
from .triton_matmul import matmul as triton_matmul
2222
from .triton_matmul import matmul_kernel as triton_matmul_kernel
23+
import torch._inductor.config as inductor_config
2324

2425
import inspect
2526
try:
@@ -128,6 +129,39 @@ def colfax_cutlass_matmul(self, a, b, bias) -> Callable:
128129
else:
129130
return lambda: colfax_gemm(a, b, alpha=1.0, beta=1.0)
130131

132+
@register_benchmark()
133+
def pt2_triton_matmul(self, a, b, bias) -> Callable:
134+
torch._dynamo.reset()
135+
with inductor_config.patch(
136+
max_autotune=True,
137+
max_autotune_gemm_backends="TRITON",
138+
autotune_fallback_to_aten=False,
139+
):
140+
if bias is not None:
141+
f = lambda a, b: a.matmul(b) + bias
142+
else:
143+
f = lambda a, b: a.matmul(b)
144+
compiled = torch.compile(f, dynamic=False)
145+
compiled(a, b)
146+
return lambda: compiled(a, b)
147+
148+
@register_benchmark()
149+
def pt2_cutlass_matmul(self, a, b, bias) -> Callable:
150+
torch._dynamo.reset()
151+
with inductor_config.patch(
152+
max_autotune=True,
153+
max_autotune_gemm_backends="CUTLASS",
154+
autotune_fallback_to_aten=False,
155+
):
156+
if bias is not None:
157+
f = lambda a, b: a.matmul(b) + bias
158+
else:
159+
f = lambda a, b: a.matmul(b)
160+
# cutlass needs to know the static shape, so set dynamic to False
161+
compiled = torch.compile(f, dynamic=False)
162+
compiled(a, b)
163+
return lambda: compiled(a, b)
164+
131165
@register_x_val(label="(M, N, K)")
132166
def get_x_val(self, example_inputs) -> Tuple[int, int, int]:
133167
# x-value: computation intensity

0 commit comments

Comments
 (0)