Skip to content

Commit b7524a6

Browse files
xuzhao9facebook-github-bot
authored andcommitted
Add torchao to PT2 Benchmark Runner
Summary: X-link: #2268 Support torchao performance and accuracy tests in PT2 Benchmark Runner, using the inductor backend as the baseline. X-link: pytorch/pytorch#126469 Reviewed By: jerryzh168 Differential Revision: D57463273 Pulled By: xuzhao9 fbshipit-source-id: 64520f18b63107ce5f07447ef7f4a8c841d9ff1f
1 parent ebbc77b commit b7524a6

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

userbenchmark/dynamo/dynamobench/common.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3485,6 +3485,18 @@ def get_example_inputs(self):
34853485
action="store_true",
34863486
help="Measure speedup with TorchInductor",
34873487
)
3488+
group.add_argument(
3489+
"--quantization",
3490+
choices=[
3491+
"int8dynamic",
3492+
"int8weightonly",
3493+
"int4weightonly",
3494+
"autoquant",
3495+
"noquant",
3496+
],
3497+
default=None,
3498+
help="Measure speedup of torchao quantization with TorchInductor baseline",
3499+
)
34883500
group.add_argument(
34893501
"--export",
34903502
action="store_true",
@@ -3679,6 +3691,9 @@ def run(runner, args, original_dir=None):
36793691
if args.inductor:
36803692
assert args.backend is None
36813693
args.backend = "inductor"
3694+
if args.quantization:
3695+
assert args.backend is None
3696+
args.backend = "torchao"
36823697
if args.dynamic_batch_only:
36833698
args.dynamic_shapes = True
36843699
torch._dynamo.config.assume_static_by_default = True
@@ -3957,6 +3972,20 @@ def run(runner, args, original_dir=None):
39573972

39583973
# AOTInductor doesn't support control flow yet
39593974
runner.skip_models.update(runner.skip_models_due_to_control_flow)
3975+
elif args.backend == "torchao":
3976+
assert "cuda" in args.devices, "Quantization requires CUDA device."
3977+
assert args.bfloat16, "Quantization requires dtype bfloat16."
3978+
from .torchao import setup_baseline, torchao_optimize_ctx
3979+
3980+
setup_baseline()
3981+
baseline_ctx = functools.partial(
3982+
torch.compile,
3983+
backend="inductor",
3984+
fullgraph=args.nopython,
3985+
mode=args.inductor_compile_mode,
3986+
)
3987+
runner.model_iter_fn = baseline_ctx(runner.model_iter_fn)
3988+
optimize_ctx = torchao_optimize_ctx(args.quantization)
39603989
else:
39613990
optimize_ctx = torch._dynamo.optimize(args.backend, nopython=args.nopython)
39623991
experiment = speedup_experiment

0 commit comments

Comments
 (0)