Skip to content

Commit 7b7276d

Browse files
Ivan Kobzarevfacebook-github-bot
Ivan Kobzarev
authored andcommitted
Inductor freezing bfloat16 conv folding needs high tolerance (#145623)
Summary: Issue: pytorch/pytorch#144888 Torchbench of timm lcnet_050 model fails on accuracy in case of `--frezing` `--inference` `--bfloat16` `res_error==0.12` If to turn off convolution inductor constant folding - `res_error==0.016` `float16 error ~ 0.00669` `float16 without conv folding ~ 0.0018` convolution folding results in increase of error almost at one order of magnitude. I think we should revisit and try to do something to improve the accuracy for conv folding. E.g. For example doing conv folding at compilation time with float64? At the moment I am adding counters to identify if convolution folding happened, and in case of bfloat16 and conv_folding - increase multiplier to the max level (10) to pass accuracy test. X-link: pytorch/pytorch#145623 Approved by: https://github.com/eellison Reviewed By: ZainRizvi Differential Revision: D68897700 fbshipit-source-id: f407528b4b37eb45273a8c66f791c44e86c6632e
1 parent 373ffb1 commit 7b7276d

File tree

2 files changed

+49
-26
lines changed

2 files changed

+49
-26
lines changed

userbenchmark/dynamo/dynamobench/_dynamo/utils.py

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2528,6 +2528,7 @@ def same(
25282528
ignore_non_fp=False,
25292529
log_error=log.error,
25302530
use_larger_multiplier_for_smaller_tensor=False,
2531+
force_max_multiplier: bool = False,
25312532
):
25322533
"""Check correctness to see if ref and res match"""
25332534
if fp64_ref is None:
@@ -2554,6 +2555,7 @@ def same(
25542555
ignore_non_fp,
25552556
log_error=log_error,
25562557
use_larger_multiplier_for_smaller_tensor=use_larger_multiplier_for_smaller_tensor,
2558+
force_max_multiplier=force_max_multiplier,
25572559
)
25582560
for ai, bi, fp64_refi in zip(ref, res, fp64_ref)
25592561
)
@@ -2573,6 +2575,7 @@ def same(
25732575
ignore_non_fp,
25742576
log_error=log_error,
25752577
use_larger_multiplier_for_smaller_tensor=use_larger_multiplier_for_smaller_tensor,
2578+
force_max_multiplier=force_max_multiplier,
25762579
)
25772580
elif isinstance(ref, dict):
25782581
assert isinstance(res, dict)
@@ -2593,6 +2596,7 @@ def same(
25932596
ignore_non_fp=ignore_non_fp,
25942597
log_error=log_error,
25952598
use_larger_multiplier_for_smaller_tensor=use_larger_multiplier_for_smaller_tensor,
2599+
force_max_multiplier=force_max_multiplier,
25962600
)
25972601
):
25982602
log_error("Accuracy failed for key name %s", k)
@@ -2685,33 +2689,42 @@ def to_tensor(t):
26852689

26862690
res_error = rmse(fp64_ref, res).item()
26872691

2688-
# In the case of using AMP (Automatic Mixed Precision), certain models have
2689-
# failed the benchmark's correctness check. However, the end-to-end model's
2690-
# accuracy when comparing AMP with FP32 is within a difference of less than 0.1%.
2691-
# Thus, it's possible that the correctness check failures for these models are
2692-
# false alarms. We use multiplier of 3 instead of 2 to avoid these false alarms.
2693-
multiplier = (
2694-
3.0 if res.dtype in (torch.float16, torch.bfloat16) else 2.0
2695-
)
2692+
def get_multiplier():
2693+
# In some particular cases, we expect high difference in results.
2694+
# At the moment one of this cases is inductor freezing bfloat16 convolution const folding.
2695+
# In case of it the res_error is at least one order of magnitude higher.
2696+
if force_max_multiplier:
2697+
return 10.0
2698+
# In the case of using AMP (Automatic Mixed Precision), certain models have
2699+
# failed the benchmark's correctness check. However, the end-to-end model's
2700+
# accuracy when comparing AMP with FP32 is within a difference of less than 0.1%.
2701+
# Thus, it's possible that the correctness check failures for these models are
2702+
# false alarms. We use multiplier of 3 instead of 2 to avoid these false alarms.
2703+
multiplier = (
2704+
3.0 if res.dtype in (torch.float16, torch.bfloat16) else 2.0
2705+
)
26962706

2697-
if use_larger_multiplier_for_smaller_tensor and (
2698-
fp64_ref.numel() <= 10 and tol >= 4 * 1e-2
2699-
):
2700-
multiplier = 10.0
2701-
elif use_larger_multiplier_for_smaller_tensor and (
2702-
fp64_ref.numel() <= 500 and tol >= 4 * 1e-2
2703-
):
2704-
multiplier = 5.0
2705-
elif (
2706-
fp64_ref.numel() < 1000
2707-
or (ref.ndim == 4 and ref.shape[-1] == ref.shape[-2] == 1)
2708-
# large tol means a benchmark has been specified as REQUIRE_HIGHER_TOLERANCE
2709-
or tol >= 2 * 1e-2
2710-
):
2711-
# In the presence of noise, noise might dominate our error
2712-
# metric for smaller tensors.
2713-
# Similary, for 1x1 kernels, there seems to be high noise with amp.
2714-
multiplier = 3.0
2707+
if use_larger_multiplier_for_smaller_tensor and (
2708+
fp64_ref.numel() <= 10 and tol >= 4 * 1e-2
2709+
):
2710+
multiplier = 10.0
2711+
elif use_larger_multiplier_for_smaller_tensor and (
2712+
fp64_ref.numel() <= 500 and tol >= 4 * 1e-2
2713+
):
2714+
multiplier = 5.0
2715+
elif (
2716+
fp64_ref.numel() < 1000
2717+
or (ref.ndim == 4 and ref.shape[-1] == ref.shape[-2] == 1)
2718+
# large tol means a benchmark has been specified as REQUIRE_HIGHER_TOLERANCE
2719+
or tol >= 2 * 1e-2
2720+
):
2721+
# In the presence of noise, noise might dominate our error
2722+
# metric for smaller tensors.
2723+
# Similary, for 1x1 kernels, there seems to be high noise with amp.
2724+
multiplier = 3.0
2725+
return multiplier
2726+
2727+
multiplier = get_multiplier()
27152728

27162729
passes_test = res_error <= (multiplier * ref_error + tol / 10.0)
27172730
if (

userbenchmark/dynamo/dynamobench/common.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3054,6 +3054,7 @@ def record_status(accuracy_status, dynamo_start_stats):
30543054
# Run with Dynamo
30553055
reset_rng_state()
30563056
torch._dynamo.reset()
3057+
torch._dynamo.utils.counters.clear()
30573058
model_copy = None
30583059
try:
30593060
model_copy = self.deepcopy_and_maybe_parallelize(model)
@@ -3114,6 +3115,14 @@ def record_status(accuracy_status, dynamo_start_stats):
31143115
# The downside and potential problem, is that the output formats may be different.
31153116
# E.g., the output order might not match, None might be part of output, etc.
31163117

3118+
force_max_multiplier = False
3119+
if (
3120+
self.args.freezing
3121+
and self.args.bfloat16
3122+
and torch._dynamo.utils.counters["inductor"]["binary_folding_conv"] > 0
3123+
):
3124+
force_max_multiplier = True
3125+
31173126
try:
31183127
if self.args.training and self.args.amp:
31193128
if process_fn := self.get_output_amp_train_process_func.get(
@@ -3133,6 +3142,7 @@ def record_status(accuracy_status, dynamo_start_stats):
31333142
),
31343143
cos_similarity=cos_similarity,
31353144
tol=tolerance,
3145+
force_max_multiplier=force_max_multiplier,
31363146
):
31373147
is_same = False
31383148
except Exception:

0 commit comments

Comments
 (0)