Skip to content

Commit 4643bba

Browse files
Qualcomm AI Engine Direct - multi-method support in to_edge_transform_and_lower_to_qnn
Summary: - add support for multi-method in to_edge_transform_and_lower_to_qnn - deprecate capture_program in llama.py
1 parent aed9c7e commit 4643bba

File tree

3 files changed

+156
-110
lines changed

3 files changed

+156
-110
lines changed

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 25 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
generate_htp_compiler_spec,
4848
generate_qnn_executorch_compiler_spec,
4949
PyQnnManagerAdaptor,
50-
QnnPartitioner,
5150
rewrite_prepared_observer,
5251
skip_annotation,
5352
to_edge_transform_and_lower_to_qnn,
@@ -89,12 +88,8 @@
8988
from executorch.examples.models.torchvision_vit.model import TorchVisionViTModel
9089

9190
from executorch.examples.models.wav2letter import Wav2LetterModel
92-
from executorch.exir import EdgeProgramManager, to_edge
93-
from executorch.exir.backend.backend_api import (
94-
disable_validation,
95-
MethodProgramsPartitionerSpec,
96-
to_backend,
97-
)
91+
from executorch.exir import to_edge
92+
from executorch.exir.backend.backend_api import disable_validation
9893

9994

10095
class TestQNNFloatingPointOperator(TestQNN):
@@ -2701,22 +2696,18 @@ def test_qnn_backend_multi_graphs(self):
27012696
)
27022697
for graph_name in graph_names
27032698
]
2704-
# TODO: retire capture_program once we figure out how to extract
2705-
# intermediate graph from official lowering API
2706-
edge_progs = {
2707-
graph_name: capture_program(module, sample_input).exported_program
2708-
for graph_name, module, sample_input in zip(
2709-
graph_names, modules, sample_inputs
2710-
)
2711-
}
2712-
partitioners = {
2713-
graph_name: QnnPartitioner(compiler_spec)
2714-
for graph_name, compiler_spec in zip(graph_names, compiler_specs)
2715-
}
2716-
lowered_ep_dict = to_backend(
2717-
MethodProgramsPartitionerSpec(edge_progs, partitioners)
2699+
2700+
modules_dict = {}
2701+
sample_inputs_dict = {}
2702+
compiler_specs_dict = {}
2703+
for i, graph_name in enumerate(graph_names):
2704+
modules_dict[graph_name] = modules[i]
2705+
sample_inputs_dict[graph_name] = sample_inputs[i]
2706+
compiler_specs_dict[graph_name] = compiler_specs[i]
2707+
delegated_program = to_edge_transform_and_lower_to_qnn(
2708+
modules_dict, sample_inputs_dict, compiler_specs_dict
27182709
)
2719-
executorch_prog = EdgeProgramManager(lowered_ep_dict).to_executorch()
2710+
executorch_prog = delegated_program.to_executorch()
27202711
for index, module in enumerate(modules):
27212712
self.verify_output(
27222713
module=module,
@@ -3375,28 +3366,21 @@ def test_qnn_backend_multi_graphs(self):
33753366
)
33763367
for graph_name in graph_names
33773368
]
3378-
# TODO: retire capture_program once we figure out how to extract
3379-
# intermediate graph from official lowering API
3380-
for i, module in enumerate(modules):
3381-
module_exported = torch.export.export(module, sample_inputs[i]).module()
3369+
modules_dict = {}
3370+
sample_inputs_dict = {}
3371+
compiler_specs_dict = {}
3372+
for i, graph_name in enumerate(graph_names):
3373+
module_exported = torch.export.export(modules[i], sample_inputs[i]).module()
33823374
module_prepared = prepare_pt2e(module_exported, make_quantizer())
33833375
module_prepared(*sample_inputs[i])
3384-
modules[i] = convert_pt2e(module_prepared)
3385-
3386-
edge_progs = {
3387-
graph_name: capture_program(module, sample_input).exported_program
3388-
for graph_name, module, sample_input in zip(
3389-
graph_names, modules, sample_inputs
3390-
)
3391-
}
3392-
partitioners = {
3393-
graph_name: QnnPartitioner(compiler_spec)
3394-
for graph_name, compiler_spec in zip(graph_names, compiler_specs)
3395-
}
3396-
lowered_ep_dict = to_backend(
3397-
MethodProgramsPartitionerSpec(edge_progs, partitioners)
3376+
modules_dict[graph_name] = convert_pt2e(module_prepared)
3377+
sample_inputs_dict[graph_name] = sample_inputs[i]
3378+
compiler_specs_dict[graph_name] = compiler_specs[i]
3379+
delegated_program = to_edge_transform_and_lower_to_qnn(
3380+
modules_dict, sample_inputs_dict, compiler_specs_dict
33983381
)
3399-
executorch_prog = EdgeProgramManager(lowered_ep_dict).to_executorch()
3382+
3383+
executorch_prog = delegated_program.to_executorch()
34003384
for index, module in enumerate(modules):
34013385
self.verify_output(
34023386
module=module,

backends/qualcomm/utils/utils.py

Lines changed: 104 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -317,60 +317,126 @@ def get_decomp_table(passes_job) -> Dict[torch._ops.OperatorBase, Callable]:
317317

318318

319319
def to_edge_transform_and_lower_to_qnn(
320-
module: Union[torch.nn.Module, torch.fx.GraphModule],
321-
inputs: Tuple[torch.Tensor],
322-
compiler_specs: List[CompileSpec],
320+
module: Union[
321+
torch.nn.Module,
322+
torch.fx.GraphModule,
323+
Dict[str, torch.nn.Module],
324+
Dict[str, torch.fx.GraphModule],
325+
],
326+
inputs: Union[Tuple[torch.Tensor], Dict[str, Tuple[torch.Tensor]]],
327+
compiler_specs: Union[List[Any], Dict[str, List[Any]]],
323328
constant_methods: Optional[Dict[str, Any]] = None,
324329
dynamic_shapes: Optional[Dict] = None,
325330
dep_table: Optional[Dict] = None,
326-
passes_job: Optional[OrderedDict] = None,
331+
passes_job: Optional[Union[OrderedDict, Dict[str, OrderedDict]]] = None,
327332
skip_node_id_set: Optional[set] = None,
328333
skip_node_op_set: Optional[set] = None,
329334
) -> EdgeProgramManager:
330335
"""
331-
Transforms and lowers a given PyTorch module to QNN backend.
336+
Transforms and lowers a given PyTorch module to the QNN backend.
332337
333338
Args:
334-
module (Union[torch.nn.Module, torch.fx.GraphModule]): The PyTorch module or fx.GraphModule to be transformed.
335-
inputs (Tuple[torch.Tensor]): The input tensors for the module.
336-
compiler_specs (List[CompileSpec]): Compiler specs for Qualcomm AI Engine Direct.
337-
constant_methods (Optional[Dict[str, Any]]): An optional dictionary of method name to the constant value
338-
returned by that method in eager mode. Often used to store config information on
339-
Edge models.
340-
dynamic_shapes (Optional[Dict]): Information about dynamic shapes.
341-
dep_table (Optional[Dict]): Dependency table for the transformation passes.
342-
passes_job (Optional[OrderedDict]): Ordered dictionary of transformation passes.
343-
skip_node_id_set (Optional[set]): Set of node IDs to skip during partitioning.
344-
skip_node_op_set (Optional[set]): Set of node operations to skip during partitioning.
339+
module (Union[torch.nn.Module, torch.fx.GraphModule,Dict[str, torch.nn.Module], Dict[str, torch.fx.GraphModule]]):
340+
The PyTorch module or fx.GraphModule to be transformed.
341+
inputs (Union[Tuple[torch.Tensor], Dict[str, Tuple[torch.Tensor]]]):
342+
The input tensors for the module.
343+
compiler_specs (Union[List[Any], Dict[str, List[Any]]]):
344+
Compiler specifications for Qualcomm AI Engine Direct.
345+
constant_methods (Optional[Dict[str, Any]]):
346+
An optional dictionary mapping method names to constant values returned by those methods in eager mode.
347+
Often used to store configuration information on Edge models.
348+
dynamic_shapes (Optional[Dict]):
349+
Information about dynamic shapes.
350+
dep_table (Optional[Dict]):
351+
Dependency table for the transformation passes.
352+
passes_job (Optional[Union[OrderedDict, Dict[str, OrderedDict]]]):
353+
Ordered dictionary of transformation passes.
354+
skip_node_id_set (Optional[set]):
355+
Set of node IDs to skip during partitioning.
356+
skip_node_op_set (Optional[set]):
357+
Set of node operations to skip during partitioning.
345358
346359
Returns:
347-
EdgeProgramManager: The manager for the edge program after transformation and lowering.
360+
EdgeProgramManager:
361+
The manager for the edge program after transformation and lowering.
348362
"""
349-
ep = torch.export.export(module, inputs, dynamic_shapes=dynamic_shapes, strict=True)
350-
# This transformation is primarily intended for the LiftConstantScalarOperands pass
351-
# to avoid creating temporary tensors in the operation builder.
352-
# However, this pass will create a get_attr node, which should be converted
353-
# into a lifted tensor constant by the lift_constant_tensor_pass.
354-
# If placed in the to_edge_transform_passes, it will be executed
355-
# after the lift_constant_tensor_pass, causing the operation builder
356-
# to fail to correctly retrieve the parameter by the get_parameter.
357-
ep = QnnPassManager().transform_for_export_pipeline(ep)
358-
transform_passes = QnnPassManager().get_to_edge_transform_passes(
359-
ep, passes_job=passes_job, dep_table=dep_table
360-
)
361-
qnn_partitioner = QnnPartitioner(
362-
compiler_specs,
363-
skip_node_id_set=skip_node_id_set,
364-
skip_node_op_set=skip_node_op_set,
365-
)
366-
edge_program_manager = to_edge_transform_and_lower(
367-
ep,
363+
364+
def ensure_graph_specific_dict(value, graph_names, callback=None):
365+
"""
366+
Ensures the input value is a dictionary with keys matching the provided graph names.
367+
If the input is not a dictionary or its keys do not match the graph names, a new dictionary
368+
is created with the graph names as keys and the input value assigned to each key.
369+
370+
Examples:
371+
1. Input is None:
372+
>>> ensure_graph_specific_dict(None, ["forward1", "forward2"])
373+
{'forward1': None, 'forward2': None}
374+
375+
2. Input is a single value:
376+
>>> ensure_graph_specific_dict(input, ["forward1", "forward2"])
377+
{'forward1': input, 'forward2': input}
378+
379+
3. Input is a non-graph specific dict:
380+
>>> ensure_graph_specific_dict({Any: input}, ["forward1", "forward2"])
381+
{'forward1': {Any: input}, 'forward2': {Any: input}}
382+
"""
383+
if value is None:
384+
return {graph_name: None for graph_name in graph_names}
385+
if isinstance(value, dict) and graph_names == value.keys():
386+
return value
387+
return {graph_name: value for graph_name in graph_names}
388+
389+
if not isinstance(module, dict):
390+
module = {"forward": module}
391+
392+
# Ensure attributes are graph-specific dictionaries
393+
graph_names = module.keys()
394+
inputs = ensure_graph_specific_dict(inputs, graph_names)
395+
compiler_specs = ensure_graph_specific_dict(compiler_specs, graph_names)
396+
dynamic_shapes = ensure_graph_specific_dict(dynamic_shapes, graph_names)
397+
dep_table = ensure_graph_specific_dict(dep_table, graph_names)
398+
passes_job = ensure_graph_specific_dict(passes_job, graph_names)
399+
400+
# Prepare programs and partitioners
401+
aten_programs = {}
402+
transform_passes = {}
403+
qnn_partitioners = {
404+
graph_name: [
405+
QnnPartitioner(
406+
compiler_specs[graph_name],
407+
skip_node_id_set=skip_node_id_set,
408+
skip_node_op_set=skip_node_op_set,
409+
)
410+
]
411+
for graph_name in graph_names
412+
}
413+
414+
for graph_name, m in module.items():
415+
ep = torch.export.export(
416+
m,
417+
inputs[graph_name],
418+
dynamic_shapes=dynamic_shapes[graph_name],
419+
strict=True,
420+
)
421+
# This transformation is primarily intended for the LiftConstantScalarOperands pass
422+
# to avoid creating temporary tensors in the operation builder.
423+
# However, this pass will create a get_attr node, which should be converted
424+
# into a lifted tensor constant by the lift_constant_tensor_pass.
425+
# If placed in the to_edge_transform_passes, it will be executed
426+
# after the lift_constant_tensor_pass, causing the operation builder
427+
# to fail to correctly retrieve the parameter by the get_parameter.
428+
aten_programs[graph_name] = QnnPassManager().transform_for_export_pipeline(ep)
429+
transform_passes[graph_name] = QnnPassManager().get_to_edge_transform_passes(
430+
ep, passes_job=passes_job[graph_name], dep_table=dep_table[graph_name]
431+
)
432+
433+
return to_edge_transform_and_lower(
434+
aten_programs,
368435
transform_passes=transform_passes,
369-
partitioner=[qnn_partitioner],
436+
partitioner=qnn_partitioners,
370437
constant_methods=constant_methods,
371438
compile_config=qnn_edge_config(),
372439
)
373-
return edge_program_manager
374440

375441

376442
def capture_program(

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -692,44 +692,43 @@ def permute(w, heads):
692692
)
693693
for graph_name in graph_names
694694
]
695-
696-
# TODO: retire capture_program once we figure out how to extract
697-
# intermediate graph from official lowering API
698-
edge_progs = {
699-
graph_name: capture_program(
700-
module=llama_instance.llama_graph_module,
701-
inputs=sample_input,
702-
dep_table=llama_instance.dep_table,
703-
passes_job=llama_instance.passes_job,
704-
).exported_program
705-
for graph_name, llama_instance, sample_input in zip(
706-
graph_names, llama_instance_list, sample_inputs_list
707-
)
708-
}
709-
for n in edge_progs[graph_names[0]].graph.nodes:
695+
edge_prog_mgr = to_edge_transform_and_lower_to_qnn(
696+
{
697+
graph_name: instance.llama_graph_module
698+
for graph_name, instance in zip(graph_names, llama_instance_list)
699+
},
700+
{
701+
graph_name: inputs
702+
for graph_name, inputs in zip(graph_names, sample_inputs_list)
703+
},
704+
{
705+
graph_name: compiler_spec
706+
for graph_name, compiler_spec in zip(graph_names, compiler_specs)
707+
},
708+
llama_instance_list[1].llama_meta,
709+
dep_table={
710+
graph_name: instance.dep_table
711+
for graph_name, instance in zip(graph_names, llama_instance_list)
712+
},
713+
passes_job={
714+
graph_name: instance.passes_job
715+
for graph_name, instance in zip(graph_names, llama_instance_list)
716+
},
717+
skip_node_op_set={"llama.fallback.default"},
718+
)
719+
for n in list(edge_prog_mgr._edge_programs.values())[0].graph.nodes:
710720
if n.op == "output":
711721
for node, output_encoding in n.meta[QCOM_QUANT_ATTRS_MAP].items():
712722
if node.meta["val"].size() in llama_instance_list[0].io_shape:
713723
quant_attrs = output_encoding
714724

715-
partitioners = {
716-
graph_name: QnnPartitioner(
717-
compiler_spec, skip_node_op_set={"llama.fallback.default"}
718-
)
719-
for graph_name, compiler_spec in zip(graph_names, compiler_specs)
720-
}
721-
722-
lowered_ep_dict = to_backend(
723-
MethodProgramsPartitionerSpec(edge_progs, partitioners)
724-
)
725-
726725
if args.num_sharding > 1:
727726
# TODO: add arg parser of spill_fill_size since weight-sharing based
728727
# context binaries cannot be opened in x86 host
729728
pass
730729

731730
if args.verbose:
732-
for ep in lowered_ep_dict.values():
731+
for ep in edge_prog_mgr._edge_programs.values():
733732
print_delegation_info(ep.graph_module)
734733

735734
executorch_config = ExecutorchBackendConfig(
@@ -743,10 +742,7 @@ def permute(w, heads):
743742
),
744743
extract_delegate_segments=True,
745744
)
746-
exec_prog_mgr = EdgeProgramManager(
747-
edge_programs=lowered_ep_dict,
748-
constant_methods=llama_instance_list[1].llama_meta,
749-
).to_executorch(executorch_config)
745+
exec_prog_mgr = edge_prog_mgr.to_executorch(executorch_config)
750746

751747
with open(f"{args.artifact}/{pte_filename}.pte", "wb") as file:
752748
exec_prog_mgr.write_to_file(file)

0 commit comments

Comments
 (0)