Skip to content

Commit e935426

Browse files
anijain2305facebook-github-bot
authored andcommitted
Support list subclasses and fix dict subclasses mutation bugs (#146819)
Summary: This PR adds support for list subclasses. Among other things are 1) Tracking the mutations on internal vts like `_dict_vt` and `_list_vt` using sources. This helps identify if there was a mutation in the underlying data structures, and we need to reconstruct it. 2) `UserDefinedObjectVariable` now has a new method - `is_modified` which `side_effect` infra relies upon to check mutations in the underlying vts (like `_dict_vt`). 3) `reconstruction` logic ensures that we use `dict.__getitem__` and `list.__getitem__` methods. This is super important because we don't want to call the overridden `__getitem__` methods. If this PR is hard to review, please let me know. I can break it into several small PRs. X-link: pytorch/pytorch#146819 Approved by: https://github.com/StrongerXi, https://github.com/jansel Reviewed By: huydhn Differential Revision: D69537369 fbshipit-source-id: 9c20f4ee84c91639c320a3a04a1a153859623ab6
1 parent 43851fb commit e935426

File tree

1 file changed

+10
-1
lines changed
  • userbenchmark/dynamo/dynamobench/_dynamo

1 file changed

+10
-1
lines changed

userbenchmark/dynamo/dynamobench/_dynamo/utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2356,6 +2356,8 @@ def check_numpy_ndarray_args(args, kwargs):
23562356

23572357
tuple_new = tuple.__new__
23582358
tuple_methods = {method for method in tuple.__dict__.values() if callable(method)}
2359+
list_methods = {method for method in list.__dict__.values() if callable(method)}
2360+
list_getitem = list.__getitem__
23592361

23602362

23612363
def builtin_dict_keys(d):
@@ -2407,8 +2409,15 @@ def to_subclass(t, cls):
24072409
return t.as_subclass(cls)
24082410

24092411

2412+
dict_getitem = dict.__getitem__
2413+
2414+
24102415
def dict_keys_getitem(d, n):
2411-
return next(itertools.islice(iter(d), n, n + 1))
2416+
# Call dict(d) to prevent calling overridden __iter__/keys
2417+
dict_class = dict
2418+
if isinstance(d, OrderedDict):
2419+
dict_class = OrderedDict
2420+
return next(itertools.islice(dict_class.keys(d), n, n + 1))
24122421

24132422

24142423
def enum_repr(value, local):

0 commit comments

Comments
 (0)