Skip to content

Commit 23c6648

Browse files
bertmaherfacebook-github-bot
authored andcommitted
Add some split-k shapes to the benchmark harness
Summary: Available via the arg `--splitk` although I'm not sure if that's the best way to do things... Reviewed By: xuzhao9, chenyang78 Differential Revision: D56197655 fbshipit-source-id: d30303b7402faa057722197481e185580199aba0
1 parent a0665be commit 23c6648

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

torchbenchmark/operators/gemm/data_io.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def parse_args(args: List[str]) -> argparse.Namespace:
1212
parser.add_argument("--k", default=8, type=int)
1313
parser.add_argument("--n", default=8, type=int)
1414
parser.add_argument("--input", default=None, type=str)
15+
parser.add_argument("--splitk", action="store_true", default=False)
1516
args = parser.parse_args(args)
1617
return args
1718

torchbenchmark/operators/gemm/operator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@
6767
(4096, 4096, 4096, None),
6868
]
6969

70+
SPLIT_K_SHAPES = [
71+
(m, k, m, None)
72+
for m in [128 * i for i in range(1, 5)]
73+
for k in [2048 * i for i in range(1, 9)]
74+
]
7075

7176
class Operator(BenchmarkOperator):
7277
DEFAULT_METRICS = ["latency", "speedup", "accuracy"]
@@ -81,6 +86,8 @@ def __init__(self, mode: str, device: str, extra_args: List[str] = []):
8186
self.tbargs = parse_args(self.extra_args)
8287
if self.tbargs.input:
8388
self.shapes = read_shapes_from_csv(self.tbargs.input)
89+
elif self.tbargs.splitk:
90+
self.shapes = SPLIT_K_SHAPES
8491
else:
8592
self.shapes = [(self.tb_args.m, self.tbargs.k, self.tbargs.n)]
8693
self.DEFAULT_NUM_BATCH = len(self.shapes)

0 commit comments

Comments
 (0)