|
| 1 | +""" |
| 2 | +Compute a bf16 (activation) x int4 (weight) gemm. |
| 3 | +Inspired by [gpt-fast](https://github.com/pytorch-labs/gpt-fast) |
| 4 | +ATen kernels from tinygemm |
| 5 | +Triton implementation by @jlebar: https://gist.github.com/jlebar/3435b2c00deea53258887ce37231e5e2 |
| 6 | +""" |
| 7 | + |
| 8 | +import argparse |
| 9 | +import os |
| 10 | +import statistics |
| 11 | +import torch |
| 12 | +import triton.ops |
| 13 | +import triton.language as tl |
| 14 | + |
| 15 | +from typing import Any |
| 16 | + |
| 17 | +from torchbenchmark.util.triton_op import ( |
| 18 | + BenchmarkOperator, |
| 19 | + BenchmarkOperatorMetrics, |
| 20 | + register_benchmark, |
| 21 | + register_metric, |
| 22 | +) |
| 23 | + |
| 24 | +from .kernel import pack_2xint4, matmul, matmul_kernel |
| 25 | + |
| 26 | + |
| 27 | +class Operator(BenchmarkOperator): |
| 28 | + DEFAULT_METRICS = ["tflops", "gbps", "latency"] |
| 29 | + |
| 30 | + def __init__(self, mode, device, extra_args): |
| 31 | + super().__init__(mode=mode, device=device, extra_args=extra_args) |
| 32 | + # `Group size` and `inner K tiles` are defaults from gpt-fast. |
| 33 | + self.group_size = 32 |
| 34 | + self.inner_k_tiles = 8 |
| 35 | + |
| 36 | + def get_input_iter(self): |
| 37 | + def args(B, L, Dout, Din): |
| 38 | + x = torch.randn(B, L, Din, device=self.device, dtype=torch.bfloat16) |
| 39 | + w = torch.randint(-8, 7, (Din, Dout), device=self.device, dtype=torch.int32) |
| 40 | + scales_and_zeros = torch.randn( |
| 41 | + Din // self.group_size, |
| 42 | + Dout, |
| 43 | + 2, |
| 44 | + device=self.device, |
| 45 | + dtype=torch.bfloat16, |
| 46 | + ) |
| 47 | + return (x, w, scales_and_zeros) |
| 48 | + |
| 49 | + # LLama-2 shapes w/ 8-way tensor parallelism. |
| 50 | + name_to_shapes_70b = { |
| 51 | + "attn.wqkv": (8192, 1280), |
| 52 | + "attn.w0": (1024, 8192), |
| 53 | + "ffn.w13": (8192, 7168), |
| 54 | + "ffn.w2": (3584, 8192), |
| 55 | + } |
| 56 | + for seq_len in (1, 4096): |
| 57 | + for bsz in (1, 4, 16, 64): |
| 58 | + for name, (k, n) in name_to_shapes_70b.items(): |
| 59 | + yield args(bsz, seq_len, n, k) |
| 60 | + |
| 61 | + def get_x_val(self, example_inputs) -> float: |
| 62 | + x, w, scales_and_zeros = example_inputs |
| 63 | + B, m, k = x.size() |
| 64 | + _, n = w.size() |
| 65 | + return (B, m, n, k) |
| 66 | + |
| 67 | + @register_benchmark(baseline=True) |
| 68 | + def tinygemm(self, x, w, scales_and_zeros): |
| 69 | + x = x.reshape(-1, x.size(-1)) |
| 70 | + w_int4 = torch.ops.aten._convert_weight_to_int4pack( |
| 71 | + w.T.contiguous(), self.inner_k_tiles |
| 72 | + ) |
| 73 | + return lambda: torch.ops.aten._weight_int4pack_mm( |
| 74 | + x, w_int4, self.group_size, scales_and_zeros |
| 75 | + ) |
| 76 | + |
| 77 | + @register_benchmark() |
| 78 | + def triton(self, x, w, scales_and_zeros): |
| 79 | + x = x.reshape(-1, x.size(-1)) |
| 80 | + w_int4 = pack_2xint4(w).T.contiguous().T |
| 81 | + return lambda: matmul(x, w_int4) |
| 82 | + |
| 83 | + @register_metric() |
| 84 | + def best_config(self, fn, inputs, metrics): |
| 85 | + if "triton" in str(fn): |
| 86 | + return str(matmul_kernel.best_config) |
| 87 | + return "" |
| 88 | + |
| 89 | + @register_metric() |
| 90 | + def gbps(self, fn, example_inputs: Any, metrics: BenchmarkOperatorMetrics) -> float: |
| 91 | + def nbytes(t): |
| 92 | + return t.numel() * t.element_size() |
| 93 | + |
| 94 | + x, w, scale_and_zero = example_inputs |
| 95 | + c = fn() |
| 96 | + |
| 97 | + gb = (sum(nbytes(t) for t in (x, scale_and_zero, c)) + nbytes(w) // 8) / 1e9 |
| 98 | + return list(map(lambda ms: gb / ms * 1e3, metrics.latency)) |
| 99 | + |
| 100 | + @register_metric() |
| 101 | + def tflops( |
| 102 | + self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics |
| 103 | + ) -> float: |
| 104 | + a, b, _ = example_inputs |
| 105 | + B, m, k = a.size() |
| 106 | + m = B * m |
| 107 | + _, n = b.size() |
| 108 | + flops = 2 * m * n * k |
| 109 | + return [flops / x / 1e12 * 1e3 for x in metrics.latency] |
| 110 | + |
| 111 | + def plot(self): |
| 112 | + @triton.testing.perf_report( |
| 113 | + triton.testing.Benchmark( |
| 114 | + x_names=[ |
| 115 | + "B", |
| 116 | + "m", |
| 117 | + "n", |
| 118 | + "k", |
| 119 | + ], # argument names to use as an x-axis for the plot |
| 120 | + x_vals=self.output.x_vals, # different possible values for `x_name` |
| 121 | + line_arg="provider", # argument name whose value corresponds to a different line in the plot |
| 122 | + line_vals=[ |
| 123 | + "tinygemm", |
| 124 | + "triton", |
| 125 | + ], # possible values for `line_arg`` |
| 126 | + line_names=[ |
| 127 | + "tinygemm", |
| 128 | + "triton", |
| 129 | + ], # label name for the lines |
| 130 | + styles=[("blue", "-"), ("green", "-")], |
| 131 | + ylabel="tflops", # label name for the y-axis |
| 132 | + plot_name="int4-gemm-performance", # name for the plot. Used also as a file name for saving the plot. |
| 133 | + args={}, # values for function arguments not in `x_names` and `y_name` |
| 134 | + ) |
| 135 | + ) |
| 136 | + def _plot(B, m, n, k, provider): |
| 137 | + tflops = self.output.get_y_vals((B, m, n, k), provider, "tflops") |
| 138 | + return tflops |
| 139 | + |
| 140 | + save_path = "/tmp/int4_gemm" |
| 141 | + |
| 142 | + if not os.path.exists(save_path): |
| 143 | + os.mkdir(save_path) |
| 144 | + |
| 145 | + _plot.run(show_plots=True, print_data=True, save_path=save_path) |
0 commit comments