Skip to content

Commit 2c5bc4a

Browse files
anijain2305facebook-github-bot
authored andcommitted
Switch off inference mode during compilation (#149321)
Summary: PR does following * Turns `inference_mode` to False and `no_grad` for `convert_frame`, if the inference_mode is on globally. * Turns off inference_mode for fake tensor prop. This ensures that converting from real inference tensor to a fake tensor removes the inference-ness. * Graph breaks on is_inference and is_inference_mode_enabled. X-link: pytorch/pytorch#149321 Approved by: https://github.com/jansel, https://github.com/zou3519 Reviewed By: izaitsevfb Differential Revision: D71451966 fbshipit-source-id: 8b5cd2178fe37f2843238efafdb798df7ced2391
1 parent d1b2abb commit 2c5bc4a

File tree

1 file changed

+38
-0
lines changed
  • userbenchmark/dynamo/dynamobench/_dynamo

1 file changed

+38
-0
lines changed

userbenchmark/dynamo/dynamobench/_dynamo/utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4474,3 +4474,41 @@ def get_optimize_ddp_mode():
44744474
f"Invalid dynamo config optimize_ddp value {mode=}"
44754475
)
44764476
return mode
4477+
4478+
4479+
@contextmanager
4480+
def maybe_disable_inference_mode() -> Generator[None, None, None]:
4481+
"""
4482+
Disables torch.inference_mode for the compilation (still on at runtime).
4483+
This simplifies the compile stack where we can assume that inference_mode
4484+
will always be off.
4485+
4486+
Since inference_mode is equivalent to no_grad + some optimizations (version
4487+
counts etc), we turn on no_grad here. The other optimizations are not
4488+
relevant to torch.compile.
4489+
"""
4490+
is_inference_mode_on = (
4491+
config.fake_tensor_disable_inference_mode and torch.is_inference_mode_enabled()
4492+
)
4493+
if is_inference_mode_on:
4494+
with (
4495+
torch.inference_mode(False),
4496+
torch.no_grad(),
4497+
):
4498+
yield
4499+
else:
4500+
yield
4501+
4502+
4503+
@contextmanager
4504+
def maybe_disable_inference_mode_for_fake_prop() -> Generator[None, None, None]:
4505+
"""
4506+
Turns off tracking of inference_mode for fake tensor propagation. With this
4507+
context manager, when a real tensor is converted to fake tensor, the fake
4508+
tensor looses its inference-ness.
4509+
"""
4510+
if config.fake_tensor_disable_inference_mode:
4511+
with torch._subclasses.meta_utils.disable_inference_mode_for_fake_prop():
4512+
yield
4513+
else:
4514+
yield

0 commit comments

Comments
 (0)