@@ -1383,12 +1383,11 @@ def _produce_dynamic_shapes_for_export(path, x):
1383
1383
1384
1384
1385
1385
class AOTInductorModelCache :
1386
- cache = {}
1386
+ cache : dict [ weakref . ref , tuple [ Any , float ]] = {}
1387
1387
1388
1388
@classmethod
1389
1389
def load (cls , model , example_inputs , mode ):
1390
1390
import torch ._inductor
1391
- import torch .export ._trace
1392
1391
from torch .export .dynamic_shapes import _combine_args , _tree_map_with_path
1393
1392
1394
1393
key = weakref .ref (model )
@@ -1419,16 +1418,40 @@ def load(cls, model, example_inputs, mode):
1419
1418
# delete example_outputs and reset memory stats here
1420
1419
del example_outputs
1421
1420
if current_device == "cuda" :
1422
- torch .cuda .reset_peak_memory_stats ()
1423
1421
empty_gpu_cache (current_device )
1422
+ torch .cuda .reset_peak_memory_stats ()
1423
+ pre_clone_memory_used = torch .cuda .max_memory_allocated ()
1424
1424
elif current_device == "hpu" :
1425
1425
torch .hpu .reset_peak_memory_stats ()
1426
+ pre_clone_memory_used = torch .hpu .max_memory_allocated ()
1427
+
1428
+ # Clone the model pre-exporting. This prevents scenarios observed in a few
1429
+ # models, where the forward pass modifies model state while exporting, and
1430
+ # FakeTensors are thus saved as model data members. This invalidates model
1431
+ # reuse in eager mode, so it's safest to export a model clone.
1432
+ model_clone = copy .deepcopy (model )
1433
+
1434
+ # Since CPU doesn't monitor max memory allocation, anything measuring peak
1435
+ # memory will miss our transient model clone on CPU anyway.
1436
+ #
1437
+ # The justification for tracking this value (in order to remove it from the
1438
+ # AOTInductor memory measurements) is that normal usage of AOTInductor would
1439
+ # not clone the model, since the eager model would be unused post-export.
1440
+ clone_memory_used = 0.0
1441
+ if current_device == "cuda" :
1442
+ clone_memory_used = (
1443
+ torch .cuda .max_memory_allocated () - pre_clone_memory_used
1444
+ ) / 1e9
1445
+ elif current_device == "hpu" :
1446
+ clone_memory_used = (
1447
+ torch .hpu .max_memory_allocated () - pre_clone_memory_used
1448
+ ) / 1e9
1426
1449
1427
1450
inductor_configs = {}
1428
1451
if mode == "max-autotune" :
1429
1452
inductor_configs ["max_autotune" ] = True
1430
1453
ep = torch .export .export (
1431
- model ,
1454
+ model_clone ,
1432
1455
example_args ,
1433
1456
example_kwargs ,
1434
1457
dynamic_shapes = dynamic_shapes ,
@@ -1439,9 +1462,16 @@ def load(cls, model, example_inputs, mode):
1439
1462
ep , inductor_configs = inductor_configs
1440
1463
) # type: ignore[arg-type]
1441
1464
1442
- cls .cache [key ] = torch ._inductor .aoti_load_package (package_path )
1465
+ cls .cache [key ] = (
1466
+ torch ._inductor .aoti_load_package (package_path ),
1467
+ clone_memory_used ,
1468
+ )
1443
1469
1444
- return cls .cache [key ]
1470
+ return cls .cache [key ][0 ]
1471
+
1472
+ @classmethod
1473
+ def get_excess_memory (cls , model ) -> float :
1474
+ return cls .cache .get (weakref .ref (model ), (None , 0.0 ))[1 ]
1445
1475
1446
1476
1447
1477
def export (model , example_inputs ):
@@ -1456,6 +1486,9 @@ def export(model, example_inputs):
1456
1486
_produce_dynamic_shapes_for_export , combined_args
1457
1487
)
1458
1488
1489
+ # NOTE: if args.export is ever enabled for --performance mode (rather than solely
1490
+ # --accuracy), we'll need to clone the model and subtract out extra memory usage, as
1491
+ # done in AOTInductorModelCache.
1459
1492
ep = torch .export .export (
1460
1493
model , example_args , example_kwargs , dynamic_shapes = dynamic_shapes , strict = True
1461
1494
)
@@ -2468,6 +2501,11 @@ def warmup(fn, model, example_inputs, mode, niters=10):
2468
2501
"dynamo" ,
2469
2502
niters = 1 ,
2470
2503
)
2504
+ # If we use warm peak memory, the AOT model loading transient memory
2505
+ # won't be present on the warm measurement. We only have to account for
2506
+ # it when using cold memory.
2507
+ elif self .args .export_aot_inductor :
2508
+ dynamo_peak_mem -= AOTInductorModelCache .get_excess_memory (model )
2471
2509
2472
2510
if self .args .profile_dynamo_cache_lookup :
2473
2511
with torch .profiler .profile (
@@ -2616,6 +2654,11 @@ def warmup(fn, model, example_inputs, mode, niters=5):
2616
2654
"dynamo" ,
2617
2655
niters = 1 ,
2618
2656
)
2657
+ # If we use warm peak memory, the AOT model loading transient memory
2658
+ # won't be present on the warm measurement. We only have to account for
2659
+ # it when using cold memory.
2660
+ elif self .args .export_aot_inductor :
2661
+ dynamo_peak_mem -= AOTInductorModelCache .get_excess_memory (model )
2619
2662
2620
2663
if self .args .profile_dynamo_cache_lookup :
2621
2664
with torch .profiler .profile (
0 commit comments