Skip to content

Commit 1cfdb25

Browse files
committed
Updated watermark, improved text
1 parent 85e1145 commit 1cfdb25

File tree

2 files changed

+73
-168
lines changed

2 files changed

+73
-168
lines changed

examples/samplers/fast_sampling_with_jax_and_numba.ipynb

Lines changed: 51 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
"pm.sample()\n",
2828
"```\n",
2929
"\n",
30-
"The default PyMC sampler uses a Python-based NUTS implementation that provides maximum compatibility with all PyMC features. This sampler is always used when working with models that contain discrete variables, as it's the only option that supports non-gradient based samplers like Slice and Metropolis. While this sampler can compile the underlying model to different backends (C, Numba, or JAX) using the `compile_kwargs` parameter, it still maintains Python overhead that can limit performance for large models.\n",
30+
"The default PyMC sampler uses a Python-based NUTS implementation that provides maximum compatibility with all PyMC features. This sampler is required when working with models that contain discrete variables, as it's the only option that supports non-gradient based samplers like Slice and Metropolis. While this sampler can compile the underlying model to different backends (C, Numba, or JAX) using PyTensor's compilation system via the `compile_kwargs` parameter, it maintains Python overhead that can limit performance for large models.\n",
3131
"\n",
3232
"### Nutpie Sampler\n",
3333
"\n",
@@ -37,7 +37,7 @@
3737
"pm.sample(nuts_sampler=\"nutpie\", nuts_sampler_kwargs={\"backend\": \"jax\", \"gradient_backend\": \"pytensor\"})\n",
3838
"```\n",
3939
"\n",
40-
"Nutpie is on the cutting-edge of PyMC sampling performance. Written in Rust, it eliminates most Python overhead and provides exceptional performance for continuous models. The Numba backend typically offers the highest performance for most use cases, while the JAX backend excels with very large models and provides GPU acceleration capabilities. Nutpie is particularly well-suited for production workflows where sampling speed is critical.\n",
40+
"Nutpie is PyMC's cutting-edge performance sampler. Written in Rust, it eliminates Python overhead and provides exceptional performance for continuous models. The Numba backend typically offers the highest performance for most use cases, while the JAX backend excels with very large models and provides GPU acceleration capabilities. Nutpie is particularly well-suited for production workflows where sampling speed is critical.\n",
4141
"\n",
4242
"### NumPyro Sampler\n",
4343
"\n",
@@ -47,33 +47,30 @@
4747
"pm.sample(nuts_sampler=\"numpyro\", nuts_sampler_kwargs={\"chain_method\": \"vectorized\"})\n",
4848
"```\n",
4949
"\n",
50-
"NumPyro provides a mature JAX-based sampling implementation that integrates seamlessly with the broader JAX ecosystem. This sampler typically performs best with small to medium-sized models and offers excellent GPU support. NumPyro benefits from years of development within the JAX community and provides reliable performance characteristics, though it may have compilation overhead for very large models.\n",
50+
"NumPyro provides a mature JAX-based sampling implementation that integrates seamlessly with the broader JAX ecosystem. This sampler benefits from years of development within the JAX community and provides reliable performance characteristics, with excellent GPU support for accelerated computation.\n",
5151
"\n",
5252
"### BlackJAX Sampler\n",
5353
"\n",
5454
"```python\n",
5555
"pm.sample(nuts_sampler=\"blackjax\")\n",
5656
"```\n",
5757
"\n",
58-
"BlackJAX offers another JAX-based sampling implementation focused on flexibility and research applications. While it provides similar capabilities to NumPyro, it's less commonly used in production environments. BlackJAX can be valuable for experimental workflows or when specific JAX-based features are required that aren't available in other samplers.\n",
59-
"\n",
58+
"BlackJAX offers another JAX-based sampling implementation focused on flexibility and research applications. While it provides similar capabilities to NumPyro, it's less commonly used in production environments. BlackJAX can be valuable for experimental workflows or when specific JAX-based features are required."
59+
]
60+
},
61+
{
62+
"cell_type": "markdown",
63+
"metadata": {},
64+
"source": [
6065
"## Performance Guidelines\n",
6166
"\n",
6267
"Understanding when to use each sampler depends on several key factors including model size, variable types, and computational requirements.\n",
6368
"\n",
64-
"**Model Size Considerations**\n",
65-
"\n",
66-
"For small models, NumPyro typically provides the best balance of performance and reliability. The compilation overhead is minimal, and the mature JAX implementation handles these models efficiently. Larger models often benefit from Nutpie with the Numba backend, which provides excellent performance without the memory overhead sometimes associated with JAX compilation.\n",
67-
"\n",
68-
"Large models generally perform best with either Nutpie's JAX backend or Nutpie's Numba backend. The choice between these depends on whether GPU acceleration is needed and how the model's computational graph interacts with each backend's optimization strategies.\n",
69+
"For **small models**, NumPyro typically provides the best balance of performance and reliability. The compilation overhead is minimal, and its mature JAX implementation handles these models efficiently. **Large models** generally perform best with Nutpie's Numba backend for consistent CPU performance or Nutpie's JAX backend when GPU acceleration is needed or memory efficiency is critical.\n",
6970
"\n",
70-
"**Variable Type Requirements**\n",
71+
"Models containing **discrete variables** must use PyMC's built-in sampler, as it's the only implementation that supports compatible (*i.e.*, non-gradient based) sampling algorithms. For purely continuous models, all sampling backends are available, making performance the primary consideration.\n",
7172
"\n",
72-
"Models containing discrete variables have no choice but to use PyMC's built-in sampler, as it's the only implementation that supports the necessary Slice and Metropolis sampling algorithms. For purely continuous models, all sampling backends are available, making performance the primary consideration.\n",
73-
"\n",
74-
"**Computational Backend Selection**\n",
75-
"\n",
76-
"Numba excels at CPU optimization and provides consistent performance across different model types. It's particularly effective for models with complex mathematical operations that benefit from just-in-time compilation. JAX offers superior performance for very large models and provides natural GPU acceleration, making it ideal when computational resources are a limiting factor. The traditional C backend serves as a reliable fallback option with broad compatibility but typically offers lower performance than the alternatives."
73+
"**Numba** excels at CPU optimization and provides consistent performance across different model types. It's particularly effective for models with complex mathematical operations that benefit from just-in-time compilation. **JAX** offers superior performance for very large models and provides natural GPU acceleration, making it ideal when computational resources are a limiting factor. The **C** backend serves as a reliable fallback option with broad compatibility but typically offers lower performance than the alternatives."
7774
]
7875
},
7976
{
@@ -90,11 +87,18 @@
9087
}
9188
],
9289
"source": [
90+
"import platform\n",
91+
"\n",
9392
"import arviz as az\n",
9493
"import matplotlib.pyplot as plt\n",
9594
"import numpy as np\n",
9695
"import pymc as pm\n",
9796
"\n",
97+
"if platform.system() == \"linux\":\n",
98+
" import multiprocessing\n",
99+
"\n",
100+
" multiprocessing.set_start_method(\"spawn\", force=True)\n",
101+
"\n",
98102
"rng = np.random.default_rng(seed=42)\n",
99103
"print(f\"Running on PyMC v{pm.__version__}\")"
100104
]
@@ -228,55 +232,23 @@
228232
"text": [
229233
"Initializing NUTS using jitter+adapt_diag...\n",
230234
"Multiprocess sampling (4 chains in 4 jobs)\n",
231-
"NUTS: [w, z]\n"
232-
]
233-
},
234-
{
235-
"data": {
236-
"application/vnd.jupyter.widget-view+json": {
237-
"model_id": "c34297902e6f4d118f552495bdace798",
238-
"version_major": 2,
239-
"version_minor": 0
240-
},
241-
"text/plain": [
242-
"Output()"
243-
]
244-
},
245-
"metadata": {},
246-
"output_type": "display_data"
247-
},
248-
{
249-
"data": {
250-
"text/html": [
251-
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
252-
],
253-
"text/plain": []
254-
},
255-
"metadata": {},
256-
"output_type": "display_data"
257-
},
258-
{
259-
"name": "stderr",
260-
"output_type": "stream",
261-
"text": [
262-
"Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 6 seconds.\n",
263-
"The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details\n",
264-
"The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details\n"
235+
"NUTS: [w, z]\n",
236+
"Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 5 seconds.\n"
265237
]
266238
},
267239
{
268240
"name": "stdout",
269241
"output_type": "stream",
270242
"text": [
271-
"CPU times: user 16 s, sys: 417 ms, total: 16.4 s\n",
272-
"Wall time: 22.4 s\n"
243+
"CPU times: user 7.8 s, sys: 375 ms, total: 8.17 s\n",
244+
"Wall time: 13.8 s\n"
273245
]
274246
}
275247
],
276248
"source": [
277249
"%%time\n",
278250
"with PPCA:\n",
279-
" idata_pymc = pm.sample()"
251+
" idata_pymc = pm.sample(progressbar=False)"
280252
]
281253
},
282254
{
@@ -295,8 +267,8 @@
295267
"name": "stdout",
296268
"output_type": "stream",
297269
"text": [
298-
"CPU times: user 45.6 s, sys: 813 ms, total: 46.5 s\n",
299-
"Wall time: 35.6 s\n"
270+
"CPU times: user 42.3 s, sys: 798 ms, total: 43.1 s\n",
271+
"Wall time: 32.2 s\n"
300272
]
301273
}
302274
],
@@ -324,8 +296,8 @@
324296
"name": "stdout",
325297
"output_type": "stream",
326298
"text": [
327-
"CPU times: user 33.8 s, sys: 9.67 s, total: 43.5 s\n",
328-
"Wall time: 16.9 s\n"
299+
"CPU times: user 32.7 s, sys: 11.7 s, total: 44.4 s\n",
300+
"Wall time: 17.1 s\n"
329301
]
330302
}
331303
],
@@ -363,8 +335,8 @@
363335
"name": "stdout",
364336
"output_type": "stream",
365337
"text": [
366-
"CPU times: user 53.8 s, sys: 2.47 s, total: 56.3 s\n",
367-
"Wall time: 44.3 s\n"
338+
"CPU times: user 53.1 s, sys: 2.65 s, total: 55.8 s\n",
339+
"Wall time: 43.4 s\n"
368340
]
369341
}
370342
],
@@ -409,60 +381,24 @@
409381
"Multiprocess sampling (4 chains in 4 jobs)\n",
410382
"NUTS: [w, z]\n",
411383
"/var/home/fonnesbeck/repos/pymc-examples/.pixi/envs/default/lib/python3.12/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
412-
" self.pid = os.fork()\n"
413-
]
414-
},
415-
{
416-
"data": {
417-
"application/vnd.jupyter.widget-view+json": {
418-
"model_id": "715185d8daef43cdaed775149ca32369",
419-
"version_major": 2,
420-
"version_minor": 0
421-
},
422-
"text/plain": [
423-
"Output()"
424-
]
425-
},
426-
"metadata": {},
427-
"output_type": "display_data"
428-
},
429-
{
430-
"name": "stderr",
431-
"output_type": "stream",
432-
"text": [
384+
" self.pid = os.fork()\n",
433385
"/var/home/fonnesbeck/repos/pymc-examples/.pixi/envs/default/lib/python3.12/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
434-
" self.pid = os.fork()\n"
435-
]
436-
},
437-
{
438-
"data": {
439-
"text/html": [
440-
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
441-
],
442-
"text/plain": []
443-
},
444-
"metadata": {},
445-
"output_type": "display_data"
446-
},
447-
{
448-
"name": "stderr",
449-
"output_type": "stream",
450-
"text": [
451-
"Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 5 seconds.\n",
386+
" self.pid = os.fork()\n",
387+
"Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 6 seconds.\n",
452388
"The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details\n",
453389
"The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details\n"
454390
]
455391
}
456392
],
457393
"source": [
458394
"with PPCA:\n",
459-
" idata_c = pm.sample(nuts_sampler=\"pymc\", compile_kwargs={\"mode\": \"fast_run\"})\n",
395+
" idata_c = pm.sample(nuts_sampler=\"pymc\", compile_kwargs={\"mode\": \"fast_run\"}, progressbar=False)\n",
460396
"\n",
461397
"# with PPCA:\n",
462-
"# idata_pymc_numba = pm.sample(nuts_sampler=\"pymc\", compile_kwargs={\"mode\": \"numba\"})\n",
398+
"# idata_pymc_numba = pm.sample(nuts_sampler=\"pymc\", compile_kwargs={\"mode\": \"numba\"}, progressbar=False)\n",
463399
"\n",
464400
"# with PPCA:\n",
465-
"# idata_pymc_jax = pm.sample(nuts_sampler=\"pymc\", compile_kwargs={\"mode\": \"jax\"})"
401+
"# idata_pymc_jax = pm.sample(nuts_sampler=\"pymc\", compile_kwargs={\"mode\": \"jax\"}, progressbar=False)"
466402
]
467403
},
468404
{
@@ -495,47 +431,11 @@
495431
">BinaryGibbsMetropolis: [cluster]\n",
496432
">NUTS: [mu, sigma]\n",
497433
"/var/home/fonnesbeck/repos/pymc-examples/.pixi/envs/default/lib/python3.12/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
498-
" self.pid = os.fork()\n"
499-
]
500-
},
501-
{
502-
"data": {
503-
"application/vnd.jupyter.widget-view+json": {
504-
"model_id": "0d08b347ee5d43dca776a9844f714ae6",
505-
"version_major": 2,
506-
"version_minor": 0
507-
},
508-
"text/plain": [
509-
"Output()"
510-
]
511-
},
512-
"metadata": {},
513-
"output_type": "display_data"
514-
},
515-
{
516-
"name": "stderr",
517-
"output_type": "stream",
518-
"text": [
434+
" self.pid = os.fork()\n",
519435
"/var/home/fonnesbeck/repos/pymc-examples/.pixi/envs/default/lib/python3.12/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
520-
" self.pid = os.fork()\n"
521-
]
522-
},
523-
{
524-
"data": {
525-
"text/html": [
526-
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
527-
],
528-
"text/plain": []
529-
},
530-
"metadata": {},
531-
"output_type": "display_data"
532-
},
533-
{
534-
"name": "stderr",
535-
"output_type": "stream",
536-
"text": [
436+
" self.pid = os.fork()\n",
537437
"Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 5 seconds.\n",
538-
"There were 19 divergences after tuning. Increase `target_accept` or reparameterize.\n",
438+
"There were 9 divergences after tuning. Increase `target_accept` or reparameterize.\n",
539439
"The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details\n",
540440
"The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details\n"
541441
]
@@ -548,7 +448,7 @@
548448
" sigma = pm.HalfNormal(\"sigma\", 1, shape=2)\n",
549449
" obs = pm.Normal(\"obs\", mu=mu[cluster], sigma=sigma[cluster], observed=rng.normal(0, 1, 100))\n",
550450
"\n",
551-
" trace_discrete = pm.sample()"
451+
" trace_discrete = pm.sample(progressbar=False)"
552452
]
553453
},
554454
{
@@ -570,20 +470,24 @@
570470
"name": "stdout",
571471
"output_type": "stream",
572472
"text": [
573-
"Last updated: Sat May 24 2025\n",
473+
"Last updated: Mon May 26 2025\n",
574474
"\n",
575475
"Python implementation: CPython\n",
576476
"Python version : 3.12.10\n",
577477
"IPython version : 9.2.0\n",
578478
"\n",
579479
"pytensor: 2.30.3\n",
580-
"aeppl : not installed\n",
581-
"xarray : 2025.4.0\n",
480+
"arviz : 0.21.0\n",
481+
"pymc : 5.22.0\n",
482+
"numpyro : 0.18.0\n",
483+
"blackjax: 0.0.0\n",
484+
"nutpie : 0.14.3\n",
582485
"\n",
583-
"numpy : 2.2.6\n",
584486
"pymc : 5.22.0\n",
585-
"matplotlib: 3.10.3\n",
586487
"arviz : 0.21.0\n",
488+
"platform : 1.0.8\n",
489+
"numpy : 2.2.6\n",
490+
"matplotlib: 3.10.3\n",
587491
"\n",
588492
"Watermark: 2.5.0\n",
589493
"\n"
@@ -592,7 +496,7 @@
592496
],
593497
"source": [
594498
"%load_ext watermark\n",
595-
"%watermark -n -u -v -iv -w -p pytensor,aeppl,xarray"
499+
"%watermark -n -u -v -iv -w -p pytensor,arviz,pymc,numpyro,blackjax,nutpie"
596500
]
597501
},
598502
{

0 commit comments

Comments
 (0)