Skip to content

Commit f188e24

Browse files
masnesralfacebook-github-bot
authored andcommitted
Add --warm-start-latency to benchmark harness (#125353)
Summary: This change introduces a new flagg to perform a "warm start" test from the benchmark harness. The idea is to test a model twice: first with a fresh inductor cache (i.e., a "cold start"), and then a second run in a fresh process with the cache available (i.e. a "warm start"). We can later add this mode to CI runs to collect compile times for warm start. X-link: pytorch/pytorch#125353 Approved by: https://github.com/eellison, https://github.com/desertfire Reviewed By: izaitsevfb Differential Revision: D57216102 Pulled By: masnesral fbshipit-source-id: 2cb5751260e05844ad3324572ada4caca5109f23
1 parent 73dd57e commit f188e24

File tree

1 file changed

+56
-52
lines changed

1 file changed

+56
-52
lines changed

userbenchmark/dynamo/dynamobench/common.py

Lines changed: 56 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2005,33 +2005,6 @@ def get_dynamo_stats():
20052005
)
20062006

20072007

2008-
def maybe_fresh_cache(fn, is_cold_start):
2009-
def inner(*args, **kwargs):
2010-
cache_minder = contextlib.nullcontext()
2011-
if is_cold_start:
2012-
cache_entries = {}
2013-
cache_minder = fresh_inductor_cache(cache_entries)
2014-
2015-
try:
2016-
with cache_minder:
2017-
return fn(*args, **kwargs)
2018-
finally:
2019-
dump_cache = False
2020-
if dump_cache and is_cold_start:
2021-
output_csv(
2022-
output_filename[:-4] + "_triton_cache.csv",
2023-
["dev", "name", "batch_size", "triton_cache"],
2024-
[
2025-
current_device,
2026-
current_name,
2027-
current_batch_size,
2028-
cache_entries,
2029-
],
2030-
)
2031-
2032-
return inner
2033-
2034-
20352008
@contextmanager
20362009
def maybe_init_distributed(should_init_distributed, rank, world_size, port="6789"):
20372010
try:
@@ -3297,12 +3270,6 @@ def get_example_inputs(self):
32973270
action="store_true",
32983271
help="print dataframe result used for calculating accuracy",
32993272
)
3300-
parser.add_argument(
3301-
"--cold-start-latency",
3302-
"--cold_start_latency",
3303-
action="store_true",
3304-
help="Use a fresh triton cachedir when running each model, to force cold-start compile.",
3305-
)
33063273
parser.add_argument(
33073274
"--disable-cudagraphs",
33083275
action="store_true",
@@ -3415,6 +3382,19 @@ def get_example_inputs(self):
34153382
help="Enables Memory Snapshot tool for memory deep dives: https://pytorch.org/blog/understanding-gpu-memory-1/",
34163383
)
34173384

3385+
group_latency = parser.add_mutually_exclusive_group()
3386+
group_latency.add_argument(
3387+
"--cold-start-latency",
3388+
"--cold_start_latency",
3389+
action="store_true",
3390+
help="Use a fresh triton cachedir when running each model, to force cold-start compile.",
3391+
)
3392+
group_latency.add_argument(
3393+
"--warm-start-latency",
3394+
action="store_true",
3395+
help="Run model(s) twice and preseve caches in between to enable a 'warm start' on the 2nd run",
3396+
)
3397+
34183398
group_fuser = parser.add_mutually_exclusive_group()
34193399
# --nvfuser is now the default, keep the option to not break scripts
34203400
group_fuser.add_argument("--nvfuser", action="store_true", help=argparse.SUPPRESS)
@@ -3571,9 +3551,17 @@ def process_entry(rank, runner, original_dir, args):
35713551
world_size=args.world_size,
35723552
port=args.distributed_master_port,
35733553
):
3574-
return maybe_fresh_cache(
3575-
run, (args.cold_start_latency and args.only) or args.ci
3576-
)(runner, args, original_dir)
3554+
return run(runner, args, original_dir)
3555+
3556+
3557+
def maybe_fresh_cache(args):
3558+
cache_dir_assigned = "TORCHINDUCTOR_CACHE_DIR" in os.environ
3559+
if not cache_dir_assigned and (
3560+
args.cold_start_latency or args.warm_start_latency or args.ci
3561+
):
3562+
return fresh_inductor_cache()
3563+
else:
3564+
return contextlib.nullcontext()
35773565

35783566

35793567
def main(runner, original_dir=None, args=None):
@@ -3598,23 +3586,39 @@ def main(runner, original_dir=None, args=None):
35983586
f"--diff-branch: current branch is same as {args.diff_branch} branch, what are you diffing?"
35993587
)
36003588

3601-
args.init_distributed = args.only and args.multiprocess
3602-
if args.init_distributed:
3603-
# NB: Do NOT query device count before CUDA initialization; we're
3604-
# going to overwrite CUDA_VISIBLE_DEVICES and this will result in
3605-
# https://github.com/pytorch/pytorch/issues/107300
3606-
device_count = torch.cuda.device_count()
3607-
if device_count <= 1:
3608-
log.warning(
3609-
"The use multiprocess flag is set but there are <= 1 devices available."
3589+
with maybe_fresh_cache(args):
3590+
args.init_distributed = args.only and args.multiprocess
3591+
if args.init_distributed:
3592+
# NB: Do NOT query device count before CUDA initialization; we're
3593+
# going to overwrite CUDA_VISIBLE_DEVICES and this will result in
3594+
# https://github.com/pytorch/pytorch/issues/107300
3595+
device_count = torch.cuda.device_count()
3596+
if device_count <= 1:
3597+
log.warning(
3598+
"The use multiprocess flag is set but there are <= 1 devices available."
3599+
)
3600+
# multiprocess path
3601+
args.world_size = device_count
3602+
mp.spawn(
3603+
process_entry, args=(runner, original_dir, args), nprocs=device_count
36103604
)
3611-
# multiprocess path
3612-
args.world_size = device_count
3613-
mp.spawn(process_entry, args=(runner, original_dir, args), nprocs=device_count)
3614-
else:
3615-
# single process path just uses the main process
3616-
args.world_size = 1
3617-
process_entry(0, runner, original_dir, args)
3605+
elif args.only and args.warm_start_latency:
3606+
# Warm start mode. Enable FX graph caching and perform back-to-back runs in
3607+
# separate processes (but ensure the inductor cache is preserved across runs).
3608+
env = os.environ.copy()
3609+
env["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
3610+
cmd = [sys.executable] + sys.argv
3611+
cmd.remove("--warm-start-latency")
3612+
3613+
print(f"Executing cold-start run for {args.only}")
3614+
subprocess.check_call(cmd, timeout=args.timeout, env=env)
3615+
3616+
print(f"Executing warm-start run for {args.only}")
3617+
subprocess.check_call(cmd, timeout=args.timeout, env=env)
3618+
else:
3619+
# single process path just uses the main process
3620+
args.world_size = 1
3621+
process_entry(0, runner, original_dir, args)
36183622

36193623

36203624
def write_csv_when_exception(args, name: str, status: str, device=None):

0 commit comments

Comments
 (0)