Skip to content

Commit d1b2abb

Browse files
pianpwkfacebook-github-bot
authored andcommitted
fix dynamic_shapes spec for moco (#148772) (#2601)
Summary: Pull Request resolved: #2601 Fixes pytorch/pytorch#148333 X-link: pytorch/pytorch#148772 Approved by: https://github.com/yushangdi, https://github.com/desertfire Reviewed By: yushangdi, angelayi Differential Revision: D71412041 fbshipit-source-id: c5d2bf63539534d0e660da9ea882be1984f693ec
1 parent 50e2f74 commit d1b2abb

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

userbenchmark/dynamo/dynamobench/common.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1408,7 +1408,7 @@ class AOTInductorModelCache:
14081408
def load(cls, model, example_inputs):
14091409
import torch._inductor
14101410
import torch.export._trace
1411-
from torch.export.dynamic_shapes import _tree_map_with_path
1411+
from torch.export.dynamic_shapes import _combine_args, _tree_map_with_path
14121412

14131413
key = weakref.ref(model)
14141414
if key not in cls.cache:
@@ -1428,7 +1428,7 @@ def load(cls, model, example_inputs):
14281428
else:
14291429
_register_dataclass_output_as_pytree(example_outputs)
14301430

1431-
combined_args = tuple(example_args) + tuple(example_kwargs.values())
1431+
combined_args = _combine_args(model, example_args, example_kwargs)
14321432
dynamic_shapes = _tree_map_with_path(
14331433
_produce_dynamic_shapes_for_export, combined_args
14341434
)
@@ -1449,13 +1449,13 @@ def load(cls, model, example_inputs):
14491449

14501450

14511451
def export(model, example_inputs):
1452-
from torch.export.dynamic_shapes import _tree_map_with_path
1452+
from torch.export.dynamic_shapes import _combine_args, _tree_map_with_path
14531453

14541454
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
14551455
example_outputs = model(*example_args, **example_kwargs)
14561456
_register_dataclass_output_as_pytree(example_outputs)
14571457

1458-
combined_args = tuple(example_args) + tuple(example_kwargs.values())
1458+
combined_args = _combine_args(model, example_args, example_kwargs)
14591459
dynamic_shapes = _tree_map_with_path(
14601460
_produce_dynamic_shapes_for_export, combined_args
14611461
)

0 commit comments

Comments
 (0)