Skip to content

Commit d36d477

Browse files
generatedunixname499836121facebook-github-bot
generatedunixname499836121
authored andcommitted
Fix misleadingly high AOT Inductor dashboard performance (#153060)
Summary: Fixes misleadingly high AOTInductor performance benchmark numbers in scenarios where a model updates internal parameters during `torch.export.export`. Since `FakeTensorMode` is enabled during export, all such parameters become `FakeTensor`s, slowing down future eager-mode runs using that model substantively. This, in turn, causes misleading performance stats, where the slowness of eager-mode makes `AOTInductor` look _very_ good. An [example benchmark](https://hud.pytorch.org/benchmark/timm_models/inductor_aot_inductor?dashboard=torchinductor&startTime=Wed%2C%2030%20Apr%202025%2015%3A54%3A04%20GMT&stopTime=Wed%2C%2007%20May%202025%2015%3A54%3A04%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=main&lCommit=1dd36ad2d440a4f3faf724b3a8e13925e3180c24&rBranch=main&rCommit=cc7346bf19c019255dcb4484694a75850ed74d5a&model=convit_base) with this issue. The equivalent `cpp_wrapper` benchmark run shows a 2x performance gain, not 20x. Only two benchmarks we regularly run are affected by this, both in the TIMM set. X-link: pytorch/pytorch#153060 Approved by: https://github.com/desertfire Reviewed By: jeanschmidt Differential Revision: D74729281 fbshipit-source-id: bf25cd22933d9670018d935747b0604dec4178aa
1 parent ecf479d commit d36d477

File tree

1 file changed

+49
-6
lines changed

1 file changed

+49
-6
lines changed

userbenchmark/dynamo/dynamobench/common.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,12 +1383,11 @@ def _produce_dynamic_shapes_for_export(path, x):
13831383

13841384

13851385
class AOTInductorModelCache:
1386-
cache = {}
1386+
cache: dict[weakref.ref, tuple[Any, float]] = {}
13871387

13881388
@classmethod
13891389
def load(cls, model, example_inputs, mode):
13901390
import torch._inductor
1391-
import torch.export._trace
13921391
from torch.export.dynamic_shapes import _combine_args, _tree_map_with_path
13931392

13941393
key = weakref.ref(model)
@@ -1419,16 +1418,40 @@ def load(cls, model, example_inputs, mode):
14191418
# delete example_outputs and reset memory stats here
14201419
del example_outputs
14211420
if current_device == "cuda":
1422-
torch.cuda.reset_peak_memory_stats()
14231421
empty_gpu_cache(current_device)
1422+
torch.cuda.reset_peak_memory_stats()
1423+
pre_clone_memory_used = torch.cuda.max_memory_allocated()
14241424
elif current_device == "hpu":
14251425
torch.hpu.reset_peak_memory_stats()
1426+
pre_clone_memory_used = torch.hpu.max_memory_allocated()
1427+
1428+
# Clone the model pre-exporting. This prevents scenarios observed in a few
1429+
# models, where the forward pass modifies model state while exporting, and
1430+
# FakeTensors are thus saved as model data members. This invalidates model
1431+
# reuse in eager mode, so it's safest to export a model clone.
1432+
model_clone = copy.deepcopy(model)
1433+
1434+
# Since CPU doesn't monitor max memory allocation, anything measuring peak
1435+
# memory will miss our transient model clone on CPU anyway.
1436+
#
1437+
# The justification for tracking this value (in order to remove it from the
1438+
# AOTInductor memory measurements) is that normal usage of AOTInductor would
1439+
# not clone the model, since the eager model would be unused post-export.
1440+
clone_memory_used = 0.0
1441+
if current_device == "cuda":
1442+
clone_memory_used = (
1443+
torch.cuda.max_memory_allocated() - pre_clone_memory_used
1444+
) / 1e9
1445+
elif current_device == "hpu":
1446+
clone_memory_used = (
1447+
torch.hpu.max_memory_allocated() - pre_clone_memory_used
1448+
) / 1e9
14261449

14271450
inductor_configs = {}
14281451
if mode == "max-autotune":
14291452
inductor_configs["max_autotune"] = True
14301453
ep = torch.export.export(
1431-
model,
1454+
model_clone,
14321455
example_args,
14331456
example_kwargs,
14341457
dynamic_shapes=dynamic_shapes,
@@ -1439,9 +1462,16 @@ def load(cls, model, example_inputs, mode):
14391462
ep, inductor_configs=inductor_configs
14401463
) # type: ignore[arg-type]
14411464

1442-
cls.cache[key] = torch._inductor.aoti_load_package(package_path)
1465+
cls.cache[key] = (
1466+
torch._inductor.aoti_load_package(package_path),
1467+
clone_memory_used,
1468+
)
14431469

1444-
return cls.cache[key]
1470+
return cls.cache[key][0]
1471+
1472+
@classmethod
1473+
def get_excess_memory(cls, model) -> float:
1474+
return cls.cache.get(weakref.ref(model), (None, 0.0))[1]
14451475

14461476

14471477
def export(model, example_inputs):
@@ -1456,6 +1486,9 @@ def export(model, example_inputs):
14561486
_produce_dynamic_shapes_for_export, combined_args
14571487
)
14581488

1489+
# NOTE: if args.export is ever enabled for --performance mode (rather than solely
1490+
# --accuracy), we'll need to clone the model and subtract out extra memory usage, as
1491+
# done in AOTInductorModelCache.
14591492
ep = torch.export.export(
14601493
model, example_args, example_kwargs, dynamic_shapes=dynamic_shapes, strict=True
14611494
)
@@ -2468,6 +2501,11 @@ def warmup(fn, model, example_inputs, mode, niters=10):
24682501
"dynamo",
24692502
niters=1,
24702503
)
2504+
# If we use warm peak memory, the AOT model loading transient memory
2505+
# won't be present on the warm measurement. We only have to account for
2506+
# it when using cold memory.
2507+
elif self.args.export_aot_inductor:
2508+
dynamo_peak_mem -= AOTInductorModelCache.get_excess_memory(model)
24712509

24722510
if self.args.profile_dynamo_cache_lookup:
24732511
with torch.profiler.profile(
@@ -2616,6 +2654,11 @@ def warmup(fn, model, example_inputs, mode, niters=5):
26162654
"dynamo",
26172655
niters=1,
26182656
)
2657+
# If we use warm peak memory, the AOT model loading transient memory
2658+
# won't be present on the warm measurement. We only have to account for
2659+
# it when using cold memory.
2660+
elif self.args.export_aot_inductor:
2661+
dynamo_peak_mem -= AOTInductorModelCache.get_excess_memory(model)
26192662

26202663
if self.args.profile_dynamo_cache_lookup:
26212664
with torch.profiler.profile(

0 commit comments

Comments
 (0)