Skip to content

Commit 330a3d1

Browse files
mengluyfacebook-github-bot
authored andcommitted
[1/n][Optimus][Auto-AC] Support activation quantization without scaling (pytorch#148380)
Summary: X-link: pytorch/benchmark#2607 Pull Request resolved: pytorch#148380 We enable the activation quantization in the forward pass, and users can customize the dtype they want to quantize. Test Plan: # unit test ``` buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/inductor:quantization -- test_activation_quantization_aten ``` Buck UI: https://www.internalfb.com/buck2/9a53c909-d3ea-479a-874e-cc917999ca88 Test UI: https://www.internalfb.com/intern/testinfra/testrun/12384899050440719 Network: Up: 62KiB Down: 81KiB (reSessionID-913ca82d-c395-4492-818e-6e004df37f87) Executing actions. Remaining 0/4 6.1s exec time total Command: test. Finished 2 local Time elapsed: 3:22.9s Tests finished: Pass 2. Fail 0. Fatal 0. Skip 0. Build failure 0 # E2E ### how to enable ``` post_grad_fusion_options={ "activation_quantization_aten_pass": {"quant_type": "torch.float8_e5m2"}, }, ``` see D51860030 to check how to set the config under dynamo_config_map Note: you can change the quant_type, if nothing gives, then the default type torch.float8_e5m2 will be used to quantize #### If you use FSDP - You may also need to set inline_inbuilt_nn_modules to true for models that use FSDP (see D70023488 to check the config setting) - Remove UNSAFE_SKIP_FSDP_MODULE_GUARDS=1 (context: https://fb.workplace.com/groups/1075192433118967/permalink/1629608671010671/) ``` buck2 run mode/opt //aps_models/ads/icvr:icvr_launcher -- mode=mast_omnifm_v1-5_mwb launcher.max_retries=3 data_loader.dataset.batch_size=8 launcher.data_project=oncall_ads_model_platform launcher.fbl_entitlement=ads_global_tc_training_efficiency_qps max_ind_range=1 launcher.num_workers=8 data_loader.reading_service.num_remote_dpp_workers=30 data_loader.dataset.num_batches=100 trainer.gpu_tracer.wait=50 trainer.gpu_tracer.active=3 trainer.gpu_tracer.overhead_detection=10 launcher.tags=[ads_ranking_taxonomy_mc_qps_optimization] ``` aps-512_8_remove_fsdp_guards-92ae3972ba tlparse: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/aps-512_8_remove_fsdp_guards-92ae3972ba/attempt_0/version_0/rank_0/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000 baseline w/o fp8 quantization: aps-mengluy_remove_fsdp-ce75b306fa w/ fp8 quantization: aps-mengluy_remove_fsdp_fp8-96541deec4 ### QPS {F1977040587} ### memory baseline {F1977040640} memory snapshot: https://www.internalfb.com/ai_infra/zoomer/profiling-run/insights?profilingRunID=1767027467197075&tab=INSIGHTS&primarySubtab=Memory%20Analysis&secondarySubtab=Memory%20Snapshot with fp8 {F1977040641} memory snapshot: https://www.internalfb.com/ai_infra/zoomer/profiling-run/insights?profilingRunID=639378375763157&tab=INSIGHTS&primarySubtab=Memory%20Analysis&secondarySubtab=Memory%20Snapshot ### conclusion: - ~9% qps improvement, reduces peak memory from 82.01 to 78.97. - for NE, we need have longer verification, WIP with scaling version. Differential Revision: D70522237
1 parent 5bb154e commit 330a3d1

File tree

9 files changed

+415
-18
lines changed

9 files changed

+415
-18
lines changed

test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ torch.fx._symbolic_trace.Tracer.trace(self, root: Union[torch.nn.modules.module.
88
torch.fx._symbolic_trace.symbolic_trace(root: Union[torch.nn.modules.module.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.graph_module.GraphModule
99
torch.fx._symbolic_trace.wrap(fn_or_name: Union[str, Callable])
1010
torch.fx.graph.Graph.__init__(self, owning_module: Optional[GraphModule] = None, tracer_cls: Optional[Type[Tracer]] = None, tracer_extras: Optional[Dict[str, Any]] = None)
11-
torch.fx.graph.Graph.call_function(self, the_function: Callable[..., Any], args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None) -> torch.fx.node.Node
11+
torch.fx.graph.Graph.call_function(self, the_function: Callable[..., Any], args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None, name: Optional[str] = None) -> torch.fx.node.Node
1212
torch.fx.graph.Graph.call_method(self, method_name: str, args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None) -> torch.fx.node.Node
1313
torch.fx.graph.Graph.call_module(self, module_name: str, args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None) -> torch.fx.node.Node
1414
torch.fx.graph.Graph.create_node(self, op: str, target: 'Target', args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, name: Optional[str] = None, type_expr: Optional[Any] = None) -> torch.fx.node.Node

test/inductor/test_quantization.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# Owner(s): ["module: inductor"]
2+
3+
import logging
4+
5+
import numpy as np
6+
7+
import torch
8+
import torch._inductor
9+
import torch._inductor.fx_passes.group_batch_fusion
10+
from torch._dynamo.utils import counters
11+
from torch._inductor.test_case import run_tests, TestCase
12+
from torch.testing._internal.common_utils import IS_LINUX
13+
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, requires_gpu
14+
15+
16+
log = logging.getLogger(__name__)
17+
18+
19+
class TargetCPModule(torch.nn.Module):
20+
def __init__(self):
21+
super().__init__()
22+
23+
def forward(self, x1, x2):
24+
relued = torch.relu(x1)
25+
tanhed = torch.tanh(relued)
26+
tensor = torch.matmul(
27+
tanhed,
28+
x2,
29+
)
30+
return tensor
31+
32+
33+
class FeedforwardNN(torch.nn.Module):
34+
def __init__(self):
35+
super().__init__()
36+
self.fc1 = torch.nn.Linear(1, 64)
37+
self.fc2 = torch.nn.Linear(64, 64)
38+
self.fc3 = torch.nn.Linear(64, 64)
39+
self.fc4 = torch.nn.Linear(64, 1)
40+
41+
def forward(self, x):
42+
x = torch.relu(self.fc1(x))
43+
tanh_x = torch.tanh(x)
44+
x = torch.relu(self.fc2(x))
45+
x = torch.relu(self.fc3(tanh_x))
46+
x = self.fc4(x)
47+
return x
48+
49+
50+
class TestQuantization(TestCase):
51+
def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3):
52+
if len(set(ref_dict.keys())) != len(set(res_dict.keys())):
53+
return False
54+
for key1 in ref_dict.keys():
55+
key2 = "_orig_mod." + key1
56+
assert key2 in res_dict, f"{key1} does not exist in traced module"
57+
# if both of them are None, continue
58+
if (
59+
not isinstance(ref_dict[key1], torch.Tensor)
60+
and not isinstance(res_dict[key2], torch.Tensor)
61+
and ref_dict[key1] is None
62+
and res_dict[key2] is None
63+
):
64+
log.info(
65+
"None found with key1 and value 1: %s, %s, key2 and value2 %s, %s",
66+
key1,
67+
ref_dict[key1],
68+
key2,
69+
res_dict[key2],
70+
)
71+
continue
72+
elif not torch.allclose(
73+
ref_dict[key1], res_dict[key2], rtol=rtol, atol=atol, equal_nan=True
74+
):
75+
log.info(
76+
"gradient mismatch for eager and compiled modules, with eager: %s and compiled: %s",
77+
ref_dict[key1],
78+
res_dict[key2],
79+
)
80+
return False
81+
return True
82+
83+
def compare_pred(self, module, traced, input, rtol=1e-3, atol=1e-3):
84+
ref = module(*input)
85+
res = traced(*input)
86+
self.assertEqual(ref, res, rtol=rtol, atol=atol)
87+
88+
def compare_parameters(self, module, traced, rtol=1e-3, atol=1e-3):
89+
ref_params = dict(module.named_parameters())
90+
res_params = dict(traced.named_parameters())
91+
self.assertTrue(self.compare_dict_tensors(ref_params, res_params, rtol, atol))
92+
93+
def compare_gradients(self, module, traced, rtol=1e-3, atol=1e-3):
94+
ref_grad = {key: param.grad for key, param in module.named_parameters()}
95+
res_grad = {key: param.grad for key, param in traced.named_parameters()}
96+
self.assertTrue(
97+
self.compare_dict_tensors(ref_grad, res_grad, rtol=rtol, atol=atol)
98+
)
99+
100+
@requires_gpu()
101+
@torch._inductor.config.patch(
102+
pre_grad_fusion_options={},
103+
post_grad_fusion_options={
104+
"activation_quantization_aten_pass": {
105+
"quant_type": "torch.float8_e5m2",
106+
"size_in_mb": 0.0,
107+
},
108+
},
109+
)
110+
def test_activation_quantization_aten(self):
111+
counters.clear()
112+
module = TargetCPModule().to(GPU_TYPE)
113+
input = [
114+
torch.rand(
115+
(16, 10), requires_grad=True, device=GPU_TYPE, dtype=torch.bfloat16
116+
),
117+
torch.rand(
118+
(10, 16), requires_grad=True, device=GPU_TYPE, dtype=torch.bfloat16
119+
),
120+
]
121+
traced = torch.compile(module)
122+
ref = module(*input)
123+
res = traced(*input)
124+
self.compare_pred(module, traced, input)
125+
ref.sum().backward()
126+
res.sum().backward()
127+
self.compare_parameters(module, traced)
128+
self.compare_gradients(module, traced)
129+
self.assertEqual(
130+
counters["inductor"]["activation_quantization_fwd_aten_pass"], 1
131+
)
132+
self.assertEqual(
133+
counters["inductor"]["activation_quantization_bwd_aten_pass"], 1
134+
)
135+
self.assertTrue(torch.allclose(ref, res))
136+
counters.clear()
137+
138+
module = FeedforwardNN().to(GPU_TYPE)
139+
X = np.linspace(-10, 10, 100).reshape(-1, 1).astype(np.float32)
140+
input = [
141+
torch.from_numpy(X).to(GPU_TYPE),
142+
]
143+
traced = torch.compile(module)
144+
ref = module(*input)
145+
res = traced(*input)
146+
self.compare_pred(module, traced, input)
147+
ref.sum().backward()
148+
res.sum().backward()
149+
self.compare_parameters(module, traced)
150+
self.compare_gradients(module, traced)
151+
self.assertEqual(
152+
counters["inductor"]["activation_quantization_fwd_aten_pass"], 1
153+
)
154+
self.assertEqual(
155+
counters["inductor"]["activation_quantization_bwd_aten_pass"], 1
156+
)
157+
self.assertTrue(torch.allclose(ref, res))
158+
counters.clear()
159+
160+
161+
if __name__ == "__main__":
162+
if IS_LINUX and HAS_GPU:
163+
run_tests()

torch/_dynamo/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4588,3 +4588,7 @@ def maybe_disable_inference_mode_for_fake_prop() -> Generator[None, None, None]:
45884588
yield
45894589
else:
45904590
yield
4591+
4592+
4593+
def is_node_meta_valid(node: Optional[torch.fx.Node]) -> bool:
4594+
return node is None or "example_value" in node.meta or "val" in node.meta

0 commit comments

Comments
 (0)