Skip to content

Commit 2ada47c

Browse files
ErikQQYyebaigithub-actions[bot]
authored
Unify argument order in phasepoint and transition (#435)
* Unify argument order in phasepoint and transition * Use Test in quality tests * Revert change to Test * Try with Test * Try with Test * Use Test namespace * Retrigger CUDA test * Skip CUDA tests when no CUDA devices are found. (#436) * Update cuda.jl * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Revert phasepoint unifying * Bump compat for docs * Update changelog * Better changelog --------- Co-authored-by: Hong Ge <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 01d31dd commit 2ada47c

File tree

9 files changed

+59
-51
lines changed

9 files changed

+59
-51
lines changed

HISTORY.md

+4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# AdvancedHMC Changelog
22

3+
## 0.8.0
4+
5+
- To make an MCMC transtion from phasepoint `z` using trajectory `τ`(or HMCKernel `κ`) under Hamiltonian `h`, use `transition(h, τ, z)` or `transition(rng, h, τ, z)`(if using HMCKernel, use `transition(h, κ, z)` or `transition(rng, h, κ, z)`).
6+
37
## v0.7.1
48

59
- README has been simplified, many docs transfered to docs: https://turinglang.org/AdvancedHMC.jl/dev/.

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedHMC"
22
uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
3-
version = "0.7.1"
3+
version = "0.8.0"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

docs/Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
44
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
55

66
[compat]
7-
AdvancedHMC = "0.7"
7+
AdvancedHMC = "0.8"
88
Documenter = "1"
99
DocumenterCitations = "1"

src/sampler.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ function transition(
5454
(; refreshment, τ) = κ
5555
@set! τ.integrator = jitter(rng, τ.integrator)
5656
z = refresh(rng, refreshment, h, z)
57-
return transition(rng, τ, h, z)
57+
return transition(rng, h, τ, z)
5858
end
5959

6060
function Adaptation.adapt!(

src/trajectory.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -244,10 +244,10 @@ $(SIGNATURES)
244244
245245
Make a MCMC transition from phase point `z` using the trajectory `τ` under Hamiltonian `h`.
246246
247-
NOTE: This is a RNG-implicit fallback function for `transition(Random.default_rng(), τ, h, z)`
247+
NOTE: This is a RNG-implicit fallback function for `transition(Random.default_rng(), h, τ, z)`
248248
"""
249-
function transition(τ::Trajectory, h::Hamiltonian, z::PhasePoint)
250-
return transition(Random.default_rng(), τ, h, z)
249+
function transition(h::Hamiltonian, τ::Trajectory, z::PhasePoint)
250+
return transition(Random.default_rng(), h, τ, z)
251251
end
252252

253253
###
@@ -256,8 +256,8 @@ end
256256

257257
function transition(
258258
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
259-
τ::Trajectory{TS,I,TC},
260259
h::Hamiltonian,
260+
τ::Trajectory{TS,I,TC},
261261
z::PhasePoint,
262262
) where {TS<:AbstractTrajectorySampler,I,TC<:StaticTerminationCriterion}
263263
H0 = energy(z)
@@ -665,7 +665,7 @@ function build_tree(
665665
end
666666

667667
function transition(
668-
rng::AbstractRNG, τ::Trajectory{TS,I,TC}, h::Hamiltonian, z0::PhasePoint
668+
rng::AbstractRNG, h::Hamiltonian, τ::Trajectory{TS,I,TC}, z0::PhasePoint
669669
) where {
670670
TS<:AbstractTrajectorySampler,I<:AbstractIntegrator,TC<:DynamicTerminationCriterion
671671
}

test/CUDA/cuda.jl

+41-38
Original file line numberDiff line numberDiff line change
@@ -11,47 +11,50 @@ using LogDensityProblems
1111
include(joinpath(@__DIR__, "..", "common.jl"))
1212

1313
@testset "AdvancedHMC GPU" begin
14-
n_chains = 1000
15-
n_samples = 1000
16-
dim = 5
17-
18-
T = Float32
19-
m, s, θ₀ = zeros(T, dim), ones(T, dim), rand(T, dim, n_chains)
20-
m, s, θ₀ = CuArray(m), CuArray(s), CuArray(θ₀)
21-
22-
target = Gaussian(m, s)
23-
metric = UnitEuclideanMetric(T, size(θ₀))
24-
ℓπ, ∇ℓπ = get_ℓπ(target), get_∇ℓπ(target)
25-
hamiltonian = Hamiltonian(metric, ℓπ, ∇ℓπ)
26-
integrator = Leapfrog(one(T) / 5)
27-
proposal = HMCKernel(Trajectory{EndPointTS}(integrator, FixedNSteps(5)))
28-
29-
samples, stats = sample(hamiltonian, proposal, θ₀, n_samples)
14+
if CUDA.functional()
15+
n_chains = 1000
16+
n_samples = 1000
17+
dim = 5
18+
T = Float32
19+
m, s, θ₀ = zeros(T, dim), ones(T, dim), rand(T, dim, n_chains)
20+
m, s, θ₀ = CuArray(m), CuArray(s), CuArray(θ₀)
21+
target = Gaussian(m, s)
22+
metric = UnitEuclideanMetric(T, size(θ₀))
23+
ℓπ, ∇ℓπ = get_ℓπ(target), get_∇ℓπ(target)
24+
hamiltonian = Hamiltonian(metric, ℓπ, ∇ℓπ)
25+
integrator = Leapfrog(one(T) / 5)
26+
proposal = HMCKernel(Trajectory{EndPointTS}(integrator, FixedNSteps(5)))
27+
samples, stats = sample(hamiltonian, proposal, θ₀, n_samples)
28+
else
29+
println("GPU tests are skipped because no CUDA devices are found.")
30+
end
3031
end
3132

3233
@testset "PhasePoint GPU" begin
33-
for T in [Float32, Float64]
34-
function init_z1()
35-
return PhasePoint(
36-
CuArray([T(NaN) T(NaN)]),
37-
CuArray([T(NaN) T(NaN)]),
38-
DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))),
39-
DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))),
40-
)
34+
if CUDA.functional()
35+
for T in [Float32, Float64]
36+
function init_z1()
37+
return PhasePoint(
38+
CuArray([T(NaN) T(NaN)]),
39+
CuArray([T(NaN) T(NaN)]),
40+
DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))),
41+
DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))),
42+
)
43+
end
44+
function init_z2()
45+
return PhasePoint(
46+
CuArray([T(Inf) T(Inf)]),
47+
CuArray([T(Inf) T(Inf)]),
48+
DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))),
49+
DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))),
50+
)
51+
end
52+
z1 = init_z1()
53+
z2 = init_z2()
54+
@test z1.ℓπ.value == z2.ℓπ.value
55+
@test z1.ℓκ.value == z2.ℓκ.value
4156
end
42-
function init_z2()
43-
return PhasePoint(
44-
CuArray([T(Inf) T(Inf)]),
45-
CuArray([T(Inf) T(Inf)]),
46-
DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))),
47-
DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))),
48-
)
49-
end
50-
51-
z1 = init_z1()
52-
z2 = init_z2()
53-
54-
@test z1.ℓπ.value == z2.ℓπ.value
55-
@test z1.ℓκ.value == z2.ℓκ.value
57+
else
58+
println("GPU tests are skipped because no CUDA devices are found.")
5659
end
5760
end

test/Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2020
ReTest = "e0db7c4e-2690-44b9-bad6-7687da720f89"
2121
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2222
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
23+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2324
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
2425
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2526

test/quality.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
using AdvancedHMC
2-
using ReTest
2+
using Test: Test
33
using Aqua: Aqua
44
using JET
55
using ForwardDiff
66

7-
@testset "Aqua" begin
7+
Test.@testset "Aqua" begin
88
Aqua.test_all(AdvancedHMC)
99
end
1010

11-
@testset "JET" begin
11+
Test.@testset "JET" begin
1212
JET.test_package(AdvancedHMC; target_defined_modules=true)
1313
end

test/trajectory.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,11 @@ end
129129
for τ_test in [τ, τ_with_jittered_lf], seed in [1234, 5678, 90]
130130
rng = MersenneTwister(seed)
131131
z = AdvancedHMC.phasepoint(h, θ_init, r_init)
132-
z1′ = AdvancedHMC.transition(rng, τ_test, h, z).z
132+
z1′ = AdvancedHMC.transition(rng, h, τ_test, z).z
133133

134134
rng = MersenneTwister(seed)
135135
z = AdvancedHMC.phasepoint(h, θ_init, r_init)
136-
z2′ = AdvancedHMC.transition(rng, τ_test, h, z).z
136+
z2′ = AdvancedHMC.transition(rng, h, τ_test, z).z
137137

138138
@test z1′.θ == z2′.θ
139139
@test z1′.r == z2′.r

0 commit comments

Comments
 (0)