Skip to content

Commit b2b4158

Browse files
int3facebook-github-bot
authored andcommitted
Use NVTX filtering to limit NCU profile collection
Summary: Previously, we used `--replay-mode range`, but that did not give us per-kernel metrics, so it was changed to `---replay-mode kernel` (the default). However, that can causes us to profile a lot more kernels outside the ones in the desired benchmark. It appears we can instead use NVTX filtering to solve this problem. Relevant docs: https://docs.nvidia.com/nsight-compute/NsightComputeCli/index.html#nvtx-filtering I also tacked on a minor change to the ncu invocation, adding `--import-source yes`. This makes it easier to analyze the traces on a different machine from the one doing the profiling. Reviewed By: chenyang78 Differential Revision: D58711358 fbshipit-source-id: 28aec4f71a736c7427b1886335297ece4a2a54a8
1 parent 62e2609 commit b2b4158

File tree

2 files changed

+34
-23
lines changed

2 files changed

+34
-23
lines changed

torchbenchmark/_components/ncu/__init__.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
1-
21
from typing import Callable
32

4-
def do_bench_ncu_in_task(fn: Callable, warmup=25, grad_to_none=None, fast_flush=True, output_dir=None) -> None:
3+
4+
def do_bench_ncu_in_task(
5+
fn: Callable,
6+
warmup=25,
7+
grad_to_none=None,
8+
fast_flush=True,
9+
output_dir=None,
10+
range_name: str = "",
11+
) -> None:
512
"""
613
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
714
the 20-th and 80-th performance percentile.
@@ -46,8 +53,6 @@ def do_bench_ncu_in_task(fn: Callable, warmup=25, grad_to_none=None, fast_flush=
4653
# Warm-up
4754
for _ in range(n_warmup):
4855
fn()
49-
# Start ncu profiling
50-
torch.cuda.cudart().cudaProfilerStart()
5156
# we don't want `fn` to accumulate gradient values
5257
# if it contains a backward pass. So we clear the
5358
# provided gradients
@@ -56,5 +61,5 @@ def do_bench_ncu_in_task(fn: Callable, warmup=25, grad_to_none=None, fast_flush=
5661
x.grad = None
5762
# we clear the L2 cache before run
5863
cache.zero_()
59-
fn()
60-
torch.cuda.cudart().cudaProfilerStop()
64+
with torch.cuda.nvtx.range(range_name):
65+
fn()

torchbenchmark/util/triton_op.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,10 @@ def __call__(cls, *args, **kwargs):
414414
obj.__post__init__()
415415
return obj
416416

417+
418+
_RANGE_NAME = "tritonbench_range"
419+
420+
417421
class BenchmarkOperator(metaclass=PostInitProcessor):
418422
mode: Mode = Mode.FWD
419423
test: str = "eval"
@@ -827,6 +831,7 @@ def _init_extra_metrics() -> Dict[str, Any]:
827831
fn=fn,
828832
warmup=warmup,
829833
grad_to_none=self.get_grad_to_none(self.example_inputs),
834+
range_name=_RANGE_NAME,
830835
)
831836
metrics.extra_metrics["_ncu_trace_in_task"] = "success"
832837
# generate customized metrics
@@ -901,26 +906,27 @@ def ncu_trace(self, input_id: int, fn_name: str, replay: bool=False) -> str:
901906
"ncu",
902907
"--set",
903908
"full",
904-
"--replay-mode",
905-
"kernel",
909+
"--nvtx",
910+
"--nvtx-include",
911+
f"{_RANGE_NAME}/",
906912
"--target-processes",
907913
"all",
908-
"--csv",
909-
"-f",
910-
"--log-file",
911-
str(ncu_output_file.resolve()),
912-
] if not replay else [
913-
"ncu",
914-
"--set",
915-
"full",
916-
"--replay-mode",
917-
"kernel",
918-
"--target-processes",
919-
"all",
920-
"-f",
921-
"-o",
922-
str(ncu_output_file.resolve()),
914+
"--import-source",
915+
"yes",
923916
]
917+
if replay:
918+
ncu_args.extend([
919+
"-f",
920+
"-o",
921+
str(ncu_output_file.resolve()),
922+
])
923+
else:
924+
ncu_args.extend([
925+
"--csv",
926+
"-f",
927+
"--log-file",
928+
str(ncu_output_file.resolve()),
929+
])
924930
ncu_args.extend(op_task_args)
925931
subprocess.check_call(ncu_args)
926932
return str(ncu_output_file.resolve())

0 commit comments

Comments
 (0)