@@ -1616,6 +1616,40 @@ def _export(
1616
1616
return onnx_program
1617
1617
1618
1618
1619
+ class OnnxModelFromDynamoAotOptimize (OnnxModelFromDynamo ):
1620
+ """Dynamo and Fx based export, with AOT optimize post export. `torch.onnx.dynamo_export`."""
1621
+
1622
+ _COMPILER_NAME = "dynamo_aot_optimize"
1623
+
1624
+ def _export (
1625
+ self , model , example_inputs , output_path : str
1626
+ ) -> torch .onnx .ONNXProgram :
1627
+ if self .copy_before_export :
1628
+ # Deepcopy model before export to avoid modification to baseline model.
1629
+ model , example_inputs = self .deepcopy_model_and_inputs_to_device (
1630
+ model , example_inputs , self ._determine_deepcopy_target_device ()
1631
+ )
1632
+
1633
+ example_args , example_kwargs = _normalize_bench_inputs (example_inputs )
1634
+ options = torch .onnx .ExportOptions (dynamic_shapes = self ._dynamic_shapes )
1635
+ export_output = torch .onnx .dynamo_export (
1636
+ model , * example_args , ** example_kwargs , export_options = options
1637
+ )
1638
+
1639
+ import onnx
1640
+ from onnxscript .rewriter .onnxruntime import rewrite
1641
+
1642
+ model_proto = rewrite (export_output .model_proto )
1643
+ onnx .save_model (
1644
+ model_proto ,
1645
+ output_path ,
1646
+ save_as_external_data = True ,
1647
+ all_tensors_to_one_file = True ,
1648
+ )
1649
+
1650
+ return export_output
1651
+
1652
+
1619
1653
class _OnnxPatch :
1620
1654
@classmethod
1621
1655
def patch_non_tensor_outputs (cls , correct_result , new_result , fp64_outputs ):
@@ -3475,6 +3509,12 @@ def get_example_inputs(self):
3475
3509
action = "store_true" ,
3476
3510
help = "Measure speedup with Dynamo ONNX AOT Inline, i.e. `torch.onnx.dynamo_export`" ,
3477
3511
)
3512
+ group .add_argument (
3513
+ "--dynamo-onnx-aot-optimize" ,
3514
+ "--dynamo_onnx_aot_optimize" ,
3515
+ action = "store_true" ,
3516
+ help = "Measure speedup with Dynamo ONNX w/ ort fusions, i.e. `torch.onnx.dynamo_export`" ,
3517
+ )
3478
3518
group .add_argument (
3479
3519
"--backend" ,
3480
3520
choices = torch ._dynamo .list_backends (exclude_tags = None ),
@@ -3839,6 +3879,17 @@ def run(runner, args, original_dir=None):
3839
3879
experiment = speedup_experiment_onnx
3840
3880
output_filename = "dynamo_onnx_aot_inline.csv"
3841
3881
current_onnx_compiler = "dynamo"
3882
+ elif args .dynamo_onnx_aot_optimize :
3883
+ optimize_ctx = functools .partial (
3884
+ optimize_onnx_ctx ,
3885
+ args .output_directory or "." ,
3886
+ OnnxModelFromDynamoAotOptimize ,
3887
+ dynamic_shapes = args .dynamic_shapes ,
3888
+ copy_before_export = args .performance ,
3889
+ )
3890
+ experiment = speedup_experiment_onnx
3891
+ output_filename = "dynamo_onnx_aot_optimize.csv"
3892
+ current_onnx_compiler = "dynamo"
3842
3893
elif args .speedup_dynamo_ts :
3843
3894
optimize_ctx = torch ._dynamo .optimize ("ts" , nopython = args .nopython )
3844
3895
experiment = speedup_experiment
0 commit comments