@@ -171,27 +171,14 @@ def _build_call_graph(
171
171
g .add_edge (bloq , callee , n = n )
172
172
173
173
174
- def _compute_sigma (root_bloq : Bloq , g : nx .DiGraph ) -> Dict [Bloq , Union [int , sympy .Expr ]]:
175
- """Iterate over nodes to sum up the counts of leaf bloqs."""
176
- bloq_sigmas : Dict [Bloq , Dict [Bloq , Union [int , sympy .Expr ]]] = defaultdict (
177
- lambda : defaultdict (lambda : 0 )
178
- )
179
- for bloq in reversed (list (nx .topological_sort (g ))):
180
- callees = list (g .successors (bloq ))
181
- sigma = bloq_sigmas [bloq ]
182
- if not callees :
183
- # 1. `bloq` is a leaf node. Its count is one of itself.
184
- sigma [bloq ] = 1
185
- continue
186
-
187
- for callee in callees :
188
- callee_sigma = bloq_sigmas [callee ]
189
- # 2. Otherwise, sigma of the caller is sum(n * sigma of callee) for all the callees.
190
- n = g .edges [bloq , callee ]['n' ]
191
- for k in callee_sigma .keys ():
192
- sigma [k ] += callee_sigma [k ] * n
174
+ def _compute_sigma (
175
+ root_bloq : Bloq , g : nx .DiGraph , generalizer : 'GeneralizerT'
176
+ ) -> Dict [Bloq , Union [int , sympy .Expr ]]:
177
+ """Shim for compatibility with old 'sigma' that used the call graph to count leaf bloqs."""
178
+ from qualtran .resource_counting import BloqCount , get_cost_value
193
179
194
- return dict (bloq_sigmas [root_bloq ])
180
+ leaf_counts = BloqCount .for_call_graph_leaf_bloqs (g )
181
+ return get_cost_value (root_bloq , leaf_counts , generalizer = generalizer )
195
182
196
183
197
184
def get_bloq_call_graph (
@@ -239,7 +226,7 @@ def get_bloq_call_graph(
239
226
if bloq is None :
240
227
raise ValueError ("You can't generalize away the root bloq." )
241
228
_build_call_graph (bloq , generalizer , ssa , keep , max_depth , g = g , depth = 0 )
242
- sigma = _compute_sigma (bloq , g )
229
+ sigma = _compute_sigma (bloq , g , generalizer )
243
230
return g , sigma
244
231
245
232
0 commit comments