|
14 | 14 | import json
|
15 | 15 | import logging
|
16 | 16 | import os
|
| 17 | +import random |
17 | 18 | import shutil
|
18 | 19 | import signal
|
19 | 20 | import subprocess
|
@@ -690,17 +691,52 @@ def timed(
|
690 | 691 | times=1,
|
691 | 692 | return_result=False,
|
692 | 693 | collect_outputs=False,
|
| 694 | + batch_size=None, |
693 | 695 | ):
|
694 | 696 | use_xla = tensor_is_on_xla(example_inputs)
|
695 | 697 | synchronize()
|
696 | 698 |
|
| 699 | + if batch_size: |
| 700 | + patch_torch_manual_seed() |
| 701 | + |
697 | 702 | if use_xla:
|
698 | 703 | xm.mark_step()
|
699 | 704 | xm.wait_device_ops()
|
700 | 705 |
|
| 706 | + def vary_batch(t: torch.Tensor, new_batch_size) -> torch.Tensor: |
| 707 | + for i, s in enumerate(t.size()): |
| 708 | + if s == batch_size: |
| 709 | + # If new batch is smaller, we truncate |
| 710 | + if new_batch_size < batch_size: |
| 711 | + indexer = [slice(None)] * t.ndim |
| 712 | + indexer[i] = slice(0, new_batch_size) |
| 713 | + t = t[tuple(indexer)] |
| 714 | + # If new batch is greater, we just duplicate the last row |
| 715 | + # over and over until we hit the desired batch size |
| 716 | + elif new_batch_size > batch_size: |
| 717 | + indexer = [slice(None)] * t.ndim |
| 718 | + indexer[i] = -1 |
| 719 | + last_slice = t[tuple(indexer)].unsqueeze(i) |
| 720 | + repeat_shape = list(t.shape) |
| 721 | + repeat_shape[i] = new_batch_size - batch_size |
| 722 | + padding = last_slice.expand(*repeat_shape) |
| 723 | + t = torch.cat([t, padding], dim=i) |
| 724 | + break |
| 725 | + return t |
| 726 | + |
701 | 727 | time_total = 0
|
702 | 728 | # Dont collect outputs to correctly measure timing
|
703 | 729 | for _ in range(times):
|
| 730 | + # If batch_size is 1, it too often collides with other non batch size |
| 731 | + # dimensions resulting in errors. |
| 732 | + if batch_size and batch_size > 1: |
| 733 | + # Calculate new batch size by varying the original batch size by up to 20% |
| 734 | + # Ensure it's at least greater than 1 |
| 735 | + variation = random.uniform(0.8, 1.2) |
| 736 | + new_batch_size = max(2, int(batch_size * variation)) |
| 737 | + example_inputs = tree_map_only( |
| 738 | + torch.Tensor, lambda x: vary_batch(x, new_batch_size), example_inputs |
| 739 | + ) |
704 | 740 | # Put this call inside the loop to reset the seed for each iteration.
|
705 | 741 | # Don't include reset_rng_state() to correctly measure timing
|
706 | 742 | reset_rng_state(use_xla)
|
@@ -1071,6 +1107,7 @@ def maybe_mark_profile(*args, **kwargs):
|
1071 | 1107 | return_result=True,
|
1072 | 1108 | times=times,
|
1073 | 1109 | collect_outputs=args.collect_outputs,
|
| 1110 | + batch_size=kwargs.get("batch_size"), |
1074 | 1111 | )
|
1075 | 1112 |
|
1076 | 1113 | # call mark_step between the 2 calls to make the comparison fair.
|
@@ -2478,7 +2515,14 @@ def warmup(fn, model, example_inputs, mode, niters=10):
|
2478 | 2515 | return " ".join(map(str, results))
|
2479 | 2516 |
|
2480 | 2517 | def run_performance_test(
|
2481 |
| - self, name, model, example_inputs, optimize_ctx, experiment, tag=None |
| 2518 | + self, |
| 2519 | + name, |
| 2520 | + model, |
| 2521 | + example_inputs, |
| 2522 | + optimize_ctx, |
| 2523 | + experiment, |
| 2524 | + tag=None, |
| 2525 | + batch_size=None, |
2482 | 2526 | ):
|
2483 | 2527 | if self.args.xla:
|
2484 | 2528 | with self.pick_grad(name, self.args.training):
|
@@ -2536,6 +2580,7 @@ def warmup(fn, model, example_inputs, mode, niters=5):
|
2536 | 2580 | with self.pick_grad(name, self.args.training), ctx:
|
2537 | 2581 | ok, total = Stats.reset_counters()
|
2538 | 2582 | experiment_kwargs = {}
|
| 2583 | + experiment_kwargs["batch_size"] = batch_size |
2539 | 2584 | if tag is not None:
|
2540 | 2585 | experiment_kwargs["tag"] = tag
|
2541 | 2586 | results = []
|
@@ -2699,6 +2744,7 @@ def run_one_model(
|
2699 | 2744 | experiment,
|
2700 | 2745 | explain=False,
|
2701 | 2746 | tag=None,
|
| 2747 | + batch_size=None, |
2702 | 2748 | ):
|
2703 | 2749 | mode = "train" if self.args.training else "eval"
|
2704 | 2750 | msg = f"{current_device:4} {mode:5} {current_name:34} "
|
@@ -2727,7 +2773,13 @@ def run_one_model(
|
2727 | 2773 | )
|
2728 | 2774 | else:
|
2729 | 2775 | status = self.run_performance_test(
|
2730 |
| - name, model, example_inputs, optimize_ctx, experiment, tag |
| 2776 | + name, |
| 2777 | + model, |
| 2778 | + example_inputs, |
| 2779 | + optimize_ctx, |
| 2780 | + experiment, |
| 2781 | + tag, |
| 2782 | + batch_size=batch_size, |
2731 | 2783 | )
|
2732 | 2784 | print(status)
|
2733 | 2785 | empty_gpu_cache(current_device)
|
@@ -4064,6 +4116,7 @@ def detect_and_mark_batch(t):
|
4064 | 4116 | experiment,
|
4065 | 4117 | explain=args.explain,
|
4066 | 4118 | tag=args.tag,
|
| 4119 | + batch_size=batch_size if args.dynamic_batch_only else None, |
4067 | 4120 | )
|
4068 | 4121 | if args.generate_aot_autograd_stats:
|
4069 | 4122 | stats_file = output_filename.split(".csv")[0] + "_stats.csv"
|
|
0 commit comments