|
20 | 20 | from .data_io import parse_args, read_shapes_from_csv
|
21 | 21 | from .triton_matmul import matmul as triton_matmul
|
22 | 22 | from .triton_matmul import matmul_kernel as triton_matmul_kernel
|
| 23 | +import torch._inductor.config as inductor_config |
23 | 24 |
|
24 | 25 | import inspect
|
25 | 26 | try:
|
@@ -128,6 +129,39 @@ def colfax_cutlass_matmul(self, a, b, bias) -> Callable:
|
128 | 129 | else:
|
129 | 130 | return lambda: colfax_gemm(a, b, alpha=1.0, beta=1.0)
|
130 | 131 |
|
| 132 | + @register_benchmark() |
| 133 | + def pt2_triton_matmul(self, a, b, bias) -> Callable: |
| 134 | + torch._dynamo.reset() |
| 135 | + with inductor_config.patch( |
| 136 | + max_autotune=True, |
| 137 | + max_autotune_gemm_backends="TRITON", |
| 138 | + autotune_fallback_to_aten=False, |
| 139 | + ): |
| 140 | + if bias is not None: |
| 141 | + f = lambda a, b: a.matmul(b) + bias |
| 142 | + else: |
| 143 | + f = lambda a, b: a.matmul(b) |
| 144 | + compiled = torch.compile(f, dynamic=False) |
| 145 | + compiled(a, b) |
| 146 | + return lambda: compiled(a, b) |
| 147 | + |
| 148 | + @register_benchmark() |
| 149 | + def pt2_cutlass_matmul(self, a, b, bias) -> Callable: |
| 150 | + torch._dynamo.reset() |
| 151 | + with inductor_config.patch( |
| 152 | + max_autotune=True, |
| 153 | + max_autotune_gemm_backends="CUTLASS", |
| 154 | + autotune_fallback_to_aten=False, |
| 155 | + ): |
| 156 | + if bias is not None: |
| 157 | + f = lambda a, b: a.matmul(b) + bias |
| 158 | + else: |
| 159 | + f = lambda a, b: a.matmul(b) |
| 160 | + # cutlass needs to know the static shape, so set dynamic to False |
| 161 | + compiled = torch.compile(f, dynamic=False) |
| 162 | + compiled(a, b) |
| 163 | + return lambda: compiled(a, b) |
| 164 | + |
131 | 165 | @register_x_val(label="(M, N, K)")
|
132 | 166 | def get_x_val(self, example_inputs) -> Tuple[int, int, int]:
|
133 | 167 | # x-value: computation intensity
|
|
0 commit comments