Skip to content

Commit 961a768

Browse files
bobrenjc93facebook-github-bot
authored andcommitted
Vary batch size when running dynamic shapes benchmarks (#154805)
Summary: This better measures the actual runtime performance of dynamic shapes where we aren't guaranteed to have similar shapes as the original hint. X-link: pytorch/pytorch#154805 Approved by: https://github.com/Skylion007 ghstack dependencies: #154802, #154826, #154822, #154823 Reviewed By: seemethere Differential Revision: D75811385 fbshipit-source-id: 3deccb5c569ce3a99e681c1f5e6f632d7875131b
1 parent a1de07b commit 961a768

File tree

1 file changed

+55
-2
lines changed

1 file changed

+55
-2
lines changed

userbenchmark/dynamo/dynamobench/common.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import json
1515
import logging
1616
import os
17+
import random
1718
import shutil
1819
import signal
1920
import subprocess
@@ -690,17 +691,52 @@ def timed(
690691
times=1,
691692
return_result=False,
692693
collect_outputs=False,
694+
batch_size=None,
693695
):
694696
use_xla = tensor_is_on_xla(example_inputs)
695697
synchronize()
696698

699+
if batch_size:
700+
patch_torch_manual_seed()
701+
697702
if use_xla:
698703
xm.mark_step()
699704
xm.wait_device_ops()
700705

706+
def vary_batch(t: torch.Tensor, new_batch_size) -> torch.Tensor:
707+
for i, s in enumerate(t.size()):
708+
if s == batch_size:
709+
# If new batch is smaller, we truncate
710+
if new_batch_size < batch_size:
711+
indexer = [slice(None)] * t.ndim
712+
indexer[i] = slice(0, new_batch_size)
713+
t = t[tuple(indexer)]
714+
# If new batch is greater, we just duplicate the last row
715+
# over and over until we hit the desired batch size
716+
elif new_batch_size > batch_size:
717+
indexer = [slice(None)] * t.ndim
718+
indexer[i] = -1
719+
last_slice = t[tuple(indexer)].unsqueeze(i)
720+
repeat_shape = list(t.shape)
721+
repeat_shape[i] = new_batch_size - batch_size
722+
padding = last_slice.expand(*repeat_shape)
723+
t = torch.cat([t, padding], dim=i)
724+
break
725+
return t
726+
701727
time_total = 0
702728
# Dont collect outputs to correctly measure timing
703729
for _ in range(times):
730+
# If batch_size is 1, it too often collides with other non batch size
731+
# dimensions resulting in errors.
732+
if batch_size and batch_size > 1:
733+
# Calculate new batch size by varying the original batch size by up to 20%
734+
# Ensure it's at least greater than 1
735+
variation = random.uniform(0.8, 1.2)
736+
new_batch_size = max(2, int(batch_size * variation))
737+
example_inputs = tree_map_only(
738+
torch.Tensor, lambda x: vary_batch(x, new_batch_size), example_inputs
739+
)
704740
# Put this call inside the loop to reset the seed for each iteration.
705741
# Don't include reset_rng_state() to correctly measure timing
706742
reset_rng_state(use_xla)
@@ -1071,6 +1107,7 @@ def maybe_mark_profile(*args, **kwargs):
10711107
return_result=True,
10721108
times=times,
10731109
collect_outputs=args.collect_outputs,
1110+
batch_size=kwargs.get("batch_size"),
10741111
)
10751112

10761113
# call mark_step between the 2 calls to make the comparison fair.
@@ -2478,7 +2515,14 @@ def warmup(fn, model, example_inputs, mode, niters=10):
24782515
return " ".join(map(str, results))
24792516

24802517
def run_performance_test(
2481-
self, name, model, example_inputs, optimize_ctx, experiment, tag=None
2518+
self,
2519+
name,
2520+
model,
2521+
example_inputs,
2522+
optimize_ctx,
2523+
experiment,
2524+
tag=None,
2525+
batch_size=None,
24822526
):
24832527
if self.args.xla:
24842528
with self.pick_grad(name, self.args.training):
@@ -2536,6 +2580,7 @@ def warmup(fn, model, example_inputs, mode, niters=5):
25362580
with self.pick_grad(name, self.args.training), ctx:
25372581
ok, total = Stats.reset_counters()
25382582
experiment_kwargs = {}
2583+
experiment_kwargs["batch_size"] = batch_size
25392584
if tag is not None:
25402585
experiment_kwargs["tag"] = tag
25412586
results = []
@@ -2699,6 +2744,7 @@ def run_one_model(
26992744
experiment,
27002745
explain=False,
27012746
tag=None,
2747+
batch_size=None,
27022748
):
27032749
mode = "train" if self.args.training else "eval"
27042750
msg = f"{current_device:4} {mode:5} {current_name:34} "
@@ -2727,7 +2773,13 @@ def run_one_model(
27272773
)
27282774
else:
27292775
status = self.run_performance_test(
2730-
name, model, example_inputs, optimize_ctx, experiment, tag
2776+
name,
2777+
model,
2778+
example_inputs,
2779+
optimize_ctx,
2780+
experiment,
2781+
tag,
2782+
batch_size=batch_size,
27312783
)
27322784
print(status)
27332785
empty_gpu_cache(current_device)
@@ -4064,6 +4116,7 @@ def detect_and_mark_batch(t):
40644116
experiment,
40654117
explain=args.explain,
40664118
tag=args.tag,
4119+
batch_size=batch_size if args.dynamic_batch_only else None,
40674120
)
40684121
if args.generate_aot_autograd_stats:
40694122
stats_file = output_filename.split(".csv")[0] + "_stats.csv"

0 commit comments

Comments
 (0)