@@ -1050,7 +1050,9 @@ def maybe_mark_profile(*args, **kwargs):
1050
1050
1051
1051
with maybe_profile (args .export_profiler_trace , ** args .profile_details ) as p :
1052
1052
if args .export_aot_inductor :
1053
- frozen_model_iter_fn = export_aot_inductor (model , example_inputs )
1053
+ frozen_model_iter_fn = export_aot_inductor (
1054
+ model , example_inputs , args .inductor_compile_mode
1055
+ )
1054
1056
else :
1055
1057
frozen_model_iter_fn = torch ._dynamo .run (model_iter_fn )
1056
1058
@@ -1384,7 +1386,7 @@ class AOTInductorModelCache:
1384
1386
cache = {}
1385
1387
1386
1388
@classmethod
1387
- def load (cls , model , example_inputs ):
1389
+ def load (cls , model , example_inputs , mode ):
1388
1390
import torch ._inductor
1389
1391
import torch .export ._trace
1390
1392
from torch .export .dynamic_shapes import _combine_args , _tree_map_with_path
@@ -1422,6 +1424,9 @@ def load(cls, model, example_inputs):
1422
1424
elif current_device == "hpu" :
1423
1425
torch .hpu .reset_peak_memory_stats ()
1424
1426
1427
+ inductor_configs = {}
1428
+ if mode == "max-autotune" :
1429
+ inductor_configs ["max_autotune" ] = True
1425
1430
ep = torch .export .export (
1426
1431
model ,
1427
1432
example_args ,
@@ -1430,7 +1435,9 @@ def load(cls, model, example_inputs):
1430
1435
strict = False ,
1431
1436
)
1432
1437
with torch .no_grad ():
1433
- package_path = torch ._inductor .aoti_compile_and_package (ep ) # type: ignore[arg-type]
1438
+ package_path = torch ._inductor .aoti_compile_and_package (
1439
+ ep , inductor_configs = inductor_configs
1440
+ ) # type: ignore[arg-type]
1434
1441
1435
1442
cls .cache [key ] = torch ._inductor .aoti_load_package (package_path )
1436
1443
@@ -1460,8 +1467,8 @@ def opt_export(_, example_inputs):
1460
1467
return opt_export
1461
1468
1462
1469
1463
- def export_aot_inductor (model , example_inputs ):
1464
- optimized = AOTInductorModelCache .load (model , example_inputs )
1470
+ def export_aot_inductor (model , example_inputs , mode ):
1471
+ optimized = AOTInductorModelCache .load (model , example_inputs , mode )
1465
1472
1466
1473
def opt_aot_inductor (_ , example_inputs , collect_outputs = False ):
1467
1474
example_args , example_kwargs = _normalize_bench_inputs (example_inputs )
@@ -3752,7 +3759,9 @@ def run(runner, args, original_dir=None):
3752
3759
elif args .backend or args .export_aot_inductor :
3753
3760
if args .export_aot_inductor :
3754
3761
assert not args .training , "AOTInductor only supports inference"
3755
- optimize_ctx = functools .partial (export_aot_inductor )
3762
+ optimize_ctx = functools .partial (
3763
+ export_aot_inductor , mode = args .inductor_compile_mode
3764
+ )
3756
3765
3757
3766
# AOTInductor doesn't support control flow yet
3758
3767
runner .skip_models .update (runner .skip_models_due_to_control_flow )
0 commit comments