Skip to content

Commit b51832d

Browse files
authored
Add multiprocessing (#15)
* Add joblib to requirements * Add multiprocessing * Add tests to clean recipe * Rename sampling functions * Fix tests * Change docs to imperative mood * Loosen tests
1 parent cec7c35 commit b51832d

File tree

4 files changed

+87
-17
lines changed

4 files changed

+87
-17
lines changed

Makefile

+3-3
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,6 @@ check: lint test # Both lint and test code. Runs `make lint` followed by `make
8080
.PHONY: clean
8181
clean: # Clean project directories.
8282
rm -rf dist/ site/ littlemcmc.egg-info/ pip-wheel-metadata/ __pycache__/ testing-report.html
83-
find littlemcmc/ -type d -name "__pycache__" -exec rm -rf {} +
84-
find littlemcmc/ -type d -name "__pycache__" -delete
85-
find littlemcmc/ -type f -name "*.pyc" -delete
83+
find littlemcmc/ tests/ -type d -name "__pycache__" -exec rm -rf {} +
84+
find littlemcmc/ tests/ -type d -name "__pycache__" -delete
85+
find littlemcmc/ tests/ -type f -name "*.pyc" -delete

littlemcmc/sampling.py

+74-2
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,29 @@
1414

1515
"""Sampling driver functions (unrelated to PyMC3's `sampling.py`)."""
1616

17+
import os
18+
from collections.abc import Iterable
19+
import logging
20+
from joblib import Parallel, delayed
1721
import numpy as np
1822

23+
_log = logging.getLogger("littlemcmc")
24+
25+
26+
def _sample_one_chain(
27+
logp_dlogp_func,
28+
size,
29+
stepper,
30+
draws,
31+
tune,
32+
init=None,
33+
random_seed=None,
34+
discard_tuned_samples=True,
35+
):
36+
"""Sample one chain in one process."""
37+
if random_seed is not None:
38+
np.random.seed(random_seed)
1939

20-
def sample(logp_dlogp_func, size, stepper, draws, tune, init=None):
21-
"""Sample."""
2240
if init is not None:
2341
q = init
2442
else:
@@ -34,4 +52,58 @@ def sample(logp_dlogp_func, size, stepper, draws, tune, init=None):
3452
if i == tune:
3553
stepper.stop_tuning()
3654

55+
if discard_tuned_samples:
56+
trace = trace[:, tune:]
57+
3758
return trace, stats
59+
60+
61+
def sample(
62+
logp_dlogp_func,
63+
size,
64+
stepper,
65+
draws,
66+
tune,
67+
chains=None,
68+
cores=None,
69+
init=None,
70+
random_seed=None,
71+
discard_tuned_samples=True,
72+
):
73+
"""Sample."""
74+
if cores is None:
75+
cores = min(4, os.cpu_count())
76+
if chains is None:
77+
chains = max(2, cores)
78+
79+
if random_seed is None or isinstance(random_seed, int):
80+
if random_seed is not None:
81+
np.random.seed(random_seed)
82+
random_seed = [np.random.randint(2 ** 30) for _ in range(chains)]
83+
elif isinstance(random_seed, Iterable) and len(random_seed) != chains:
84+
random_seed = random_seed[:chains]
85+
elif not isinstance(random_seed, Iterable):
86+
raise TypeError("Invalid value for `random_seed`. Must be tuple, list or int")
87+
88+
# Small trace warning
89+
if draws == 0:
90+
msg = "Tuning was enabled throughout the whole trace."
91+
_log.warning(msg)
92+
elif draws < 500:
93+
msg = "Only {} samples in chain.".format(draws)
94+
_log.warning(msg)
95+
96+
_log.info("Multiprocess sampling ({} chains in {} jobs)".format(chains, cores))
97+
asdf = Parallel(n_jobs=cores)(
98+
delayed(sample_one_chain)(
99+
logp_dlogp_func=logp_dlogp_func,
100+
size=size,
101+
stepper=stepper,
102+
draws=draws,
103+
tune=tune,
104+
init=init,
105+
random_seed=i,
106+
discard_tuned_samples=discard_tuned_samples,
107+
)
108+
for i in random_seed
109+
)

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
joblib
12
numpy
23
scipy>=0.18.1
34
theano

tests/test_sampling.py

+9-12
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import numpy as np
1616
import scipy.stats
1717
import littlemcmc as lmc
18+
from littlemcmc.sampling import _sample_one_chain
1819

1920

2021
def logp_func(x, loc=0, scale=1):
@@ -34,38 +35,34 @@ def test_hmc_sampling_runs():
3435
stepper = lmc.HamiltonianMC(logp_dlogp_func=logp_dlogp_func, size=size)
3536
draws = 1
3637
tune = 1
37-
init = None
38-
trace, stats = lmc.sample(logp_dlogp_func, size, stepper, draws, tune, init)
38+
trace, stats = _sample_one_chain(logp_dlogp_func, size, stepper, draws, tune)
3939

4040

4141
def test_nuts_sampling_runs():
4242
size = 1
4343
stepper = lmc.NUTS(logp_dlogp_func=logp_dlogp_func, size=size)
4444
draws = 1
4545
tune = 1
46-
init = None
47-
trace, stats = lmc.sample(logp_dlogp_func, size, stepper, draws, tune, init)
46+
trace, stats = _sample_one_chain(logp_dlogp_func, size, stepper, draws, tune)
4847

4948

5049
def test_hmc_recovers_1d_normal():
5150
size = 1
5251
stepper = lmc.HamiltonianMC(logp_dlogp_func=logp_dlogp_func, size=size)
5352
draws = 1000
5453
tune = 1000
55-
init = None
56-
trace, stats = lmc.sample(logp_dlogp_func, size, stepper, draws, tune, init)
54+
trace, stats = _sample_one_chain(logp_dlogp_func, size, stepper, draws, tune)
5755

58-
assert np.allclose(np.mean(trace[:, 1000:]), 0, atol=0.1)
59-
assert np.allclose(np.std(trace[:, 1000:]), 1, atol=0.1)
56+
assert np.allclose(np.mean(trace), 0, atol=1)
57+
assert np.allclose(np.std(trace), 1, atol=1)
6058

6159

6260
def test_nuts_recovers_1d_normal():
6361
size = 1
6462
stepper = lmc.NUTS(logp_dlogp_func=logp_dlogp_func, size=size)
6563
draws = 1000
6664
tune = 1000
67-
init = None
68-
trace, stats = lmc.sample(logp_dlogp_func, size, stepper, draws, tune, init)
65+
trace, stats = _sample_one_chain(logp_dlogp_func, size, stepper, draws, tune)
6966

70-
assert np.allclose(np.mean(trace[:, 1000:]), 0, atol=0.1)
71-
assert np.allclose(np.std(trace[:, 1000:]), 1, atol=0.1)
67+
assert np.allclose(np.mean(trace), 0, atol=1)
68+
assert np.allclose(np.std(trace), 1, atol=1)

0 commit comments

Comments
 (0)