Skip to content

Commit edf3233

Browse files
shunting314facebook-github-bot
authored andcommitted
don't return logits for benchmark script (#151075)
Summary: PT2 benchmark scripts has a pattern like: ``` def forward_and_backward_pass(self, mod, inputs, collect_outputs=True): cloned_inputs = clone_inputs(inputs) self.optimizer_zero_grad(mod) with self.autocast(**self.autocast_arg): pred = mod(**cloned_inputs) loss = self.compute_loss(pred) self.grad_scaler.scale(loss).backward() self.optimizer_step() if collect_outputs: return collect_results(mod, pred, loss, cloned_inputs) return None ``` for training. The collect_outputs argument is True only for accuracy testing and it's false for performance testing. For HF benchmark suite, a model usually returns tuple (loss, logits). For performance testing, even though the logits is never used anywhere, dynamo has to keep it due to the control flow. A few bad things if we keep logits here 1. the peak memory will be higher since the logits is large and we can not release its memory earlier. 2. we can not do optimization like chunking for the logits because the tensor needs to be returned from the pre-grad graph Actually I think it's fine to not return logits at all. - For training cases, checking loss and gradients for accuracy is good enough. It's hard to see two runs have mismatch logits but matching loss/gradients. - Also, discarding logits as soon as possible for perf benchmarking makes it more fair for us. On the other hand, it may be interesting to let dynamo support something like dynamo.constexpr (similar to tl.constexpr). A variable annotated as dynamo.constexpr will be specialized at compile time and we can do more optimization (DCE e.g.) at compile time. (A small [repro](https://gist.github.com/shunting314/0912a8947028a904c34f361021b8024d)) Benchmark results here [link](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Fri%2C%2004%20Apr%202025%2018%3A03%3A26%20GMT&stopTime=Fri%2C%2011%20Apr%202025%2018%3A03%3A26%20GMT&granularity=hour&mode=training&dtype=amp&deviceName=cuda%20(h100)&lBranch=gh/shunting314/204/head&lCommit=fe25dab3f65e1b0e9db0af03f7664af70fcc9c66&rBranch=main&rCommit=55e62ff74ad5614faf80b060c7bfc551e3b7af5a) - HF 15% (1.51 -> 1.66 compression ratio) peak memory improvement - I also see 5% (2.74 -> 2.79x) perf win for HF. It could be true. We may generate more efficient kernels since we don't need keep logits and return it from the pre-grad graph. But I'll double check X-link: pytorch/pytorch#151075 Approved by: https://github.com/eellison, https://github.com/jansel Reviewed By: Camyll Differential Revision: D73068291 fbshipit-source-id: 709218784f3f0673f434cf4fc5094f0fd64dfeee
1 parent 2a7fdca commit edf3233

File tree

3 files changed

+6
-5
lines changed

3 files changed

+6
-5
lines changed

userbenchmark/dynamo/dynamobench/huggingface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,7 @@ def forward_and_backward_pass(self, mod, inputs, collect_outputs=True):
536536
self.grad_scaler.scale(loss).backward()
537537
self.optimizer_step()
538538
if collect_outputs:
539-
return collect_results(mod, pred, loss, cloned_inputs)
539+
return collect_results(mod, None, loss, cloned_inputs)
540540
return None
541541

542542

userbenchmark/dynamo/dynamobench/timm_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ def forward_and_backward_pass(self, mod, inputs, collect_outputs=True):
428428
self.grad_scaler.scale(loss).backward()
429429
self.optimizer_step()
430430
if collect_outputs:
431-
return collect_results(mod, pred, loss, cloned_inputs)
431+
return collect_results(mod, None, loss, cloned_inputs)
432432
return None
433433

434434

userbenchmark/dynamo/dynamobench/torchbench.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,9 @@ def process_hf_whisper_output(out):
8585
out_ret = []
8686
for i, elem in enumerate(out):
8787
if i == 0:
88-
assert isinstance(elem, dict)
89-
out_ret.append({k: v for k, v in elem.items() if k != "logits"})
88+
if elem is not None:
89+
assert isinstance(elem, dict)
90+
out_ret.append({k: v for k, v in elem.items() if k != "logits"})
9091
elif i != 1:
9192
out_ret.append(elem)
9293

@@ -470,7 +471,7 @@ def forward_and_backward_pass(self, mod, inputs, collect_outputs=True):
470471
self.grad_scaler.scale(loss).backward()
471472
self.optimizer_step()
472473
if collect_outputs:
473-
return collect_results(mod, pred, loss, cloned_inputs)
474+
return collect_results(mod, None, loss, cloned_inputs)
474475
return None
475476

476477

0 commit comments

Comments
 (0)