Skip to content

Commit 485c26e

Browse files
authored
Update of GLM model selection notebook to v5 (#761)
* Update of GLM model selection notebook to v5 * Updated authorship * Remove stray blank cell
1 parent 9857594 commit 485c26e

File tree

4 files changed

+686
-810
lines changed

4 files changed

+686
-810
lines changed

examples/generalized_linear_models/GLM-model-selection.ipynb

Lines changed: 627 additions & 792 deletions
Large diffs are not rendered by default.

examples/generalized_linear_models/GLM-model-selection.myst.md

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@ import bambi as bmb
2525
import matplotlib.pyplot as plt
2626
import numpy as np
2727
import pandas as pd
28-
import pymc3 as pm
28+
import pymc as pm
2929
import seaborn as sns
3030
import xarray as xr
3131
3232
from ipywidgets import fixed, interactive
3333
34-
print(f"Running on PyMC3 v{pm.__version__}")
34+
print(f"Running on PyMC v{pm.__version__}")
3535
```
3636

3737
```{code-cell} ipython3
@@ -44,7 +44,7 @@ plt.rcParams["figure.constrained_layout.use"] = False
4444
```
4545

4646
## Introduction
47-
A fairly minimal reproducible example of Model Selection using WAIC, and LOO as currently implemented in PyMC3.
47+
A fairly minimal reproducible example of Model Selection using WAIC, and LOO as currently implemented in PyMC.
4848

4949
This example creates two toy datasets under linear and quadratic models, and then tests the fit of a range of polynomial linear models upon those datasets by using Widely Applicable Information Criterion (WAIC), and leave-one-out (LOO) cross-validation using Pareto-smoothed importance sampling (PSIS).
5050

@@ -198,12 +198,18 @@ def plot_posterior_cr(models, idatas, rawdata, xlims, datamodelnm="linear", mode
198198
# Get traces and calc posterior prediction for npoints in x
199199
npoints = 100
200200
mdl = models[modelnm]
201-
trc = idatas[modelnm].posterior.copy().drop_vars("y_sigma")
202-
da = xr.concat([var for var in trc.values()], dim="order")
201+
trc = idatas[modelnm].posterior.copy()
203202
204-
ordr = int(modelnm[-1:])
203+
# Extract variables and stack them in correct order
204+
vars_to_concat = []
205+
for var in ["Intercept", "x"] + [f"np.power(x, {i})" for i in range(2, int(modelnm[-1:]) + 1)]:
206+
if var in trc:
207+
vars_to_concat.append(trc[var])
208+
da = xr.concat(vars_to_concat, dim="order")
209+
210+
ordr = len(vars_to_concat)
205211
x = xr.DataArray(np.linspace(xlims[0], xlims[1], npoints), dims=["x_plot"])
206-
pwrs = xr.DataArray(np.arange(ordr + 1), dims=["order"])
212+
pwrs = xr.DataArray(np.arange(ordr), dims=["order"])
207213
X = x**pwrs
208214
cr = xr.dot(X, da, dims="order")
209215
@@ -337,7 +343,7 @@ $$y = a + bx + \epsilon$$
337343

338344
+++
339345

340-
### Define model using explicit PyMC3 method
346+
### Define model using explicit PyMC method
341347

342348
```{code-cell} ipython3
343349
with pm.Model() as mdl_ols:
@@ -417,7 +423,7 @@ def create_poly_modelspec(k=1):
417423
def run_models(df, upper_order=5):
418424
"""
419425
Convenience function:
420-
Fit a range of pymc3 models of increasing polynomial complexity.
426+
Fit a range of pymc models of increasing polynomial complexity.
421427
Suggest limit to max order 5 since calculation time is exponential.
422428
"""
423429
@@ -432,7 +438,9 @@ def run_models(df, upper_order=5):
432438
models[nm] = bmb.Model(
433439
fml, df, priors={"intercept": bmb.Prior("Normal", mu=0, sigma=100)}, family="gaussian"
434440
)
435-
results[nm] = models[nm].fit(draws=2000, tune=1000, init="advi+adapt_diag")
441+
results[nm] = models[nm].fit(
442+
draws=2000, tune=1000, init="advi+adapt_diag", idata_kwargs={"log_likelihood": True}
443+
)
436444
437445
return models, results
438446
```
@@ -499,11 +507,11 @@ dfwaic_quad
499507
_, axs = plt.subplots(1, 2)
500508
501509
ax = axs[0]
502-
az.plot_compare(dfwaic_lin, ax=ax)
510+
az.plot_compare(dfwaic_lin, ax=ax, legend=False)
503511
ax.set_title("Linear data")
504512
505513
ax = axs[1]
506-
az.plot_compare(dfwaic_quad, ax=ax)
514+
az.plot_compare(dfwaic_quad, ax=ax, legend=False)
507515
ax.set_title("Quadratic data");
508516
```
509517

@@ -545,11 +553,11 @@ dfloo_quad
545553
_, axs = plt.subplots(1, 2)
546554
547555
ax = axs[0]
548-
az.plot_compare(dfloo_lin, ax=ax)
556+
az.plot_compare(dfloo_lin, ax=ax, legend=False)
549557
ax.set_title("Linear data")
550558
551559
ax = axs[1]
552-
az.plot_compare(dfloo_quad, ax=ax)
560+
az.plot_compare(dfloo_quad, ax=ax, legend=False)
553561
ax.set_title("Quadratic data");
554562
```
555563

@@ -601,6 +609,7 @@ spiegelhalter2002bayesian
601609
* Re-executed by Alex Andorra and Michael Osthege on June, 2020 ([pymc#3955](https://github.com/pymc-devs/pymc/pull/3955))
602610
* Updated by Raul Maldonado on March, 2021 ([pymc-examples#24](https://github.com/pymc-devs/pymc-examples/pull/24))
603611
* Updated by Abhipsha Das and Oriol Abril on June, 2021 ([pymc-examples#173](https://github.com/pymc-devs/pymc-examples/pull/173))
612+
* Updated by Chris Fonnesbeck on December, 2024 ([pymc-examples#761](https://github.com/pymc-devs/pymc-examples/pull/761))
604613

605614
+++
606615

@@ -613,7 +622,3 @@ spiegelhalter2002bayesian
613622

614623
:::{include} ../page_footer.md
615624
:::
616-
617-
```{code-cell} ipython3
618-
619-
```

pixi.lock

Lines changed: 35 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pixi.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ nutpie = ">=0.13.2,<0.14"
2727
numba = ">=0.60.0,<0.61"
2828
scikit-learn = ">=1.5.2,<2"
2929
blackjax = ">=1.2.3,<2"
30+
bambi = ">=0.15.0,<0.16"
3031

3132
[pypi-dependencies]
3233
pymc-experimental = ">=0.1.2, <0.2"

0 commit comments

Comments
 (0)