Skip to content

Commit 4affc28

Browse files
JaimeRZPyebaiCompatHelper Juliaharisorgntorfjelde
authored
MH Constructor (#2037)
* first draft * abstractcontext + tests * bug * externalsampler() in tests * Name Tupple problems * moving stuff to DynamicPPL RP * using new DynamicPPL PR * mistakenly removed line * specific constructors * no StaticMH RWMH * Bump bijectors compat (#2052) * CompatHelper: bump compat for Bijectors to 0.13, (keep existing compat) * Update Project.toml * Replacement for #2039 (#2040) * Fix testset for external samplers * Update abstractmcmc.jl * Update test/contrib/inference/abstractmcmc.jl Co-authored-by: Tor Erlend Fjelde <[email protected]> * Update test/contrib/inference/abstractmcmc.jl Co-authored-by: Tor Erlend Fjelde <[email protected]> * Update FillArrays compat to 1.4.1 (#2035) * Update FillArrays compat to 1.4.0 * Update test compat * Try to enable ReverseDiff tests * Update Project.toml * Update Project.toml * Bump version * Revert dependencies on FillArrays (#2042) * Update Project.toml * Update Project.toml * Fix redundant definition of `getstats` (#2044) * Fix redundant definition of `getstats` * Update Inference.jl * Revert "Update Inference.jl" This reverts commit e4f51c2. * Bump version --------- Co-authored-by: Hong Ge <[email protected]> * Transfer some test utility function into DynamicPPL (#2049) * Update OptimInterface.jl * Only run optimisation tests in numerical stage. * fix function lookup after moving functions --------- Co-authored-by: Xianda Sun <[email protected]> * Move Optim support to extension (#2051) * Move Optim support to extension * More imports * Update Project.toml --------- Co-authored-by: Hong Ge <[email protected]> --------- Co-authored-by: CompatHelper Julia <[email protected]> Co-authored-by: haris organtzidis <[email protected]> Co-authored-by: Tor Erlend Fjelde <[email protected]> Co-authored-by: David Widmann <[email protected]> Co-authored-by: Xianda Sun <[email protected]> Co-authored-by: Cameron Pfiffer <[email protected]> * Bugfixes. * Add TODO. * Update mh.jl * Update Inference.jl * Removed obsolete exports. * removed unnecessary import of extract_priors * added missing ) in MH tests * fixed incorrect referneces to AdvancedMH in tests * improve ESLogDensityFunction * remove hardcoding of SimpleVarInfo * added fixme comment * minor style changes * fixed issues with MH with RandomWalkProposal being used as an external sampler * fixed accidental typo * move definitions of unflatten for NamedTuple * improved TODO * Update Project.toml --------- Co-authored-by: Hong Ge <[email protected]> Co-authored-by: CompatHelper Julia <[email protected]> Co-authored-by: haris organtzidis <[email protected]> Co-authored-by: Tor Erlend Fjelde <[email protected]> Co-authored-by: David Widmann <[email protected]> Co-authored-by: Xianda Sun <[email protected]> Co-authored-by: Cameron Pfiffer <[email protected]> Co-authored-by: Hong Ge <[email protected]>
1 parent d8beaf0 commit 4affc28

File tree

4 files changed

+39
-1
lines changed

4 files changed

+39
-1
lines changed

src/Turing.jl

-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ export @model, # modelling
7373
Prior, # Sampling from the prior
7474

7575
MH, # classic sampling
76-
RWMH,
7776
Emcee,
7877
ESS,
7978
Gibbs,

src/inference/Inference.jl

+18
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,24 @@ Wrap a sampler so it can be used as an inference algorithm.
9999
"""
100100
externalsampler(sampler::AbstractSampler) = ExternalSampler(sampler)
101101

102+
"""
103+
ESLogDensityFunction
104+
105+
A log density function for the External sampler.
106+
107+
"""
108+
const ESLogDensityFunction{M<:Model,S<:Sampler{<:ExternalSampler},V<:AbstractVarInfo} = Turing.LogDensityFunction{V,M,<:DynamicPPL.DefaultContext}
109+
function LogDensityProblems.logdensity(f::ESLogDensityFunction, x::NamedTuple)
110+
return DynamicPPL.logjoint(f.model, DynamicPPL.unflatten(f.varinfo, x))
111+
end
112+
113+
# TODO: make a nicer `set_namedtuple!` and move these functions to DynamicPPL.
114+
function DynamicPPL.unflatten(vi::TypedVarInfo, θ::NamedTuple)
115+
set_namedtuple!(deepcopy(vi), θ)
116+
return vi
117+
end
118+
DynamicPPL.unflatten(vi::SimpleVarInfo, θ::NamedTuple) = SimpleVarInfo(θ, vi.logp, vi.transformation)
119+
102120
# Algorithm for sampling from the prior
103121
struct Prior <: InferenceAlgorithm end
104122

src/inference/mh.jl

+15
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,20 @@ function MH(space...)
188188
return MH{tuple(syms...), typeof(proposals)}(proposals)
189189
end
190190

191+
# Some of the proposals require working in unconstrained space.
192+
transform_maybe(proposal::AMH.Proposal) = proposal
193+
function transform_maybe(proposal::AMH.RandomWalkProposal)
194+
return AMH.RandomWalkProposal(Bijectors.transformed(proposal.proposal))
195+
end
196+
197+
function MH(model::Model; proposal_type=AMH.StaticProposal)
198+
priors = DynamicPPL.extract_priors(model)
199+
props = Tuple([proposal_type(prop) for prop in values(priors)])
200+
vars = Tuple(map(Symbol, collect(keys(priors))))
201+
priors = map(transform_maybe, NamedTuple{vars}(props))
202+
return AMH.MetropolisHastings(priors)
203+
end
204+
191205
#####################
192206
# Utility functions #
193207
#####################
@@ -346,6 +360,7 @@ end
346360
function should_link(varinfo, sampler, proposal::AdvancedMH.RandomWalkProposal)
347361
return true
348362
end
363+
# FIXME: This won't be hit unless `vals` are all the exactly same concrete type of `AdvancedMH.RandomWalkProposal`!
349364
function should_link(
350365
varinfo,
351366
sampler,

test/inference/mh.jl

+6
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@
1717

1818
s4 = Gibbs(MH(:m), MH(:s))
1919
c4 = sample(gdemo_default, s4, N)
20+
21+
s5 = externalsampler(MH(gdemo_default, proposal_type=AdvancedMH.RandomWalkProposal))
22+
c5 = sample(gdemo_default, s5, N)
23+
24+
s6 = externalsampler(MH(gdemo_default, proposal_type=AdvancedMH.StaticProposal))
25+
c6 = sample(gdemo_default, s6, N)
2026
end
2127
@numerical_testset "mh inference" begin
2228
Random.seed!(125)

0 commit comments

Comments
 (0)