@@ -2005,33 +2005,6 @@ def get_dynamo_stats():
2005
2005
)
2006
2006
2007
2007
2008
- def maybe_fresh_cache (fn , is_cold_start ):
2009
- def inner (* args , ** kwargs ):
2010
- cache_minder = contextlib .nullcontext ()
2011
- if is_cold_start :
2012
- cache_entries = {}
2013
- cache_minder = fresh_inductor_cache (cache_entries )
2014
-
2015
- try :
2016
- with cache_minder :
2017
- return fn (* args , ** kwargs )
2018
- finally :
2019
- dump_cache = False
2020
- if dump_cache and is_cold_start :
2021
- output_csv (
2022
- output_filename [:- 4 ] + "_triton_cache.csv" ,
2023
- ["dev" , "name" , "batch_size" , "triton_cache" ],
2024
- [
2025
- current_device ,
2026
- current_name ,
2027
- current_batch_size ,
2028
- cache_entries ,
2029
- ],
2030
- )
2031
-
2032
- return inner
2033
-
2034
-
2035
2008
@contextmanager
2036
2009
def maybe_init_distributed (should_init_distributed , rank , world_size , port = "6789" ):
2037
2010
try :
@@ -3297,12 +3270,6 @@ def get_example_inputs(self):
3297
3270
action = "store_true" ,
3298
3271
help = "print dataframe result used for calculating accuracy" ,
3299
3272
)
3300
- parser .add_argument (
3301
- "--cold-start-latency" ,
3302
- "--cold_start_latency" ,
3303
- action = "store_true" ,
3304
- help = "Use a fresh triton cachedir when running each model, to force cold-start compile." ,
3305
- )
3306
3273
parser .add_argument (
3307
3274
"--disable-cudagraphs" ,
3308
3275
action = "store_true" ,
@@ -3415,6 +3382,19 @@ def get_example_inputs(self):
3415
3382
help = "Enables Memory Snapshot tool for memory deep dives: https://pytorch.org/blog/understanding-gpu-memory-1/" ,
3416
3383
)
3417
3384
3385
+ group_latency = parser .add_mutually_exclusive_group ()
3386
+ group_latency .add_argument (
3387
+ "--cold-start-latency" ,
3388
+ "--cold_start_latency" ,
3389
+ action = "store_true" ,
3390
+ help = "Use a fresh triton cachedir when running each model, to force cold-start compile." ,
3391
+ )
3392
+ group_latency .add_argument (
3393
+ "--warm-start-latency" ,
3394
+ action = "store_true" ,
3395
+ help = "Run model(s) twice and preseve caches in between to enable a 'warm start' on the 2nd run" ,
3396
+ )
3397
+
3418
3398
group_fuser = parser .add_mutually_exclusive_group ()
3419
3399
# --nvfuser is now the default, keep the option to not break scripts
3420
3400
group_fuser .add_argument ("--nvfuser" , action = "store_true" , help = argparse .SUPPRESS )
@@ -3571,9 +3551,17 @@ def process_entry(rank, runner, original_dir, args):
3571
3551
world_size = args .world_size ,
3572
3552
port = args .distributed_master_port ,
3573
3553
):
3574
- return maybe_fresh_cache (
3575
- run , (args .cold_start_latency and args .only ) or args .ci
3576
- )(runner , args , original_dir )
3554
+ return run (runner , args , original_dir )
3555
+
3556
+
3557
+ def maybe_fresh_cache (args ):
3558
+ cache_dir_assigned = "TORCHINDUCTOR_CACHE_DIR" in os .environ
3559
+ if not cache_dir_assigned and (
3560
+ args .cold_start_latency or args .warm_start_latency or args .ci
3561
+ ):
3562
+ return fresh_inductor_cache ()
3563
+ else :
3564
+ return contextlib .nullcontext ()
3577
3565
3578
3566
3579
3567
def main (runner , original_dir = None , args = None ):
@@ -3598,23 +3586,39 @@ def main(runner, original_dir=None, args=None):
3598
3586
f"--diff-branch: current branch is same as { args .diff_branch } branch, what are you diffing?"
3599
3587
)
3600
3588
3601
- args .init_distributed = args .only and args .multiprocess
3602
- if args .init_distributed :
3603
- # NB: Do NOT query device count before CUDA initialization; we're
3604
- # going to overwrite CUDA_VISIBLE_DEVICES and this will result in
3605
- # https://github.com/pytorch/pytorch/issues/107300
3606
- device_count = torch .cuda .device_count ()
3607
- if device_count <= 1 :
3608
- log .warning (
3609
- "The use multiprocess flag is set but there are <= 1 devices available."
3589
+ with maybe_fresh_cache (args ):
3590
+ args .init_distributed = args .only and args .multiprocess
3591
+ if args .init_distributed :
3592
+ # NB: Do NOT query device count before CUDA initialization; we're
3593
+ # going to overwrite CUDA_VISIBLE_DEVICES and this will result in
3594
+ # https://github.com/pytorch/pytorch/issues/107300
3595
+ device_count = torch .cuda .device_count ()
3596
+ if device_count <= 1 :
3597
+ log .warning (
3598
+ "The use multiprocess flag is set but there are <= 1 devices available."
3599
+ )
3600
+ # multiprocess path
3601
+ args .world_size = device_count
3602
+ mp .spawn (
3603
+ process_entry , args = (runner , original_dir , args ), nprocs = device_count
3610
3604
)
3611
- # multiprocess path
3612
- args .world_size = device_count
3613
- mp .spawn (process_entry , args = (runner , original_dir , args ), nprocs = device_count )
3614
- else :
3615
- # single process path just uses the main process
3616
- args .world_size = 1
3617
- process_entry (0 , runner , original_dir , args )
3605
+ elif args .only and args .warm_start_latency :
3606
+ # Warm start mode. Enable FX graph caching and perform back-to-back runs in
3607
+ # separate processes (but ensure the inductor cache is preserved across runs).
3608
+ env = os .environ .copy ()
3609
+ env ["TORCHINDUCTOR_FX_GRAPH_CACHE" ] = "1"
3610
+ cmd = [sys .executable ] + sys .argv
3611
+ cmd .remove ("--warm-start-latency" )
3612
+
3613
+ print (f"Executing cold-start run for { args .only } " )
3614
+ subprocess .check_call (cmd , timeout = args .timeout , env = env )
3615
+
3616
+ print (f"Executing warm-start run for { args .only } " )
3617
+ subprocess .check_call (cmd , timeout = args .timeout , env = env )
3618
+ else :
3619
+ # single process path just uses the main process
3620
+ args .world_size = 1
3621
+ process_entry (0 , runner , original_dir , args )
3618
3622
3619
3623
3620
3624
def write_csv_when_exception (args , name : str , status : str , device = None ):
0 commit comments