Skip to content

Commit ee57ce6

Browse files
mengluy0125facebook-github-bot
authored andcommitted
[1/n][Optimus][Auto-AC] Support activation quantization without scaling (pytorch#148380)
Summary: X-link: pytorch/benchmark#2607 We enable the activation quantization in the forward pass, and users can customize the dtype they want to quantize. Test Plan: # unit test ``` buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/inductor:quantization -- test_activation_quantization_aten ``` Buck UI: https://www.internalfb.com/buck2/9a53c909-d3ea-479a-874e-cc917999ca88 Test UI: https://www.internalfb.com/intern/testinfra/testrun/12384899050440719 Network: Up: 62KiB Down: 81KiB (reSessionID-913ca82d-c395-4492-818e-6e004df37f87) Executing actions. Remaining 0/4 6.1s exec time total Command: test. Finished 2 local Time elapsed: 3:22.9s Tests finished: Pass 2. Fail 0. Fatal 0. Skip 0. Build failure 0 # E2E ### how to enable ``` post_grad_fusion_options={ "activation_quantization_aten_pass": {"quant_type": "torch.float8_e5m2"}, }, ``` see D51860030 to check how to set the config under dynamo_config_map Note: you can change the quant_type, if nothing gives, then the default type torch.float8_e5m2 will be used to quantize #### If you use FSDP - You may also need to set inline_inbuilt_nn_modules to true for models that use FSDP (see D70023488 to check the config setting) - Remove UNSAFE_SKIP_FSDP_MODULE_GUARDS=1 (context: https://fb.workplace.com/groups/1075192433118967/permalink/1629608671010671/) ``` buck2 run mode/opt //aps_models/ads/icvr:icvr_launcher -- mode=mast_omnifm_v1-5_mwb launcher.max_retries=3 data_loader.dataset.batch_size=8 launcher.data_project=oncall_ads_model_platform launcher.fbl_entitlement=ads_global_tc_training_efficiency_qps max_ind_range=1 launcher.num_workers=8 data_loader.reading_service.num_remote_dpp_workers=30 data_loader.dataset.num_batches=100 trainer.gpu_tracer.wait=50 trainer.gpu_tracer.active=3 trainer.gpu_tracer.overhead_detection=10 launcher.tags=[ads_ranking_taxonomy_mc_qps_optimization] ``` aps-512_8_remove_fsdp_guards-92ae3972ba tlparse: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/aps-512_8_remove_fsdp_guards-92ae3972ba/attempt_0/version_0/rank_0/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000 Differential Revision: D70522237
1 parent 4273e5d commit ee57ce6

File tree

9 files changed

+1728
-1260
lines changed

9 files changed

+1728
-1260
lines changed

test/inductor/test_quantization.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# Owner(s): ["module: inductor"]
2+
3+
import logging
4+
5+
import numpy as np
6+
7+
import torch
8+
import torch._inductor
9+
import torch._inductor.fx_passes.group_batch_fusion
10+
from torch._dynamo.utils import counters
11+
from torch._inductor.test_case import run_tests, TestCase
12+
from torch.testing._internal.common_utils import IS_LINUX
13+
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, requires_gpu
14+
15+
16+
log = logging.getLogger(__name__)
17+
18+
19+
class TargetCPModule(torch.nn.Module):
20+
def __init__(self):
21+
super().__init__()
22+
23+
def forward(self, x1, x2):
24+
relued = torch.relu(x1)
25+
tanhed = torch.tanh(relued)
26+
tensor = torch.matmul(
27+
tanhed,
28+
x2,
29+
)
30+
return tensor
31+
32+
33+
class FeedforwardNN(torch.nn.Module):
34+
def __init__(self):
35+
super().__init__()
36+
self.fc1 = torch.nn.Linear(1, 64)
37+
self.fc2 = torch.nn.Linear(64, 64)
38+
self.fc3 = torch.nn.Linear(64, 64)
39+
self.fc4 = torch.nn.Linear(64, 1)
40+
41+
def forward(self, x):
42+
x = torch.relu(self.fc1(x))
43+
tanh_x = torch.tanh(x)
44+
x = torch.relu(self.fc2(x))
45+
x = torch.relu(self.fc3(tanh_x))
46+
x = self.fc4(x)
47+
return x
48+
49+
50+
class TestQuantization(TestCase):
51+
def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3):
52+
if len(set(ref_dict.keys())) != len(set(res_dict.keys())):
53+
return False
54+
for key1 in ref_dict.keys():
55+
key2 = "_orig_mod." + key1
56+
assert key2 in res_dict, f"{key1} does not exist in traced module"
57+
# if both of them are None, continue
58+
if (
59+
not isinstance(ref_dict[key1], torch.Tensor)
60+
and not isinstance(res_dict[key2], torch.Tensor)
61+
and ref_dict[key1] is None
62+
and res_dict[key2] is None
63+
):
64+
log.info(
65+
"None found with key1 and value 1: %s, %s, key2 and value2 %s, %s",
66+
key1,
67+
ref_dict[key1],
68+
key2,
69+
res_dict[key2],
70+
)
71+
continue
72+
elif not torch.allclose(
73+
ref_dict[key1], res_dict[key2], rtol=rtol, atol=atol, equal_nan=True
74+
):
75+
log.info(
76+
"gradient mismatch for eager and compiled modules, with eager: %s and compiled: %s",
77+
ref_dict[key1],
78+
res_dict[key2],
79+
)
80+
return False
81+
return True
82+
83+
def compare_pred(self, module, traced, input, rtol=1e-3, atol=1e-3):
84+
ref = module(*input)
85+
res = traced(*input)
86+
self.assertEqual(ref, res, rtol=rtol, atol=atol)
87+
88+
def compare_parameters(self, module, traced, rtol=1e-3, atol=1e-3):
89+
ref_params = dict(module.named_parameters())
90+
res_params = dict(traced.named_parameters())
91+
self.assertTrue(self.compare_dict_tensors(ref_params, res_params, rtol, atol))
92+
93+
def compare_gradients(self, module, traced, rtol=1e-3, atol=1e-3):
94+
ref_grad = {key: param.grad for key, param in module.named_parameters()}
95+
res_grad = {key: param.grad for key, param in traced.named_parameters()}
96+
self.assertTrue(
97+
self.compare_dict_tensors(ref_grad, res_grad, rtol=rtol, atol=atol)
98+
)
99+
100+
@requires_gpu()
101+
@torch._inductor.config.patch(
102+
pre_grad_fusion_options={},
103+
post_grad_fusion_options={
104+
"activation_quantization_aten_pass": {"quant_type": "torch.float8_e5m2"},
105+
},
106+
)
107+
def test_activation_quantization_aten(self):
108+
counters.clear()
109+
module = TargetCPModule().to(GPU_TYPE)
110+
input = [
111+
torch.rand(
112+
(16, 10), requires_grad=True, device=GPU_TYPE, dtype=torch.bfloat16
113+
),
114+
torch.rand(
115+
(10, 16), requires_grad=True, device=GPU_TYPE, dtype=torch.bfloat16
116+
),
117+
]
118+
traced = torch.compile(module)
119+
ref = module(*input)
120+
res = traced(*input)
121+
self.compare_pred(module, traced, input)
122+
ref.sum().backward()
123+
res.sum().backward()
124+
self.compare_parameters(module, traced)
125+
self.compare_gradients(module, traced)
126+
self.assertEqual(
127+
counters["inductor"]["activation_quantization_fwd_aten_pass"], 1
128+
)
129+
self.assertEqual(
130+
counters["inductor"]["activation_quantization_bwd_aten_pass"], 1
131+
)
132+
self.assertTrue(torch.allclose(ref, res))
133+
counters.clear()
134+
135+
module = FeedforwardNN().to(GPU_TYPE)
136+
X = np.linspace(-10, 10, 100).reshape(-1, 1).astype(np.float32)
137+
input = [
138+
torch.from_numpy(X).to(GPU_TYPE),
139+
]
140+
traced = torch.compile(module)
141+
ref = module(*input)
142+
res = traced(*input)
143+
self.compare_pred(module, traced, input)
144+
ref.sum().backward()
145+
res.sum().backward()
146+
self.compare_parameters(module, traced)
147+
self.compare_gradients(module, traced)
148+
self.assertEqual(
149+
counters["inductor"]["activation_quantization_fwd_aten_pass"], 1
150+
)
151+
self.assertEqual(
152+
counters["inductor"]["activation_quantization_bwd_aten_pass"], 1
153+
)
154+
self.assertTrue(torch.allclose(ref, res))
155+
counters.clear()
156+
157+
158+
if __name__ == "__main__":
159+
if IS_LINUX and HAS_GPU:
160+
run_tests()

torch/_dynamo/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4566,3 +4566,7 @@ def maybe_disable_inference_mode_for_fake_prop() -> Generator[None, None, None]:
45664566
yield
45674567
else:
45684568
yield
4569+
4570+
4571+
def is_node_meta_valid(node: Optional[torch.fx.Node]):
4572+
return node is None or "example_value" in node.meta or "val" in node.meta

0 commit comments

Comments
 (0)