Skip to content

Commit 2aadd3e

Browse files
committed
Added timing summary and ESS
1 parent 1cfdb25 commit 2aadd3e

File tree

2 files changed

+448
-109
lines changed

2 files changed

+448
-109
lines changed

examples/samplers/fast_sampling_with_jax_and_numba.ipynb

Lines changed: 262 additions & 69 deletions
Large diffs are not rendered by default.

examples/samplers/fast_sampling_with_jax_and_numba.myst.md

Lines changed: 186 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,12 @@ BlackJAX offers another JAX-based sampling implementation focused on flexibility
6262

6363
+++
6464

65+
## Installation Requirements
66+
67+
To use the various sampling backends, you need to install the corresponding packages. Nutpie is the recommended high-performance option and can be installed with pip or conda/mamba (e.g. `conda install nutpie`). For JAX-based workflows, NumPyro provides mature functionality and is installed with the `numpyro` package. BlackJAX offers an alternative JAX implementation and is available in the `blackjax` package.
68+
69+
+++
70+
6571
## Performance Guidelines
6672

6773
Understanding when to use each sampler depends on several key factors including model size, variable types, and computational requirements.
@@ -73,28 +79,57 @@ Models containing **discrete variables** must use PyMC's built-in sampler, as it
7379
**Numba** excels at CPU optimization and provides consistent performance across different model types. It's particularly effective for models with complex mathematical operations that benefit from just-in-time compilation. **JAX** offers superior performance for very large models and provides natural GPU acceleration, making it ideal when computational resources are a limiting factor. The **C** backend serves as a reliable fallback option with broad compatibility but typically offers lower performance than the alternatives.
7480

7581
```{code-cell} ipython3
76-
import platform
82+
import time
83+
84+
from collections import defaultdict
7785
7886
import arviz as az
7987
import matplotlib.pyplot as plt
8088
import numpy as np
89+
import numpyro
90+
import pandas as pd
8191
import pymc as pm
8292
83-
if platform.system() == "linux":
84-
import multiprocessing
93+
numpyro.set_host_device_count(4)
8594
86-
multiprocessing.set_start_method("spawn", force=True)
95+
%config InlineBackend.figure_format = 'retina'
96+
az.style.use("arviz-darkgrid")
8797
8898
rng = np.random.default_rng(seed=42)
8999
print(f"Running on PyMC v{pm.__version__}")
90100
```
91101

92102
```{code-cell} ipython3
93-
%config InlineBackend.figure_format = 'retina'
94-
az.style.use("arviz-darkgrid")
95-
```
103+
import time
104+
105+
from collections import defaultdict
106+
107+
# Dictionary to store all results
108+
results = defaultdict(dict)
109+
110+
111+
class TimingContext:
112+
def __init__(self, name):
113+
self.name = name
114+
115+
def __enter__(self):
116+
self.start_wall = time.perf_counter()
117+
self.start_cpu = time.process_time()
118+
return self
119+
120+
def __exit__(self, *args):
121+
self.end_wall = time.perf_counter()
122+
self.end_cpu = time.process_time()
123+
124+
wall_time = self.end_wall - self.start_wall
125+
cpu_time = self.end_cpu - self.start_cpu
96126
97-
We'll demonstrate the performance differences using a Probabilistic Principal Component Analysis (PPCA) model.
127+
results[self.name]["wall_time"] = wall_time
128+
results[self.name]["cpu_time"] = cpu_time
129+
130+
print(f"Wall time: {wall_time:.1f} s")
131+
print(f"CPU time: {cpu_time:.1f} s")
132+
```
98133

99134
```{code-cell} ipython3
100135
def build_toy_dataset(N, D, K, sigma=1):
@@ -129,10 +164,14 @@ plt.title("Simulated data set")
129164
```
130165

131166
```{code-cell} ipython3
132-
with pm.Model() as PPCA:
133-
w = pm.Normal("w", mu=0, sigma=2, shape=[D, K], transform=pm.distributions.transforms.Ordered())
134-
z = pm.Normal("z", mu=0, sigma=1, shape=[N, K])
135-
x = pm.Normal("x", mu=w.dot(z.T), sigma=1, shape=[D, N], observed=data)
167+
def ppca_model():
168+
with pm.Model() as model:
169+
w = pm.Normal(
170+
"w", mu=0, sigma=2, shape=[D, K], transform=pm.distributions.transforms.Ordered()
171+
)
172+
z = pm.Normal("z", mu=0, sigma=1, shape=[N, K])
173+
x = pm.Normal("x", mu=w.dot(z.T), sigma=1, shape=[D, N], observed=data)
174+
return model
136175
```
137176

138177
## Performance Comparison
@@ -142,44 +181,154 @@ Now let's compare the performance of different sampling backends on our PPCA mod
142181
### 1. PyMC Default Sampler (Python NUTS)
143182

144183
```{code-cell} ipython3
145-
%%time
146-
with PPCA:
147-
idata_pymc = pm.sample(progressbar=False)
184+
n_draws = 2000
185+
n_tune = 2000
186+
187+
with TimingContext("PyMC Default"):
188+
with ppca_model():
189+
idata_pymc = pm.sample(draws=n_draws, tune=n_tune, progressbar=False)
190+
191+
ess_pymc = az.ess(idata_pymc)
192+
min_ess = min([ess_pymc[var].values.min() for var in ess_pymc.data_vars])
193+
mean_ess = np.mean([ess_pymc[var].values.mean() for var in ess_pymc.data_vars])
194+
results["PyMC Default"]["min_ess"] = min_ess
195+
results["PyMC Default"]["mean_ess"] = mean_ess
196+
print(f"Min ESS: {min_ess:.0f}, Mean ESS: {mean_ess:.0f}")
148197
```
149198

150-
### 2. Nutpie with Numba Backend
199+
### 2. Nutpie Sampler with Numba Backend
151200

152201
```{code-cell} ipython3
153-
%%time
154-
with PPCA:
155-
idata_nutpie_numba = pm.sample(
156-
nuts_sampler="nutpie", nuts_sampler_kwargs={"backend": "numba"}, progressbar=False
157-
)
202+
with TimingContext("Nutpie Numba"):
203+
with ppca_model():
204+
idata_nutpie_numba = pm.sample(
205+
draws=n_draws,
206+
tune=n_tune,
207+
nuts_sampler="nutpie",
208+
nuts_sampler_kwargs={"backend": "numba"},
209+
progressbar=False,
210+
)
211+
212+
ess_nutpie_numba = az.ess(idata_nutpie_numba)
213+
min_ess = min([ess_nutpie_numba[var].values.min() for var in ess_nutpie_numba.data_vars])
214+
mean_ess = np.mean([ess_nutpie_numba[var].values.mean() for var in ess_nutpie_numba.data_vars])
215+
results["Nutpie Numba"]["min_ess"] = min_ess
216+
results["Nutpie Numba"]["mean_ess"] = mean_ess
217+
print(f"Min ESS: {min_ess:.0f}, Mean ESS: {mean_ess:.0f}")
158218
```
159219

160-
### 3. Nutpie with JAX Backend
220+
### 3. Nutpie Sampler with JAX Backend
161221

162222
```{code-cell} ipython3
163-
%%time
164-
with PPCA:
165-
idata_nutpie_jax = pm.sample(
166-
nuts_sampler="nutpie", nuts_sampler_kwargs={"backend": "jax"}, progressbar=False
167-
)
223+
with TimingContext("Nutpie JAX"):
224+
with ppca_model():
225+
idata_nutpie_jax = pm.sample(
226+
draws=n_draws,
227+
tune=n_tune,
228+
nuts_sampler="nutpie",
229+
nuts_sampler_kwargs={"backend": "jax"},
230+
progressbar=False,
231+
)
232+
233+
ess_nutpie_jax = az.ess(idata_nutpie_jax)
234+
min_ess = min([ess_nutpie_jax[var].values.min() for var in ess_nutpie_jax.data_vars])
235+
mean_ess = np.mean([ess_nutpie_jax[var].values.mean() for var in ess_nutpie_jax.data_vars])
236+
results["Nutpie JAX"]["min_ess"] = min_ess
237+
results["Nutpie JAX"]["mean_ess"] = mean_ess
238+
print(f"Min ESS: {min_ess:.0f}, Mean ESS: {mean_ess:.0f}")
168239
```
169240

170241
### 4. NumPyro Sampler
171242

172243
```{code-cell} ipython3
173-
%%time
174-
with PPCA:
175-
idata_numpyro = pm.sample(nuts_sampler="numpyro", progressbar=False)
244+
with TimingContext("NumPyro"):
245+
with ppca_model():
246+
idata_numpyro = pm.sample(
247+
draws=n_draws, tune=n_tune, nuts_sampler="numpyro", progressbar=False
248+
)
249+
250+
ess_numpyro = az.ess(idata_numpyro)
251+
min_ess = min([ess_numpyro[var].values.min() for var in ess_numpyro.data_vars])
252+
mean_ess = np.mean([ess_numpyro[var].values.mean() for var in ess_numpyro.data_vars])
253+
results["NumPyro"]["min_ess"] = min_ess
254+
results["NumPyro"]["mean_ess"] = mean_ess
255+
print(f"Min ESS: {min_ess:.0f}, Mean ESS: {mean_ess:.0f}")
176256
```
177257

178-
## Installation Requirements
258+
```{code-cell} ipython3
259+
timing_data = []
260+
for backend_name, metrics in results.items():
261+
wall_time = metrics.get("wall_time", 0)
262+
cpu_time = metrics.get("cpu_time", 0)
263+
min_ess = metrics.get("min_ess", 0)
264+
mean_ess = metrics.get("mean_ess", 0)
265+
ess_per_sec = mean_ess / wall_time if wall_time > 0 else 0
266+
267+
timing_data.append(
268+
{
269+
"Sampling Backend": backend_name,
270+
"Wall Time (s)": f"{wall_time:.1f}",
271+
"CPU Time (s)": f"{cpu_time:.1f}",
272+
"Min ESS": f"{min_ess:.0f}",
273+
"Mean ESS": f"{mean_ess:.0f}",
274+
"ESS/sec": f"{ess_per_sec:.0f}",
275+
"Parallel Efficiency": f"{cpu_time/wall_time:.2f}" if wall_time > 0 else "N/A",
276+
}
277+
)
179278
180-
To use the various sampling backends, you need to install the corresponding packages. Nutpie is the recommended high-performance option and can be installed with pip or conda/mamba (e.g. `conda install nutpie`). For JAX-based workflows, NumPyro provides mature functionality and is installed with the `numpyro` package. BlackJAX offers an alternative JAX implementation and is available in the `blackjax` package.
279+
timing_df = pd.DataFrame(timing_data)
280+
timing_df = timing_df.sort_values("ESS/sec", ascending=False)
181281
182-
+++
282+
print("\nPerformance Summary Table:")
283+
print("=" * 100)
284+
print(timing_df.to_string(index=False))
285+
print("=" * 100)
286+
287+
best_backend = timing_df.iloc[0]["Sampling Backend"]
288+
best_ess_per_sec = timing_df.iloc[0]["ESS/sec"]
289+
print(f"\nMost efficient backend: {best_backend} with {best_ess_per_sec} ESS/second")
290+
```
291+
292+
```{code-cell} ipython3
293+
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 8))
294+
295+
backends = timing_df["Sampling Backend"].tolist()
296+
wall_times = [float(val) for val in timing_df["Wall Time (s)"].tolist()]
297+
mean_ess_values = [float(val) for val in timing_df["Mean ESS"].tolist()]
298+
ess_per_sec_values = [float(val) for val in timing_df["ESS/sec"].tolist()]
299+
300+
ax1.bar(backends, wall_times, color="skyblue")
301+
ax1.set_ylabel("Wall Time (seconds)")
302+
ax1.set_title("Sampling Time")
303+
ax1.tick_params(axis="x", rotation=45)
304+
305+
ax2.bar(backends, mean_ess_values, color="lightgreen")
306+
ax2.set_ylabel("Mean ESS")
307+
ax2.set_title("Effective Sample Size")
308+
ax2.tick_params(axis="x", rotation=45)
309+
310+
ax3.bar(backends, ess_per_sec_values, color="coral")
311+
ax3.set_ylabel("ESS per Second")
312+
ax3.set_title("Sampling Efficiency")
313+
ax3.tick_params(axis="x", rotation=45)
314+
315+
ax4.scatter(wall_times, mean_ess_values, s=200, alpha=0.6)
316+
for i, backend in enumerate(backends):
317+
ax4.annotate(
318+
backend,
319+
(wall_times[i], mean_ess_values[i]),
320+
xytext=(5, 5),
321+
textcoords="offset points",
322+
fontsize=9,
323+
)
324+
ax4.set_xlabel("Wall Time (seconds)")
325+
ax4.set_ylabel("Mean ESS")
326+
ax4.set_title("Time vs. Effective Sample Size")
327+
ax4.grid(True, alpha=0.3)
328+
329+
plt.tight_layout()
330+
plt.show()
331+
```
183332

184333
## Special Cases and Advanced Usage
185334

@@ -190,13 +339,13 @@ In certain scenarios, you may need to use PyMC's Python-based sampler while stil
190339
The following examples demonstrate how to use PyMC's built-in sampler with different compilation targets. The `fast_run` mode uses optimized C compilation, which provides good performance while maintaining full compatibility. The `numba` mode offers the only way to access Numba's just-in-time compilation benefits when using PyMC's sampler. The `jax` mode enables JAX compilation, though for JAX workflows, Nutpie or NumPyro typically provide better performance.
191340

192341
```{code-cell} ipython3
193-
with PPCA:
342+
with ppca_model():
194343
idata_c = pm.sample(nuts_sampler="pymc", compile_kwargs={"mode": "fast_run"}, progressbar=False)
195344
196-
# with PPCA:
345+
# with ppca_model():
197346
# idata_pymc_numba = pm.sample(nuts_sampler="pymc", compile_kwargs={"mode": "numba"}, progressbar=False)
198347
199-
# with PPCA:
348+
# with ppca_model():
200349
# idata_pymc_jax = pm.sample(nuts_sampler="pymc", compile_kwargs={"mode": "jax"}, progressbar=False)
201350
```
202351

@@ -221,12 +370,9 @@ with pm.Model() as discrete_model:
221370
## Authors
222371

223372
- Originally authored by Thomas Wiecki in July 2023
224-
- Substantially updated and expanded by Chris Fonnesbeck in May 2025
373+
- Updated and expanded by Chris Fonnesbeck in May 2025
225374

226375
```{code-cell} ipython3
227376
%load_ext watermark
228377
%watermark -n -u -v -iv -w -p pytensor,arviz,pymc,numpyro,blackjax,nutpie
229378
```
230-
231-
:::{include} ../page_footer.md
232-
:::

0 commit comments

Comments
 (0)