@@ -62,6 +62,12 @@ BlackJAX offers another JAX-based sampling implementation focused on flexibility
62
62
63
63
+++
64
64
65
+ ## Installation Requirements
66
+
67
+ To use the various sampling backends, you need to install the corresponding packages. Nutpie is the recommended high-performance option and can be installed with pip or conda/mamba (e.g. ` conda install nutpie ` ). For JAX-based workflows, NumPyro provides mature functionality and is installed with the ` numpyro ` package. BlackJAX offers an alternative JAX implementation and is available in the ` blackjax ` package.
68
+
69
+ +++
70
+
65
71
## Performance Guidelines
66
72
67
73
Understanding when to use each sampler depends on several key factors including model size, variable types, and computational requirements.
@@ -73,28 +79,57 @@ Models containing **discrete variables** must use PyMC's built-in sampler, as it
73
79
** Numba** excels at CPU optimization and provides consistent performance across different model types. It's particularly effective for models with complex mathematical operations that benefit from just-in-time compilation. ** JAX** offers superior performance for very large models and provides natural GPU acceleration, making it ideal when computational resources are a limiting factor. The ** C** backend serves as a reliable fallback option with broad compatibility but typically offers lower performance than the alternatives.
74
80
75
81
``` {code-cell} ipython3
76
- import platform
82
+ import time
83
+
84
+ from collections import defaultdict
77
85
78
86
import arviz as az
79
87
import matplotlib.pyplot as plt
80
88
import numpy as np
89
+ import numpyro
90
+ import pandas as pd
81
91
import pymc as pm
82
92
83
- if platform.system() == "linux":
84
- import multiprocessing
93
+ numpyro.set_host_device_count(4)
85
94
86
- multiprocessing.set_start_method("spawn", force=True)
95
+ %config InlineBackend.figure_format = 'retina'
96
+ az.style.use("arviz-darkgrid")
87
97
88
98
rng = np.random.default_rng(seed=42)
89
99
print(f"Running on PyMC v{pm.__version__}")
90
100
```
91
101
92
102
``` {code-cell} ipython3
93
- %config InlineBackend.figure_format = 'retina'
94
- az.style.use("arviz-darkgrid")
95
- ```
103
+ import time
104
+
105
+ from collections import defaultdict
106
+
107
+ # Dictionary to store all results
108
+ results = defaultdict(dict)
109
+
110
+
111
+ class TimingContext:
112
+ def __init__(self, name):
113
+ self.name = name
114
+
115
+ def __enter__(self):
116
+ self.start_wall = time.perf_counter()
117
+ self.start_cpu = time.process_time()
118
+ return self
119
+
120
+ def __exit__(self, *args):
121
+ self.end_wall = time.perf_counter()
122
+ self.end_cpu = time.process_time()
123
+
124
+ wall_time = self.end_wall - self.start_wall
125
+ cpu_time = self.end_cpu - self.start_cpu
96
126
97
- We'll demonstrate the performance differences using a Probabilistic Principal Component Analysis (PPCA) model.
127
+ results[self.name]["wall_time"] = wall_time
128
+ results[self.name]["cpu_time"] = cpu_time
129
+
130
+ print(f"Wall time: {wall_time:.1f} s")
131
+ print(f"CPU time: {cpu_time:.1f} s")
132
+ ```
98
133
99
134
``` {code-cell} ipython3
100
135
def build_toy_dataset(N, D, K, sigma=1):
@@ -129,10 +164,14 @@ plt.title("Simulated data set")
129
164
```
130
165
131
166
``` {code-cell} ipython3
132
- with pm.Model() as PPCA:
133
- w = pm.Normal("w", mu=0, sigma=2, shape=[D, K], transform=pm.distributions.transforms.Ordered())
134
- z = pm.Normal("z", mu=0, sigma=1, shape=[N, K])
135
- x = pm.Normal("x", mu=w.dot(z.T), sigma=1, shape=[D, N], observed=data)
167
+ def ppca_model():
168
+ with pm.Model() as model:
169
+ w = pm.Normal(
170
+ "w", mu=0, sigma=2, shape=[D, K], transform=pm.distributions.transforms.Ordered()
171
+ )
172
+ z = pm.Normal("z", mu=0, sigma=1, shape=[N, K])
173
+ x = pm.Normal("x", mu=w.dot(z.T), sigma=1, shape=[D, N], observed=data)
174
+ return model
136
175
```
137
176
138
177
## Performance Comparison
@@ -142,44 +181,154 @@ Now let's compare the performance of different sampling backends on our PPCA mod
142
181
### 1. PyMC Default Sampler (Python NUTS)
143
182
144
183
``` {code-cell} ipython3
145
- %%time
146
- with PPCA:
147
- idata_pymc = pm.sample(progressbar=False)
184
+ n_draws = 2000
185
+ n_tune = 2000
186
+
187
+ with TimingContext("PyMC Default"):
188
+ with ppca_model():
189
+ idata_pymc = pm.sample(draws=n_draws, tune=n_tune, progressbar=False)
190
+
191
+ ess_pymc = az.ess(idata_pymc)
192
+ min_ess = min([ess_pymc[var].values.min() for var in ess_pymc.data_vars])
193
+ mean_ess = np.mean([ess_pymc[var].values.mean() for var in ess_pymc.data_vars])
194
+ results["PyMC Default"]["min_ess"] = min_ess
195
+ results["PyMC Default"]["mean_ess"] = mean_ess
196
+ print(f"Min ESS: {min_ess:.0f}, Mean ESS: {mean_ess:.0f}")
148
197
```
149
198
150
- ### 2. Nutpie with Numba Backend
199
+ ### 2. Nutpie Sampler with Numba Backend
151
200
152
201
``` {code-cell} ipython3
153
- %%time
154
- with PPCA:
155
- idata_nutpie_numba = pm.sample(
156
- nuts_sampler="nutpie", nuts_sampler_kwargs={"backend": "numba"}, progressbar=False
157
- )
202
+ with TimingContext("Nutpie Numba"):
203
+ with ppca_model():
204
+ idata_nutpie_numba = pm.sample(
205
+ draws=n_draws,
206
+ tune=n_tune,
207
+ nuts_sampler="nutpie",
208
+ nuts_sampler_kwargs={"backend": "numba"},
209
+ progressbar=False,
210
+ )
211
+
212
+ ess_nutpie_numba = az.ess(idata_nutpie_numba)
213
+ min_ess = min([ess_nutpie_numba[var].values.min() for var in ess_nutpie_numba.data_vars])
214
+ mean_ess = np.mean([ess_nutpie_numba[var].values.mean() for var in ess_nutpie_numba.data_vars])
215
+ results["Nutpie Numba"]["min_ess"] = min_ess
216
+ results["Nutpie Numba"]["mean_ess"] = mean_ess
217
+ print(f"Min ESS: {min_ess:.0f}, Mean ESS: {mean_ess:.0f}")
158
218
```
159
219
160
- ### 3. Nutpie with JAX Backend
220
+ ### 3. Nutpie Sampler with JAX Backend
161
221
162
222
``` {code-cell} ipython3
163
- %%time
164
- with PPCA:
165
- idata_nutpie_jax = pm.sample(
166
- nuts_sampler="nutpie", nuts_sampler_kwargs={"backend": "jax"}, progressbar=False
167
- )
223
+ with TimingContext("Nutpie JAX"):
224
+ with ppca_model():
225
+ idata_nutpie_jax = pm.sample(
226
+ draws=n_draws,
227
+ tune=n_tune,
228
+ nuts_sampler="nutpie",
229
+ nuts_sampler_kwargs={"backend": "jax"},
230
+ progressbar=False,
231
+ )
232
+
233
+ ess_nutpie_jax = az.ess(idata_nutpie_jax)
234
+ min_ess = min([ess_nutpie_jax[var].values.min() for var in ess_nutpie_jax.data_vars])
235
+ mean_ess = np.mean([ess_nutpie_jax[var].values.mean() for var in ess_nutpie_jax.data_vars])
236
+ results["Nutpie JAX"]["min_ess"] = min_ess
237
+ results["Nutpie JAX"]["mean_ess"] = mean_ess
238
+ print(f"Min ESS: {min_ess:.0f}, Mean ESS: {mean_ess:.0f}")
168
239
```
169
240
170
241
### 4. NumPyro Sampler
171
242
172
243
``` {code-cell} ipython3
173
- %%time
174
- with PPCA:
175
- idata_numpyro = pm.sample(nuts_sampler="numpyro", progressbar=False)
244
+ with TimingContext("NumPyro"):
245
+ with ppca_model():
246
+ idata_numpyro = pm.sample(
247
+ draws=n_draws, tune=n_tune, nuts_sampler="numpyro", progressbar=False
248
+ )
249
+
250
+ ess_numpyro = az.ess(idata_numpyro)
251
+ min_ess = min([ess_numpyro[var].values.min() for var in ess_numpyro.data_vars])
252
+ mean_ess = np.mean([ess_numpyro[var].values.mean() for var in ess_numpyro.data_vars])
253
+ results["NumPyro"]["min_ess"] = min_ess
254
+ results["NumPyro"]["mean_ess"] = mean_ess
255
+ print(f"Min ESS: {min_ess:.0f}, Mean ESS: {mean_ess:.0f}")
176
256
```
177
257
178
- ## Installation Requirements
258
+ ``` {code-cell} ipython3
259
+ timing_data = []
260
+ for backend_name, metrics in results.items():
261
+ wall_time = metrics.get("wall_time", 0)
262
+ cpu_time = metrics.get("cpu_time", 0)
263
+ min_ess = metrics.get("min_ess", 0)
264
+ mean_ess = metrics.get("mean_ess", 0)
265
+ ess_per_sec = mean_ess / wall_time if wall_time > 0 else 0
266
+
267
+ timing_data.append(
268
+ {
269
+ "Sampling Backend": backend_name,
270
+ "Wall Time (s)": f"{wall_time:.1f}",
271
+ "CPU Time (s)": f"{cpu_time:.1f}",
272
+ "Min ESS": f"{min_ess:.0f}",
273
+ "Mean ESS": f"{mean_ess:.0f}",
274
+ "ESS/sec": f"{ess_per_sec:.0f}",
275
+ "Parallel Efficiency": f"{cpu_time/wall_time:.2f}" if wall_time > 0 else "N/A",
276
+ }
277
+ )
179
278
180
- To use the various sampling backends, you need to install the corresponding packages. Nutpie is the recommended high-performance option and can be installed with pip or conda/mamba (e.g. ` conda install nutpie ` ). For JAX-based workflows, NumPyro provides mature functionality and is installed with the ` numpyro ` package. BlackJAX offers an alternative JAX implementation and is available in the ` blackjax ` package.
279
+ timing_df = pd.DataFrame(timing_data)
280
+ timing_df = timing_df.sort_values("ESS/sec", ascending=False)
181
281
182
- +++
282
+ print("\nPerformance Summary Table:")
283
+ print("=" * 100)
284
+ print(timing_df.to_string(index=False))
285
+ print("=" * 100)
286
+
287
+ best_backend = timing_df.iloc[0]["Sampling Backend"]
288
+ best_ess_per_sec = timing_df.iloc[0]["ESS/sec"]
289
+ print(f"\nMost efficient backend: {best_backend} with {best_ess_per_sec} ESS/second")
290
+ ```
291
+
292
+ ``` {code-cell} ipython3
293
+ fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 8))
294
+
295
+ backends = timing_df["Sampling Backend"].tolist()
296
+ wall_times = [float(val) for val in timing_df["Wall Time (s)"].tolist()]
297
+ mean_ess_values = [float(val) for val in timing_df["Mean ESS"].tolist()]
298
+ ess_per_sec_values = [float(val) for val in timing_df["ESS/sec"].tolist()]
299
+
300
+ ax1.bar(backends, wall_times, color="skyblue")
301
+ ax1.set_ylabel("Wall Time (seconds)")
302
+ ax1.set_title("Sampling Time")
303
+ ax1.tick_params(axis="x", rotation=45)
304
+
305
+ ax2.bar(backends, mean_ess_values, color="lightgreen")
306
+ ax2.set_ylabel("Mean ESS")
307
+ ax2.set_title("Effective Sample Size")
308
+ ax2.tick_params(axis="x", rotation=45)
309
+
310
+ ax3.bar(backends, ess_per_sec_values, color="coral")
311
+ ax3.set_ylabel("ESS per Second")
312
+ ax3.set_title("Sampling Efficiency")
313
+ ax3.tick_params(axis="x", rotation=45)
314
+
315
+ ax4.scatter(wall_times, mean_ess_values, s=200, alpha=0.6)
316
+ for i, backend in enumerate(backends):
317
+ ax4.annotate(
318
+ backend,
319
+ (wall_times[i], mean_ess_values[i]),
320
+ xytext=(5, 5),
321
+ textcoords="offset points",
322
+ fontsize=9,
323
+ )
324
+ ax4.set_xlabel("Wall Time (seconds)")
325
+ ax4.set_ylabel("Mean ESS")
326
+ ax4.set_title("Time vs. Effective Sample Size")
327
+ ax4.grid(True, alpha=0.3)
328
+
329
+ plt.tight_layout()
330
+ plt.show()
331
+ ```
183
332
184
333
## Special Cases and Advanced Usage
185
334
@@ -190,13 +339,13 @@ In certain scenarios, you may need to use PyMC's Python-based sampler while stil
190
339
The following examples demonstrate how to use PyMC's built-in sampler with different compilation targets. The ` fast_run ` mode uses optimized C compilation, which provides good performance while maintaining full compatibility. The ` numba ` mode offers the only way to access Numba's just-in-time compilation benefits when using PyMC's sampler. The ` jax ` mode enables JAX compilation, though for JAX workflows, Nutpie or NumPyro typically provide better performance.
191
340
192
341
``` {code-cell} ipython3
193
- with PPCA :
342
+ with ppca_model() :
194
343
idata_c = pm.sample(nuts_sampler="pymc", compile_kwargs={"mode": "fast_run"}, progressbar=False)
195
344
196
- # with PPCA :
345
+ # with ppca_model() :
197
346
# idata_pymc_numba = pm.sample(nuts_sampler="pymc", compile_kwargs={"mode": "numba"}, progressbar=False)
198
347
199
- # with PPCA :
348
+ # with ppca_model() :
200
349
# idata_pymc_jax = pm.sample(nuts_sampler="pymc", compile_kwargs={"mode": "jax"}, progressbar=False)
201
350
```
202
351
@@ -221,12 +370,9 @@ with pm.Model() as discrete_model:
221
370
## Authors
222
371
223
372
- Originally authored by Thomas Wiecki in July 2023
224
- - Substantially updated and expanded by Chris Fonnesbeck in May 2025
373
+ - Updated and expanded by Chris Fonnesbeck in May 2025
225
374
226
375
``` {code-cell} ipython3
227
376
%load_ext watermark
228
377
%watermark -n -u -v -iv -w -p pytensor,arviz,pymc,numpyro,blackjax,nutpie
229
378
```
230
-
231
- :::{include} ../page_footer.md
232
- :::
0 commit comments