15
15
"""Classes for drawing bloq counts graphs with Graphviz."""
16
16
import abc
17
17
import html
18
- from typing import Any , Dict , Iterable , Optional , Tuple , Union
18
+ import warnings
19
+ from typing import Any , cast , Dict , Iterable , Mapping , Optional , Tuple , TYPE_CHECKING , Union
19
20
20
21
import attrs
21
22
import IPython .display
24
25
import sympy
25
26
26
27
from qualtran import Bloq , CompositeBloq
28
+ from qualtran .symbolics import SymbolicInt
29
+
30
+ if TYPE_CHECKING :
31
+ from qualtran .resource_counting import CostKey , CostValT , GateCounts
27
32
28
33
29
34
class _CallGraphDrawerBase (metaclass = abc .ABCMeta ):
@@ -176,7 +181,15 @@ class GraphvizCallGraph(_CallGraphDrawerBase):
176
181
Each edge is labeled with the number of times the "caller" (predecessor) bloq calls the
177
182
"callee" (successor) bloq.
178
183
179
- This class follows the behavior described in https://github.com/quantumlib/Qualtran/issues/791
184
+ The constructor of this class assumes you have already generated the call graph as a networkx
185
+ graph and constructed any associated data. See the factory method
186
+ `GraphvizCallGraph.from_bloq()` to set up a call graph diagram from a bloq with sensible
187
+ defaults.
188
+
189
+ This class uses a bloq's `__str__` string to title the bloq. Arbitrary additional tabular
190
+ data can be provided with `bloq_data`.
191
+
192
+ This graph drawer is the successor to the `GraphvizCounts` existing drawer,
180
193
and will replace `GraphvizCounts` when all bloqs have been migrated to use `__str__()`.
181
194
182
195
Args:
@@ -193,6 +206,133 @@ def __init__(self, g: nx.DiGraph, bloq_data: Optional[Dict['Bloq', Dict[Any, Any
193
206
194
207
self .bloq_data = bloq_data
195
208
209
+ @classmethod
210
+ def format_qubit_count (cls , val : SymbolicInt ) -> Dict [str , str ]:
211
+ """Format `QubitCount` cost values as a string.
212
+
213
+ Args:
214
+ val: The qubit count value, which should be an integer
215
+
216
+ Returns:
217
+ A dictionary mapping a string cost name to a string cost value.
218
+ """
219
+ return {'Qubits' : f'{ val } ' }
220
+
221
+ @classmethod
222
+ def format_qec_gates_cost (cls , val : 'GateCounts' , agg : Optional [str ] = None ) -> Dict [str , str ]:
223
+ """Format `QECGatesCost` cost values as a string.
224
+
225
+ Args:
226
+ val: The qec gate costs value, which should be a `GateCounts` dataclass.
227
+ agg: One of 'factored', 'total_t', 't_and_ccz', or 'beverland' to
228
+ (optionally) aggregate the gate counts. If not specified, the 'factored'
229
+ approach is used where each type of gate is counted individually. See the
230
+ methods on `GateCounts` for more information.
231
+
232
+ Returns:
233
+ A dictionary mapping string cost names to string cost values.
234
+ """
235
+ labels = {
236
+ 't' : 'Ts' ,
237
+ 'n_t' : 'Ts' ,
238
+ 'toffoli' : 'Toffolis' ,
239
+ 'n_ccz' : 'CCZs' ,
240
+ 'cswap' : 'CSwaps' ,
241
+ 'and_bloq' : 'Ands' ,
242
+ 'clifford' : 'Cliffords' ,
243
+ 'rotation' : 'Rotations' ,
244
+ 'measurement' : 'Measurements' ,
245
+ }
246
+ counts_dict : Mapping [str , SymbolicInt ]
247
+ if agg is None or agg == 'factored' :
248
+ counts_dict = val .asdict ()
249
+ elif agg == 'total_t' :
250
+ counts_dict = {'t' : val .total_t_count ()}
251
+ elif agg == 't_and_ccz' :
252
+ counts_dict = val .total_t_and_ccz_count ()
253
+ elif agg == 'beverland' :
254
+ counts_dict = val .total_beverland_count ()
255
+ else :
256
+ raise ValueError (f"Unknown aggregation mode { agg } ." )
257
+
258
+ return {labels .get (gate_k , gate_k ): f'{ gate_v } ' for gate_k , gate_v in counts_dict .items ()}
259
+
260
+ @classmethod
261
+ def format_cost_data (
262
+ cls ,
263
+ cost_data : Dict ['Bloq' , Dict ['CostKey' , 'CostValT' ]],
264
+ agg_gate_counts : Optional [str ] = None ,
265
+ ) -> Dict ['Bloq' , Dict [str , str ]]:
266
+ """Format `cost_data` as human-readable strings.
267
+
268
+ Args:
269
+ cost_data: The cost data, likely returned from a call to `query_costs()`. This
270
+ class method will delegate to `format_qubit_count` and `format_qec_gates_cost`
271
+ for `QubitCount` and `QECGatesCost` cost keys, respectively.
272
+ agg_gate_counts: One of 'factored', 'total_t', 't_and_ccz', or 'beverland' to
273
+ (optionally) aggregate the gate counts. If not specified, the 'factored'
274
+ approach is used where each type of gate is counted individually. See the
275
+ methods on `GateCounts` for more information.
276
+
277
+ Returns:
278
+ For each bloq key, a table of label/value pairs consisting of
279
+ human-readable labels and formatted values.
280
+ """
281
+ from qualtran .resource_counting import GateCounts , QECGatesCost , QubitCount
282
+
283
+ disp_data : Dict ['Bloq' , Dict [str , str ]] = {}
284
+ for bloq in cost_data .keys ():
285
+ bloq_disp : Dict [str , str ] = {}
286
+ for cost_key , cost_val in cost_data [bloq ].items ():
287
+ if isinstance (cost_key , QubitCount ):
288
+ bloq_disp |= cls .format_qubit_count (cast (SymbolicInt , cost_val ))
289
+ elif isinstance (cost_key , QECGatesCost ):
290
+ assert isinstance (cost_val , GateCounts )
291
+ bloq_disp |= cls .format_qec_gates_cost (cost_val , agg = agg_gate_counts )
292
+ else :
293
+ warnings .warn (f"Unknown cost key { cost_key } " )
294
+ bloq_disp [str (cost_key )] = str (cost_val )
295
+
296
+ disp_data [bloq ] = bloq_disp
297
+ return disp_data
298
+
299
+ @classmethod
300
+ def from_bloq (
301
+ cls , bloq : Bloq , * , max_depth : Optional [int ] = None , agg_gate_counts : Optional [str ] = None
302
+ ) -> 'GraphvizCallGraph' :
303
+ """Draw a bloq call graph.
304
+
305
+ This factory method will generate a call graph from the bloq, query the `QECGatesCost`
306
+ and `QubitCount` costs, format the cost data, and merge it with the call graph
307
+ to create a call graph diagram with annotated costs.
308
+
309
+ For additional customization, users can construct the call graph and bloq data themselves
310
+ and use the normal constructor, or provide minor display customizations by
311
+ overriding the `format_xxx` class methods.
312
+
313
+ Args:
314
+ bloq: The bloq from which we construct the call graph and query the costs.
315
+ max_depth: The maximum depth (from the root bloq) of the call graph to draw. Note
316
+ that the cost computations will walk the whole call graph, but only the nodes
317
+ within this depth will be drawn.
318
+ agg_gate_counts: One of 'factored', 'total_t', 't_and_ccz', or 'beverland' to
319
+ (optionally) aggregate the gate counts. If not specified, the 'factored'
320
+ approach is used where each type of gate is counted individually. See the
321
+ methods on `GateCounts` for more information.
322
+
323
+ Returns:
324
+ A `GraphvizCallGraph` diagram-drawer, whose methods can be used to generate
325
+ graphviz inputs or SVG diagrams.
326
+ """
327
+ from qualtran .resource_counting import QECGatesCost , QubitCount , query_costs
328
+
329
+ call_graph , _ = bloq .call_graph (max_depth = max_depth )
330
+ cost_data : Dict ['Bloq' , Dict [CostKey , Any ]] = query_costs (
331
+ bloq , [QubitCount (), QECGatesCost ()]
332
+ )
333
+ formatted_cost_data = cls .format_cost_data (cost_data , agg_gate_counts = agg_gate_counts )
334
+ return cls (g = call_graph , bloq_data = formatted_cost_data )
335
+
196
336
def get_node_title (self , b : Bloq ):
197
337
return str (b )
198
338
0 commit comments