@@ -317,60 +317,126 @@ def get_decomp_table(passes_job) -> Dict[torch._ops.OperatorBase, Callable]:
317
317
318
318
319
319
def to_edge_transform_and_lower_to_qnn (
320
- module : Union [torch .nn .Module , torch .fx .GraphModule ],
321
- inputs : Tuple [torch .Tensor ],
322
- compiler_specs : List [CompileSpec ],
320
+ module : Union [
321
+ torch .nn .Module ,
322
+ torch .fx .GraphModule ,
323
+ Dict [str , torch .nn .Module ],
324
+ Dict [str , torch .fx .GraphModule ],
325
+ ],
326
+ inputs : Union [Tuple [torch .Tensor ], Dict [str , Tuple [torch .Tensor ]]],
327
+ compiler_specs : Union [List [Any ], Dict [str , List [Any ]]],
323
328
constant_methods : Optional [Dict [str , Any ]] = None ,
324
329
dynamic_shapes : Optional [Dict ] = None ,
325
330
dep_table : Optional [Dict ] = None ,
326
- passes_job : Optional [OrderedDict ] = None ,
331
+ passes_job : Optional [Union [ OrderedDict , Dict [ str , OrderedDict ]] ] = None ,
327
332
skip_node_id_set : Optional [set ] = None ,
328
333
skip_node_op_set : Optional [set ] = None ,
329
334
) -> EdgeProgramManager :
330
335
"""
331
- Transforms and lowers a given PyTorch module to QNN backend.
336
+ Transforms and lowers a given PyTorch module to the QNN backend.
332
337
333
338
Args:
334
- module (Union[torch.nn.Module, torch.fx.GraphModule]): The PyTorch module or fx.GraphModule to be transformed.
335
- inputs (Tuple[torch.Tensor]): The input tensors for the module.
336
- compiler_specs (List[CompileSpec]): Compiler specs for Qualcomm AI Engine Direct.
337
- constant_methods (Optional[Dict[str, Any]]): An optional dictionary of method name to the constant value
338
- returned by that method in eager mode. Often used to store config information on
339
- Edge models.
340
- dynamic_shapes (Optional[Dict]): Information about dynamic shapes.
341
- dep_table (Optional[Dict]): Dependency table for the transformation passes.
342
- passes_job (Optional[OrderedDict]): Ordered dictionary of transformation passes.
343
- skip_node_id_set (Optional[set]): Set of node IDs to skip during partitioning.
344
- skip_node_op_set (Optional[set]): Set of node operations to skip during partitioning.
339
+ module (Union[torch.nn.Module, torch.fx.GraphModule,Dict[str, torch.nn.Module], Dict[str, torch.fx.GraphModule]]):
340
+ The PyTorch module or fx.GraphModule to be transformed.
341
+ inputs (Union[Tuple[torch.Tensor], Dict[str, Tuple[torch.Tensor]]]):
342
+ The input tensors for the module.
343
+ compiler_specs (Union[List[Any], Dict[str, List[Any]]]):
344
+ Compiler specifications for Qualcomm AI Engine Direct.
345
+ constant_methods (Optional[Dict[str, Any]]):
346
+ An optional dictionary mapping method names to constant values returned by those methods in eager mode.
347
+ Often used to store configuration information on Edge models.
348
+ dynamic_shapes (Optional[Dict]):
349
+ Information about dynamic shapes.
350
+ dep_table (Optional[Dict]):
351
+ Dependency table for the transformation passes.
352
+ passes_job (Optional[Union[OrderedDict, Dict[str, OrderedDict]]]):
353
+ Ordered dictionary of transformation passes.
354
+ skip_node_id_set (Optional[set]):
355
+ Set of node IDs to skip during partitioning.
356
+ skip_node_op_set (Optional[set]):
357
+ Set of node operations to skip during partitioning.
345
358
346
359
Returns:
347
- EdgeProgramManager: The manager for the edge program after transformation and lowering.
360
+ EdgeProgramManager:
361
+ The manager for the edge program after transformation and lowering.
348
362
"""
349
- ep = torch .export .export (module , inputs , dynamic_shapes = dynamic_shapes , strict = True )
350
- # This transformation is primarily intended for the LiftConstantScalarOperands pass
351
- # to avoid creating temporary tensors in the operation builder.
352
- # However, this pass will create a get_attr node, which should be converted
353
- # into a lifted tensor constant by the lift_constant_tensor_pass.
354
- # If placed in the to_edge_transform_passes, it will be executed
355
- # after the lift_constant_tensor_pass, causing the operation builder
356
- # to fail to correctly retrieve the parameter by the get_parameter.
357
- ep = QnnPassManager ().transform_for_export_pipeline (ep )
358
- transform_passes = QnnPassManager ().get_to_edge_transform_passes (
359
- ep , passes_job = passes_job , dep_table = dep_table
360
- )
361
- qnn_partitioner = QnnPartitioner (
362
- compiler_specs ,
363
- skip_node_id_set = skip_node_id_set ,
364
- skip_node_op_set = skip_node_op_set ,
365
- )
366
- edge_program_manager = to_edge_transform_and_lower (
367
- ep ,
363
+
364
+ def ensure_graph_specific_dict (value , graph_names , callback = None ):
365
+ """
366
+ Ensures the input value is a dictionary with keys matching the provided graph names.
367
+ If the input is not a dictionary or its keys do not match the graph names, a new dictionary
368
+ is created with the graph names as keys and the input value assigned to each key.
369
+
370
+ Examples:
371
+ 1. Input is None:
372
+ >>> ensure_graph_specific_dict(None, ["forward1", "forward2"])
373
+ {'forward1': None, 'forward2': None}
374
+
375
+ 2. Input is a single value:
376
+ >>> ensure_graph_specific_dict(input, ["forward1", "forward2"])
377
+ {'forward1': input, 'forward2': input}
378
+
379
+ 3. Input is a non-graph specific dict:
380
+ >>> ensure_graph_specific_dict({Any: input}, ["forward1", "forward2"])
381
+ {'forward1': {Any: input}, 'forward2': {Any: input}}
382
+ """
383
+ if value is None :
384
+ return {graph_name : None for graph_name in graph_names }
385
+ if isinstance (value , dict ) and graph_names == value .keys ():
386
+ return value
387
+ return {graph_name : value for graph_name in graph_names }
388
+
389
+ if not isinstance (module , dict ):
390
+ module = {"forward" : module }
391
+
392
+ # Ensure attributes are graph-specific dictionaries
393
+ graph_names = module .keys ()
394
+ inputs = ensure_graph_specific_dict (inputs , graph_names )
395
+ compiler_specs = ensure_graph_specific_dict (compiler_specs , graph_names )
396
+ dynamic_shapes = ensure_graph_specific_dict (dynamic_shapes , graph_names )
397
+ dep_table = ensure_graph_specific_dict (dep_table , graph_names )
398
+ passes_job = ensure_graph_specific_dict (passes_job , graph_names )
399
+
400
+ # Prepare programs and partitioners
401
+ aten_programs = {}
402
+ transform_passes = {}
403
+ qnn_partitioners = {
404
+ graph_name : [
405
+ QnnPartitioner (
406
+ compiler_specs [graph_name ],
407
+ skip_node_id_set = skip_node_id_set ,
408
+ skip_node_op_set = skip_node_op_set ,
409
+ )
410
+ ]
411
+ for graph_name in graph_names
412
+ }
413
+
414
+ for graph_name , m in module .items ():
415
+ ep = torch .export .export (
416
+ m ,
417
+ inputs [graph_name ],
418
+ dynamic_shapes = dynamic_shapes [graph_name ],
419
+ strict = True ,
420
+ )
421
+ # This transformation is primarily intended for the LiftConstantScalarOperands pass
422
+ # to avoid creating temporary tensors in the operation builder.
423
+ # However, this pass will create a get_attr node, which should be converted
424
+ # into a lifted tensor constant by the lift_constant_tensor_pass.
425
+ # If placed in the to_edge_transform_passes, it will be executed
426
+ # after the lift_constant_tensor_pass, causing the operation builder
427
+ # to fail to correctly retrieve the parameter by the get_parameter.
428
+ aten_programs [graph_name ] = QnnPassManager ().transform_for_export_pipeline (ep )
429
+ transform_passes [graph_name ] = QnnPassManager ().get_to_edge_transform_passes (
430
+ ep , passes_job = passes_job [graph_name ], dep_table = dep_table [graph_name ]
431
+ )
432
+
433
+ return to_edge_transform_and_lower (
434
+ aten_programs ,
368
435
transform_passes = transform_passes ,
369
- partitioner = [ qnn_partitioner ] ,
436
+ partitioner = qnn_partitioners ,
370
437
constant_methods = constant_methods ,
371
438
compile_config = qnn_edge_config (),
372
439
)
373
- return edge_program_manager
374
440
375
441
376
442
def capture_program (
0 commit comments