Skip to content

Commit 1708afa

Browse files
manman-renfacebook-github-bot
authored andcommitted
template attention from PT2
Summary: Based on inductor generated code, but modified to use Triton's tuning pytorch github: pytorch/pytorch#124369 The base variant is prior to OSS pytorch/pytorch#124356. This PR improves performance for template attention. The second variant is after the PR. Reviewed By: bertmaher Differential Revision: D56372010 fbshipit-source-id: 4439113a92fd41b81269af1227deaf5ec52c65dc
1 parent 02d3328 commit 1708afa

File tree

3 files changed

+439
-0
lines changed

3 files changed

+439
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .operator import Operator
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
2+
import csv
3+
import os
4+
import statistics
5+
from typing import Any, Callable, Generator, List, Optional
6+
7+
import numpy
8+
import torch
9+
import triton
10+
11+
12+
from torchbenchmark.util.triton_op import (
13+
BenchmarkOperator,
14+
BenchmarkOperatorMetrics,
15+
register_benchmark,
16+
register_metric,
17+
)
18+
19+
from .triton_attention import triton_attention_no_exp2 as triton_test_no_exp2
20+
from .triton_attention import triton_attention_with_exp2 as triton_test_with_exp2
21+
from torch._dynamo.testing import rand_strided
22+
23+
24+
BUILDIN_SHAPES = [
25+
(16, 16, 4096, 64),
26+
]
27+
28+
29+
class Operator(BenchmarkOperator):
30+
DEFAULT_METRICS = ["latency", "speedup", "accuracy"]
31+
32+
def __init__(self, mode: str, device: str, extra_args: List[str] = []):
33+
super().__init__(mode=mode, device=device, extra_args=extra_args)
34+
self.shapes = BUILDIN_SHAPES
35+
36+
@register_benchmark(baseline=True)
37+
def test_no_exp2(self, p1, p2, p3) -> Callable:
38+
return lambda: triton_test_no_exp2(p1, p2, p3)
39+
40+
@register_benchmark()
41+
def test_with_exp2(self, p1, p2, p3) -> Callable:
42+
return lambda: triton_test_with_exp2(p1, p2, p3)
43+
44+
def get_x_val(self, example_inputs) -> float:
45+
p1, p2, p3 = example_inputs
46+
batch_size, num_heads, num_queries, m = p3.size()
47+
return num_queries
48+
49+
def get_input_iter(self) -> Generator:
50+
for shape in self.shapes:
51+
batch_size, num_heads, num_queries, m = shape
52+
arg0_1 = rand_strided((16, 16, 4096, 64), (4194304, 262144, 64, 1), device='cuda:0', dtype=torch.float16)
53+
arg1_1 = rand_strided((16, 16, 4096, 64), (4194304, 262144, 64, 1), device='cuda:0', dtype=torch.float16)
54+
arg2_1 = rand_strided((16, 16, 4096, 64), (4194304, 262144, 64, 1), device='cuda:0', dtype=torch.float16)
55+
yield arg0_1, arg1_1, arg2_1
56+
57+
def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
58+
output = fn()
59+
baseline_output = baseline_fn()
60+
return torch.allclose(output, baseline_output)
61+

0 commit comments

Comments
 (0)