Skip to content

Commit ee027de

Browse files
mlazosfacebook-github-bot
authored andcommitted
Use universal flatten APIs (#152505)
Summary: X-link: pytorch/pytorch#152505 Approved by: https://github.com/anijain2305 ghstack dependencies: #152389 Reviewed By: huydhn Differential Revision: D74531500 fbshipit-source-id: ea17d89d5febe607ece2f43d053f5f533dbd1bd5
1 parent 5c048d1 commit ee027de

File tree

1 file changed

+5
-1
lines changed
  • userbenchmark/dynamo/dynamobench/_dynamo

1 file changed

+5
-1
lines changed

userbenchmark/dynamo/dynamobench/_dynamo/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@
9292
from torch.utils._triton import has_triton, has_triton_package
9393
from torch.utils.hooks import RemovableHandle
9494

95+
from .graph_utils import _get_flat_args
96+
9597

9698
if typing.TYPE_CHECKING:
9799
from collections.abc import (
@@ -3150,7 +3152,9 @@ def get_fake_value(node, tx, allow_non_graph_fake=False):
31503152
args, kwargs = get_fake_values_from_nodes(
31513153
tx, (node.args, node.kwargs), allow_non_graph_fake
31523154
)
3153-
flat_args_kwargs, _ = pytree.tree_flatten((args, kwargs))
3155+
flat_args_kwargs = get_fake_values_from_nodes(
3156+
tx, _get_flat_args(node, {}), allow_non_graph_fake
3157+
)
31543158
id_to_initial_version = {
31553159
id(arg): arg._version for arg in flat_args_kwargs if is_fake(arg)
31563160
}

0 commit comments

Comments
 (0)