Skip to content

Commit de0cb2f

Browse files
desertfirefacebook-github-bot
authored andcommitted
Fix mis-calculated memory compression ratio (#150695)
Summary: pytorch/pytorch#149817 introduced an extra warmup run to compute AOTI memory compression ratio, but since weights are only loaded once in the AOTI run, the peak memory seen in the extra warmup won't include the weight, which causes an aritifically high memory compression ratio. This PR removes that extra warmup run, and calls reset_peak_memory_stats in the proper place instead. X-link: pytorch/pytorch#150695 Approved by: https://github.com/yushangdi Reviewed By: atalman Differential Revision: D72570207 fbshipit-source-id: 421bde2de6ebde4ca795871c5920a66a6b77073f
1 parent 4dbceea commit de0cb2f

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

userbenchmark/dynamo/dynamobench/common.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,6 +1395,8 @@ def load(cls, model, example_inputs):
13951395
with torch.no_grad():
13961396
# copy.deepcopy is required to prevent any surprising side-effect,
13971397
# see https://github.com/pytorch/pytorch/issues/113029
1398+
# This will cause memory stats to be overshadowed by this eager run.
1399+
# To fix that, memory stats will be reset later.
13981400
example_outputs = copy.deepcopy(model)(*example_args, **example_kwargs)
13991401

14001402
if pytree.is_namedtuple_instance(example_outputs):
@@ -1411,6 +1413,14 @@ def load(cls, model, example_inputs):
14111413
_produce_dynamic_shapes_for_export, combined_args
14121414
)
14131415

1416+
# delete example_outputs and reset memory stats here
1417+
del example_outputs
1418+
if current_device == "cuda":
1419+
torch.cuda.reset_peak_memory_stats()
1420+
empty_gpu_cache(current_device)
1421+
elif current_device == "hpu":
1422+
torch.hpu.reset_peak_memory_stats()
1423+
14141424
ep = torch.export.export(
14151425
model,
14161426
example_args,
@@ -3735,10 +3745,6 @@ def run(runner, args, original_dir=None):
37353745
# AOTInductor doesn't support control flow yet
37363746
runner.skip_models.update(runner.skip_models_due_to_control_flow)
37373747
runner.skip_models.update(runner.skip_models_due_to_export_not_supported)
3738-
3739-
# For AOTI, we only measure the memory compression ratio at the run time
3740-
# instead of the compile time, so use a warmup run to trigger AOTI compilation.
3741-
args.use_warm_peak_memory = True
37423748
elif args.backend == "torchao":
37433749
assert "cuda" in args.devices, "Quantization requires CUDA device."
37443750
assert args.bfloat16, "Quantization requires dtype bfloat16."

0 commit comments

Comments
 (0)