@@ -1408,7 +1408,7 @@ class AOTInductorModelCache:
1408
1408
def load (cls , model , example_inputs ):
1409
1409
import torch ._inductor
1410
1410
import torch .export ._trace
1411
- from torch .export .dynamic_shapes import _tree_map_with_path
1411
+ from torch .export .dynamic_shapes import _combine_args , _tree_map_with_path
1412
1412
1413
1413
key = weakref .ref (model )
1414
1414
if key not in cls .cache :
@@ -1428,7 +1428,7 @@ def load(cls, model, example_inputs):
1428
1428
else :
1429
1429
_register_dataclass_output_as_pytree (example_outputs )
1430
1430
1431
- combined_args = tuple ( example_args ) + tuple ( example_kwargs . values () )
1431
+ combined_args = _combine_args ( model , example_args , example_kwargs )
1432
1432
dynamic_shapes = _tree_map_with_path (
1433
1433
_produce_dynamic_shapes_for_export , combined_args
1434
1434
)
@@ -1449,13 +1449,13 @@ def load(cls, model, example_inputs):
1449
1449
1450
1450
1451
1451
def export (model , example_inputs ):
1452
- from torch .export .dynamic_shapes import _tree_map_with_path
1452
+ from torch .export .dynamic_shapes import _combine_args , _tree_map_with_path
1453
1453
1454
1454
example_args , example_kwargs = _normalize_bench_inputs (example_inputs )
1455
1455
example_outputs = model (* example_args , ** example_kwargs )
1456
1456
_register_dataclass_output_as_pytree (example_outputs )
1457
1457
1458
- combined_args = tuple ( example_args ) + tuple ( example_kwargs . values () )
1458
+ combined_args = _combine_args ( model , example_args , example_kwargs )
1459
1459
dynamic_shapes = _tree_map_with_path (
1460
1460
_produce_dynamic_shapes_for_export , combined_args
1461
1461
)
0 commit comments