Skip to content

Commit e53141f

Browse files
exclamafortefacebook-github-bot
authored andcommitted
Enable max autotune for AOTInductor benchmark (#149309)
Summary: With this PR, AOTinductor can choose to run into max-autotune mode when benchmarking. X-link: pytorch/pytorch#149309 Approved by: https://github.com/desertfire Reviewed By: wdvr Differential Revision: D73784003 fbshipit-source-id: 106380f3b974cb537d66f55276c140e38f47ede3 Co-authored-by: Gabriel Ferns <[email protected]>
1 parent 34f7cc0 commit e53141f

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

userbenchmark/dynamo/dynamobench/common.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,7 +1050,9 @@ def maybe_mark_profile(*args, **kwargs):
10501050

10511051
with maybe_profile(args.export_profiler_trace, **args.profile_details) as p:
10521052
if args.export_aot_inductor:
1053-
frozen_model_iter_fn = export_aot_inductor(model, example_inputs)
1053+
frozen_model_iter_fn = export_aot_inductor(
1054+
model, example_inputs, args.inductor_compile_mode
1055+
)
10541056
else:
10551057
frozen_model_iter_fn = torch._dynamo.run(model_iter_fn)
10561058

@@ -1384,7 +1386,7 @@ class AOTInductorModelCache:
13841386
cache = {}
13851387

13861388
@classmethod
1387-
def load(cls, model, example_inputs):
1389+
def load(cls, model, example_inputs, mode):
13881390
import torch._inductor
13891391
import torch.export._trace
13901392
from torch.export.dynamic_shapes import _combine_args, _tree_map_with_path
@@ -1422,6 +1424,9 @@ def load(cls, model, example_inputs):
14221424
elif current_device == "hpu":
14231425
torch.hpu.reset_peak_memory_stats()
14241426

1427+
inductor_configs = {}
1428+
if mode == "max-autotune":
1429+
inductor_configs["max_autotune"] = True
14251430
ep = torch.export.export(
14261431
model,
14271432
example_args,
@@ -1430,7 +1435,9 @@ def load(cls, model, example_inputs):
14301435
strict=False,
14311436
)
14321437
with torch.no_grad():
1433-
package_path = torch._inductor.aoti_compile_and_package(ep) # type: ignore[arg-type]
1438+
package_path = torch._inductor.aoti_compile_and_package(
1439+
ep, inductor_configs=inductor_configs
1440+
) # type: ignore[arg-type]
14341441

14351442
cls.cache[key] = torch._inductor.aoti_load_package(package_path)
14361443

@@ -1460,8 +1467,8 @@ def opt_export(_, example_inputs):
14601467
return opt_export
14611468

14621469

1463-
def export_aot_inductor(model, example_inputs):
1464-
optimized = AOTInductorModelCache.load(model, example_inputs)
1470+
def export_aot_inductor(model, example_inputs, mode):
1471+
optimized = AOTInductorModelCache.load(model, example_inputs, mode)
14651472

14661473
def opt_aot_inductor(_, example_inputs, collect_outputs=False):
14671474
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
@@ -3752,7 +3759,9 @@ def run(runner, args, original_dir=None):
37523759
elif args.backend or args.export_aot_inductor:
37533760
if args.export_aot_inductor:
37543761
assert not args.training, "AOTInductor only supports inference"
3755-
optimize_ctx = functools.partial(export_aot_inductor)
3762+
optimize_ctx = functools.partial(
3763+
export_aot_inductor, mode=args.inductor_compile_mode
3764+
)
37563765

37573766
# AOTInductor doesn't support control flow yet
37583767
runner.skip_models.update(runner.skip_models_due_to_control_flow)

0 commit comments

Comments
 (0)