Skip to content

Commit d5d2762

Browse files
int3facebook-github-bot
authored andcommitted
Log more errors + make CSV writing more robust
Summary: Previously, we only caught specific CUDA OOM errors, but benchmarks can fail in other ways too. Let's make it more robust by catching and logging all exceptions. While there is already code to log exception messages, it often leads to malformed CSVs since there was no quoting going on. We should use Python's csv module to avoid this issue. Additionally, the previous logic would record the error message in each metric column of the failed benchmark. This was redundant, so I've changed it to emit the message only once. Finally, since Python's csv writer writes directly to a file, instead of creating a string first, the previous csv file naming convention using the hash of its contents no longer applies. Instead I've used NamedTemporaryFile to get a unique file name. Reviewed By: chenyang78 Differential Revision: D57785120 fbshipit-source-id: 73c76bba7661b60a7357aaba3d5b9659b533479e
1 parent d1a6363 commit d5d2762

File tree

2 files changed

+32
-19
lines changed

2 files changed

+32
-19
lines changed

torchbenchmark/util/triton_op.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import functools
44
import gc
55
import json
6+
import os
67
import random
78
import time
89
import warnings
@@ -206,10 +207,11 @@ def select_metric(m):
206207
row.append(x_only_metric_dict[x_only_metric])
207208
for k in y_val_keys:
208209
metrics_dict = asdict(y_val[k])
210+
if metrics_dict["error_msg"]:
211+
row.append(metrics_dict["error_msg"])
212+
row.extend([None] * (len(key_metrics[k]) - 1))
213+
continue
209214
for metric in key_metrics[k]:
210-
if metrics_dict["error_msg"]:
211-
row.append(metrics_dict["error_msg"])
212-
continue
213215
_metrics_dict = (
214216
metrics_dict["extra_metrics"]
215217
if metric in metrics_dict["extra_metrics"]
@@ -224,12 +226,28 @@ def select_metric(m):
224226
table.append(row)
225227
return headers, table
226228

227-
@property
228-
def csv(self):
229+
def write_csv_to_file(self, fileobj):
230+
import csv
231+
229232
headers, table = self._table()
230-
headers = "; ".join(headers)
231-
table = "\n".join(["; ".join([str(v) for v in row]) for row in table])
232-
return f"{headers}\n{table}"
233+
writer = csv.writer(fileobj, delimiter=";", quoting=csv.QUOTE_MINIMAL)
234+
writer.writerow(headers)
235+
writer.writerows(table)
236+
237+
def write_csv(self, dir_path):
238+
import tempfile
239+
240+
# This is just a way to create a unique filename. It's not actually a
241+
# temporary file (since delete=False).
242+
with tempfile.NamedTemporaryFile(
243+
mode='w',
244+
prefix=os.path.join(dir_path, f"op_{self.op_name}_"),
245+
suffix=".csv",
246+
newline="",
247+
delete=False,
248+
) as fileobj:
249+
self.write_csv_to_file(fileobj)
250+
return fileobj.name
233251

234252
@property
235253
def x_vals(self):
@@ -779,6 +797,8 @@ def _init_extra_metrics() -> Dict[str, Any]:
779797
metrics.extra_metrics[metric_name] = func(fn, self.example_inputs, metrics)
780798
except torch.cuda.OutOfMemoryError:
781799
metrics.error_msg = "CUDA OOM"
800+
except Exception as e:
801+
metrics.error_msg = str(e)
782802
return metrics
783803

784804
def get_peak_mem(

userbenchmark/triton/run.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def _run(args: argparse.Namespace, extra_args: List[str]) -> BenchmarkOperatorRe
5959
metrics = opbench.run(args.warmup, args.iter)
6060
if not args.skip_print:
6161
if args.csv:
62-
print(metrics.csv)
62+
metrics.write_csv_to_file(sys.stdout)
6363
else:
6464
print(metrics)
6565
if not hasattr(torch_version, "git_version") and args.log_scuba:
@@ -73,16 +73,9 @@ def _run(args: argparse.Namespace, extra_args: List[str]) -> BenchmarkOperatorRe
7373
print(f"Plotting is not implemented for {args.op}")
7474

7575
if args.dump_csv:
76-
if not os.path.exists(TRITON_BENCH_CSV_DUMP_PATH):
77-
os.mkdir(TRITON_BENCH_CSV_DUMP_PATH)
78-
79-
csv_str = metrics.csv
80-
csv_str_hash = abs(hash(csv_str)) % (10**8)
81-
file_name = f"op_{args.op}_{csv_str_hash}.csv"
82-
file_path = os.path.join(TRITON_BENCH_CSV_DUMP_PATH, file_name)
83-
with open(file_path, "w") as f:
84-
f.write(csv_str)
85-
print(f"[TritonBench] Dumped csv to {file_path}")
76+
os.makedirs(TRITON_BENCH_CSV_DUMP_PATH, exist_ok=True)
77+
path = metrics.write_csv(TRITON_BENCH_CSV_DUMP_PATH)
78+
print(f"[TritonBench] Dumped csv to {path}")
8679
return metrics
8780

8881
def run(args: List[str] = []):

0 commit comments

Comments
 (0)