Skip to content

Commit 6bff330

Browse files
xuzhao9facebook-github-bot
authored andcommitted
Add NCU Trace generation
Summary: Generate NCU Trace for the triton kernel and input batch Reviewed By: chenyang78 Differential Revision: D56047231 fbshipit-source-id: a0a18f12daeeeae9f5c9e8adc1568f3be98bd9b1
1 parent 1a1a1f8 commit 6bff330

File tree

2 files changed

+110
-11
lines changed

2 files changed

+110
-11
lines changed
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
2+
from typing import Callable
3+
4+
def do_bench_ncu_in_task(fn: Callable, warmup=25, grad_to_none=None, fast_flush=True, output_dir=None) -> None:
5+
"""
6+
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
7+
the 20-th and 80-th performance percentile.
8+
9+
:param fn: Function to benchmark
10+
:type fn: Callable
11+
:param warmup: Warmup time (in ms)
12+
:type warmup: int
13+
:param grad_to_none: Reset the gradient of the provided tensor to None
14+
:type grad_to_none: torch.tensor, optional
15+
:param fast_flush: Use faster kernel to flush L2 between measurements
16+
:type fast_flush: bool
17+
:param output_dir: Output directory to store the trace
18+
:type output_dir: str, optional
19+
"""
20+
import torch
21+
22+
fn()
23+
torch.cuda.synchronize()
24+
25+
# We maintain a buffer of 256 MB that we clear
26+
# before each kernel call to make sure that the L2
27+
# doesn't contain any input data before the run
28+
if fast_flush:
29+
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
30+
else:
31+
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
32+
33+
# Estimate the runtime of the function
34+
start_event = torch.cuda.Event(enable_timing=True)
35+
end_event = torch.cuda.Event(enable_timing=True)
36+
start_event.record()
37+
for _ in range(5):
38+
cache.zero_()
39+
fn()
40+
end_event.record()
41+
torch.cuda.synchronize()
42+
estimate_ms = start_event.elapsed_time(end_event) / 5
43+
44+
# compute number of warmup and repeat
45+
n_warmup = max(1, int(warmup / estimate_ms))
46+
# Warm-up
47+
for _ in range(n_warmup):
48+
fn()
49+
# Start ncu profiling
50+
torch.cuda.cudart().cudaProfilerStart()
51+
# we don't want `fn` to accumulate gradient values
52+
# if it contains a backward pass. So we clear the
53+
# provided gradients
54+
if grad_to_none is not None:
55+
for x in grad_to_none:
56+
x.grad = None
57+
# we clear the L2 cache before run
58+
cache.zero_()
59+
fn()
60+
torch.cuda.cudart().cudaProfilerStop()

torchbenchmark/util/triton_op.py

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
REGISTERED_BENCHMARKS: Dict[str, List[str]] = {}
2525
REGISTERED_METRICS: Dict[str, List[str]] = {}
2626
BASELINE_BENCHMARKS: Dict[str, str] = {}
27-
BUILTIN_METRICS = ["latency", "tflops", "speedup", "accuracy", "compile_time"]
27+
BUILTIN_METRICS = ["latency", "tflops", "speedup", "accuracy", "compile_time", "ncu_trace"]
2828
BASELINE_SKIP_METRICS = ["speedup", "accuracy"]
2929
PRECISION_DTYPE_MAPPING = {
3030
"fp32": torch.float32,
@@ -70,6 +70,15 @@ def do_bench_walltime(fn, warmup=25, rep=100):
7070
wall_time_ms = (end_time - start_time) * 1e3 / n_repeat
7171
return wall_time_ms
7272

73+
def _find_param_loc(l, key: str) -> int:
74+
try:
75+
return l.index(key)
76+
except ValueError:
77+
return -1
78+
def _remove_params(l, loc):
79+
if loc == -1:
80+
return l
81+
return l[:loc] + l[loc+2:]
7382

7483
@dataclass
7584
class BenchmarkOperatorMetrics:
@@ -85,6 +94,8 @@ class BenchmarkOperatorMetrics:
8594
walltime: Optional[float]
8695
# compile time
8796
compile_time: Optional[float]
97+
# ncu trace file
98+
ncu_trace: Optional[str]
8899
# error message
89100
error_msg: Optional[str]
90101
# extra metrics
@@ -544,13 +555,16 @@ def _do_bench(
544555
accuracy=accuracy,
545556
walltime=walltime,
546557
compile_time=None,
558+
ncu_trace=None,
547559
error_msg=error_msg,
548560
extra_metrics={},
549561
)
550562
if "tflops" in self.required_metrics:
551563
metric.tflops = self.tflops(fn_name, self.example_inputs, metric)
552564
if "compile_time" in self.required_metrics:
553565
metric.compile_time = self.compile_time(batch_id, fn_name, metric)
566+
if "ncu_trace" in self.required_metrics:
567+
metric.ncu_trace = self.ncu_trace(batch_id, fn_name)
554568
extra_metrics = {}
555569
# run the hidden metric "_compile_time_in_task"
556570
# to get the compile time in parent process
@@ -559,6 +573,13 @@ def _do_bench(
559573
"_compile_time_in_task must be measured by itself. " \
560574
f"required_metrics: {self.required_metrics}, _only: {self._only}, _batch_id: {self._batch_id}"
561575
extra_metrics["_compile_time_in_task"] = self._compile_time_in_task(fn)
576+
if "_ncu_trace_in_task" in self.required_metrics:
577+
assert self.required_metrics == ["_ncu_trace_in_task"] and self._only and (self._batch_id is not None), \
578+
"_ncu_trace_in_task must be measured by itself. " \
579+
f"required_metrics: {self.required_metrics}, _only: {self._only}, _batch_id: {self._batch_id}"
580+
from torchbenchmark._components.ncu import do_bench_ncu_in_task
581+
do_bench_ncu_in_task(fn=fn, warmup=warmup, grad_to_none=self.get_grad_to_none(self.example_inputs))
582+
extra_metrics["_ncu_trace_in_task"] = "success"
562583
# generate customized metrics
563584
if self.name in REGISTERED_METRICS:
564585
for metric_name in REGISTERED_METRICS[self.name]:
@@ -577,29 +598,47 @@ def _do_bench(
577598
accuracy=None,
578599
walltime=None,
579600
compile_time=None,
601+
ncu_trace=None,
580602
error_msg="CUDA OOM",
581603
extra_metrics={},
582604
)
583605
return metric
584606

585607

608+
@register_metric()
609+
def ncu_trace(self, batch_id: int, fn_name: str) -> str:
610+
# collect the ncu trace
611+
import sys
612+
import subprocess
613+
from pathlib import Path
614+
op_task_args = copy.deepcopy(sys.argv)
615+
for override_option in ["--only", "--batch-id", "--metrics"]:
616+
op_task_args = _remove_params(op_task_args, _find_param_loc(op_task_args, override_option))
617+
op_task_args.extend(["--only", fn_name, "--batch-id", str(batch_id), "--metrics", "_ncu_trace_in_task"])
618+
# Disable DCGM
619+
try:
620+
disable_dcgm = ["sudo", "dyno", "dcgm_profiling", "--mute=true", "--duration=1000_s"]
621+
subprocess.run(disable_dcgm, check=True)
622+
except subprocess.SubprocessError:
623+
warnings.warn("Cannot find dyno to disable DCGM. Proceed to collect NCU Trace.")
624+
ncu_output_dir = Path(f"/tmp/tritonbench_{self.name}_{fn_name}_{batch_id}")
625+
ncu_output_dir.mkdir(parents=True, exist_ok=True)
626+
ncu_output_file = ncu_output_dir.joinpath("ncu_output.csv").resolve()
627+
ncu_args = ["ncu", "--set", "full", "--replay-mode", "range", "--target-processes", "all", \
628+
"--csv", "-f", "--log-file", str(ncu_output_file.resolve())]
629+
ncu_args.extend(op_task_args)
630+
subprocess.check_call(ncu_args)
631+
return str(ncu_output_file.resolve())
632+
633+
586634
@register_metric()
587635
def compile_time(self, batch_id: int, fn_name: str, metrics: BenchmarkOperatorMetrics) -> float:
588636
# We need to spawn a subprocess when user wants to measure the compile time
589637
# of multiple batches and backends.
590-
def _find_loc(l, key: str) -> int:
591-
try:
592-
return l.index(key)
593-
except ValueError:
594-
return -1
595-
def _remove_element(l, loc):
596-
if loc == -1:
597-
return l
598-
return l[:loc] + l[loc+2:]
599638
from torchbenchmark.operators.op_task import OpTask
600639
op_task_args = copy.deepcopy(self._raw_extra_args)
601640
for override_option in ["--only", "--batch-id", "--metrics"]:
602-
op_task_args = _remove_element(op_task_args, _find_loc(op_task_args, override_option))
641+
op_task_args = _remove_params(op_task_args, _find_param_loc(op_task_args, override_option))
603642
op_task_args.extend(["--only", fn_name, "--batch-id", str(batch_id), "--metrics", "_compile_time_in_task"])
604643
op_task = OpTask(name=self.name)
605644
op_task.make_operator_instance(mode=self.mode.value, device=self.device, extra_args=op_task_args)

0 commit comments

Comments
 (0)