Skip to content

Commit 35a1b4a

Browse files
committed
Merge branch 'main' into 2024/08/06-t-compl-use-rotation-eps
2 parents f5f69aa + dd3d2fa commit 35a1b4a

File tree

8 files changed

+290
-30
lines changed

8 files changed

+290
-30
lines changed

qualtran/bloqs/for_testing/atom.py

+6
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
Signature,
2929
)
3030
from qualtran.cirq_interop.t_complexity_protocol import TComplexity
31+
from qualtran.resource_counting import CostKey, GateCounts, QECGatesCost
3132

3233
if TYPE_CHECKING:
3334
import quimb.tensor as qtn
@@ -66,6 +67,11 @@ def my_tensors(
6667
)
6768
]
6869

70+
def my_static_costs(self, cost_key: 'CostKey'):
71+
if cost_key == QECGatesCost():
72+
return GateCounts(t=100)
73+
return NotImplemented
74+
6975
def _t_complexity_(self) -> 'TComplexity':
7076
return TComplexity(100)
7177

qualtran/drawing/_show_funcs.py

+49-7
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@
1515
"""Convenience functions for showing rich displays in Jupyter notebook."""
1616

1717
import os
18-
from typing import Dict, Optional, Sequence, TYPE_CHECKING, Union
18+
from typing import Dict, Optional, overload, Sequence, TYPE_CHECKING, Union
1919

2020
import IPython.display
2121
import ipywidgets
2222

23-
from .bloq_counts_graph import format_counts_sigma, GraphvizCounts
23+
from qualtran import Bloq
24+
25+
from .bloq_counts_graph import format_counts_sigma, GraphvizCallGraph, GraphvizCounts
2426
from .flame_graph import get_flame_graph_svg_data
2527
from .graphviz import PrettyGraphDrawer, TypedGraphDrawer
2628
from .musical_score import draw_musical_score, get_musical_score_data
@@ -30,8 +32,6 @@
3032
import networkx as nx
3133
import sympy
3234

33-
from qualtran import Bloq
34-
3535

3636
def show_bloq(bloq: 'Bloq', type: str = 'graph'): # pylint: disable=redefined-builtin
3737
"""Display a visual representation of the bloq in IPython.
@@ -75,9 +75,51 @@ def show_bloqs(bloqs: Sequence['Bloq'], labels: Optional[Sequence[Optional[str]]
7575
IPython.display.display(box)
7676

7777

78-
def show_call_graph(g: 'nx.DiGraph') -> None:
79-
"""Display a graph representation of the counts graph `g`."""
80-
IPython.display.display(GraphvizCounts(g).get_svg())
78+
@overload
79+
def show_call_graph(
80+
item: 'Bloq', /, *, max_depth: Optional[int] = None, agg_gate_counts: Optional[str] = None
81+
) -> None:
82+
...
83+
84+
85+
@overload
86+
def show_call_graph(
87+
item: 'nx.Graph', /, *, max_depth: Optional[int] = None, agg_gate_counts: Optional[str] = None
88+
) -> None:
89+
...
90+
91+
92+
def show_call_graph(
93+
item: Union['Bloq', 'nx.Graph'],
94+
/,
95+
*,
96+
max_depth: Optional[int] = None,
97+
agg_gate_counts: Optional[str] = None,
98+
) -> None:
99+
"""Display a graph representation of the call graph.
100+
101+
Args:
102+
item: Either a networkx graph or a bloq. If a networkx graph, it should be a "call graph"
103+
which is passed verbatim to the graph drawer and the additional arguments to this
104+
function are ignored. If it is a bloq, the factory
105+
method `GraphvizCallGraph.from_bloq()` is used to construct the call graph, compute
106+
relevant costs, and display the call graph annotated with the costs.
107+
max_depth: The maximum depth (from the root bloq) of the call graph to draw. Note
108+
that the cost computations will walk the whole call graph, but only the nodes
109+
within this depth will be drawn.
110+
agg_gate_counts: One of 'factored', 'total_t', 't_and_ccz', or 'beverland' to
111+
(optionally) aggregate the gate counts. If not specified, the 'factored'
112+
approach is used where each type of gate is counted individually.
113+
114+
"""
115+
if isinstance(item, Bloq):
116+
IPython.display.display(
117+
GraphvizCallGraph.from_bloq(
118+
item, max_depth=max_depth, agg_gate_counts=agg_gate_counts
119+
).get_svg()
120+
)
121+
else:
122+
IPython.display.display(GraphvizCounts(item).get_svg())
81123

82124

83125
def show_counts_sigma(sigma: Dict['Bloq', Union[int, 'sympy.Expr']]):

qualtran/drawing/bloq_counts_graph.py

+142-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
"""Classes for drawing bloq counts graphs with Graphviz."""
1616
import abc
1717
import html
18-
from typing import Any, Dict, Iterable, Optional, Tuple, Union
18+
import warnings
19+
from typing import Any, cast, Dict, Iterable, Mapping, Optional, Tuple, TYPE_CHECKING, Union
1920

2021
import attrs
2122
import IPython.display
@@ -24,6 +25,10 @@
2425
import sympy
2526

2627
from qualtran import Bloq, CompositeBloq
28+
from qualtran.symbolics import SymbolicInt
29+
30+
if TYPE_CHECKING:
31+
from qualtran.resource_counting import CostKey, CostValT, GateCounts
2732

2833

2934
class _CallGraphDrawerBase(metaclass=abc.ABCMeta):
@@ -176,7 +181,15 @@ class GraphvizCallGraph(_CallGraphDrawerBase):
176181
Each edge is labeled with the number of times the "caller" (predecessor) bloq calls the
177182
"callee" (successor) bloq.
178183
179-
This class follows the behavior described in https://github.com/quantumlib/Qualtran/issues/791
184+
The constructor of this class assumes you have already generated the call graph as a networkx
185+
graph and constructed any associated data. See the factory method
186+
`GraphvizCallGraph.from_bloq()` to set up a call graph diagram from a bloq with sensible
187+
defaults.
188+
189+
This class uses a bloq's `__str__` string to title the bloq. Arbitrary additional tabular
190+
data can be provided with `bloq_data`.
191+
192+
This graph drawer is the successor to the `GraphvizCounts` existing drawer,
180193
and will replace `GraphvizCounts` when all bloqs have been migrated to use `__str__()`.
181194
182195
Args:
@@ -193,6 +206,133 @@ def __init__(self, g: nx.DiGraph, bloq_data: Optional[Dict['Bloq', Dict[Any, Any
193206

194207
self.bloq_data = bloq_data
195208

209+
@classmethod
210+
def format_qubit_count(cls, val: SymbolicInt) -> Dict[str, str]:
211+
"""Format `QubitCount` cost values as a string.
212+
213+
Args:
214+
val: The qubit count value, which should be an integer
215+
216+
Returns:
217+
A dictionary mapping a string cost name to a string cost value.
218+
"""
219+
return {'Qubits': f'{val}'}
220+
221+
@classmethod
222+
def format_qec_gates_cost(cls, val: 'GateCounts', agg: Optional[str] = None) -> Dict[str, str]:
223+
"""Format `QECGatesCost` cost values as a string.
224+
225+
Args:
226+
val: The qec gate costs value, which should be a `GateCounts` dataclass.
227+
agg: One of 'factored', 'total_t', 't_and_ccz', or 'beverland' to
228+
(optionally) aggregate the gate counts. If not specified, the 'factored'
229+
approach is used where each type of gate is counted individually. See the
230+
methods on `GateCounts` for more information.
231+
232+
Returns:
233+
A dictionary mapping string cost names to string cost values.
234+
"""
235+
labels = {
236+
't': 'Ts',
237+
'n_t': 'Ts',
238+
'toffoli': 'Toffolis',
239+
'n_ccz': 'CCZs',
240+
'cswap': 'CSwaps',
241+
'and_bloq': 'Ands',
242+
'clifford': 'Cliffords',
243+
'rotation': 'Rotations',
244+
'measurement': 'Measurements',
245+
}
246+
counts_dict: Mapping[str, SymbolicInt]
247+
if agg is None or agg == 'factored':
248+
counts_dict = val.asdict()
249+
elif agg == 'total_t':
250+
counts_dict = {'t': val.total_t_count()}
251+
elif agg == 't_and_ccz':
252+
counts_dict = val.total_t_and_ccz_count()
253+
elif agg == 'beverland':
254+
counts_dict = val.total_beverland_count()
255+
else:
256+
raise ValueError(f"Unknown aggregation mode {agg}.")
257+
258+
return {labels.get(gate_k, gate_k): f'{gate_v}' for gate_k, gate_v in counts_dict.items()}
259+
260+
@classmethod
261+
def format_cost_data(
262+
cls,
263+
cost_data: Dict['Bloq', Dict['CostKey', 'CostValT']],
264+
agg_gate_counts: Optional[str] = None,
265+
) -> Dict['Bloq', Dict[str, str]]:
266+
"""Format `cost_data` as human-readable strings.
267+
268+
Args:
269+
cost_data: The cost data, likely returned from a call to `query_costs()`. This
270+
class method will delegate to `format_qubit_count` and `format_qec_gates_cost`
271+
for `QubitCount` and `QECGatesCost` cost keys, respectively.
272+
agg_gate_counts: One of 'factored', 'total_t', 't_and_ccz', or 'beverland' to
273+
(optionally) aggregate the gate counts. If not specified, the 'factored'
274+
approach is used where each type of gate is counted individually. See the
275+
methods on `GateCounts` for more information.
276+
277+
Returns:
278+
For each bloq key, a table of label/value pairs consisting of
279+
human-readable labels and formatted values.
280+
"""
281+
from qualtran.resource_counting import GateCounts, QECGatesCost, QubitCount
282+
283+
disp_data: Dict['Bloq', Dict[str, str]] = {}
284+
for bloq in cost_data.keys():
285+
bloq_disp: Dict[str, str] = {}
286+
for cost_key, cost_val in cost_data[bloq].items():
287+
if isinstance(cost_key, QubitCount):
288+
bloq_disp |= cls.format_qubit_count(cast(SymbolicInt, cost_val))
289+
elif isinstance(cost_key, QECGatesCost):
290+
assert isinstance(cost_val, GateCounts)
291+
bloq_disp |= cls.format_qec_gates_cost(cost_val, agg=agg_gate_counts)
292+
else:
293+
warnings.warn(f"Unknown cost key {cost_key}")
294+
bloq_disp[str(cost_key)] = str(cost_val)
295+
296+
disp_data[bloq] = bloq_disp
297+
return disp_data
298+
299+
@classmethod
300+
def from_bloq(
301+
cls, bloq: Bloq, *, max_depth: Optional[int] = None, agg_gate_counts: Optional[str] = None
302+
) -> 'GraphvizCallGraph':
303+
"""Draw a bloq call graph.
304+
305+
This factory method will generate a call graph from the bloq, query the `QECGatesCost`
306+
and `QubitCount` costs, format the cost data, and merge it with the call graph
307+
to create a call graph diagram with annotated costs.
308+
309+
For additional customization, users can construct the call graph and bloq data themselves
310+
and use the normal constructor, or provide minor display customizations by
311+
overriding the `format_xxx` class methods.
312+
313+
Args:
314+
bloq: The bloq from which we construct the call graph and query the costs.
315+
max_depth: The maximum depth (from the root bloq) of the call graph to draw. Note
316+
that the cost computations will walk the whole call graph, but only the nodes
317+
within this depth will be drawn.
318+
agg_gate_counts: One of 'factored', 'total_t', 't_and_ccz', or 'beverland' to
319+
(optionally) aggregate the gate counts. If not specified, the 'factored'
320+
approach is used where each type of gate is counted individually. See the
321+
methods on `GateCounts` for more information.
322+
323+
Returns:
324+
A `GraphvizCallGraph` diagram-drawer, whose methods can be used to generate
325+
graphviz inputs or SVG diagrams.
326+
"""
327+
from qualtran.resource_counting import QECGatesCost, QubitCount, query_costs
328+
329+
call_graph, _ = bloq.call_graph(max_depth=max_depth)
330+
cost_data: Dict['Bloq', Dict[CostKey, Any]] = query_costs(
331+
bloq, [QubitCount(), QECGatesCost()]
332+
)
333+
formatted_cost_data = cls.format_cost_data(cost_data, agg_gate_counts=agg_gate_counts)
334+
return cls(g=call_graph, bloq_data=formatted_cost_data)
335+
196336
def get_node_title(self, b: Bloq):
197337
return str(b)
198338

qualtran/drawing/bloq_counts_graph_test.py

+46
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import re
1616
from typing import List
1717

18+
import networkx as nx
19+
1820
from qualtran.bloqs.for_testing import TestBloqWithCallGraph
1921
from qualtran.bloqs.mcmt.and_bloq import MultiAnd
2022
from qualtran.drawing import (
@@ -128,3 +130,47 @@ def test_graphviz_call_graph_with_data():
128130
'<tr><td colspan="2"><font point-size="10">TestBloqWithCallGraph</font></td></tr>\n'
129131
'<tr><td>T count</td><td>100*_n0 + 600</td></tr><tr><td>clifford</td><td>0</td></tr><tr><td>rot</td><td>0</td></tr></table></font>>'
130132
)
133+
134+
135+
def test_graphviz_call_graph_from_bloq():
136+
bloq = TestBloqWithCallGraph()
137+
drawer = GraphvizCallGraph.from_bloq(bloq)
138+
139+
node_labels = _get_node_labels_from_pydot_graph(drawer)
140+
for nl in node_labels:
141+
# Spot check one of the nodes
142+
if 'TestBloqWithCallGraph' in nl:
143+
assert nl == (
144+
'<<font point-size="10"><table border="0" cellborder="1" cellspacing="0" cellpadding="5">\n'
145+
'<tr><td colspan="2"><font point-size="10">TestBloqWithCallGraph</font></td></tr>\n'
146+
'<tr><td>Qubits</td><td>3</td></tr>'
147+
'<tr><td>Ts</td><td>100*_n0 + 600</td></tr>'
148+
'</table></font>>'
149+
)
150+
151+
152+
def test_graphviz_call_graph_from_bloq_agg():
153+
bloq = TestBloqWithCallGraph()
154+
drawer = GraphvizCallGraph.from_bloq(bloq, agg_gate_counts='t_and_ccz')
155+
156+
node_labels = _get_node_labels_from_pydot_graph(drawer)
157+
for nl in node_labels:
158+
# Spot check one of the nodes
159+
# Note the additional cell.
160+
if 'TestBloqWithCallGraph' in nl:
161+
assert nl == (
162+
'<<font point-size="10"><table border="0" cellborder="1" cellspacing="0" cellpadding="5">\n'
163+
'<tr><td colspan="2"><font point-size="10">TestBloqWithCallGraph</font></td></tr>\n'
164+
'<tr><td>Qubits</td><td>3</td></tr>'
165+
'<tr><td>Ts</td><td>100*_n0 + 600</td></tr>'
166+
'<tr><td>CCZs</td><td>0</td></tr>'
167+
'</table></font>>'
168+
)
169+
170+
171+
def test_graphviz_call_graph_from_bloq_max_depth():
172+
bloq = TestBloqWithCallGraph()
173+
drawer = GraphvizCallGraph.from_bloq(bloq)
174+
assert len(list(nx.topological_generations(drawer.g))) == 3
175+
drawer2 = GraphvizCallGraph.from_bloq(bloq, max_depth=1)
176+
assert len(list(nx.topological_generations(drawer2.g))) == 2

0 commit comments

Comments
 (0)