Skip to content

Commit e3a56ed

Browse files
authored
Revive GPU tests (#430)
1 parent 7067a90 commit e3a56ed

File tree

5 files changed

+60
-33
lines changed

5 files changed

+60
-33
lines changed

.buildkite/pipeline.yml

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
env:
2+
# SECRET_CODECOV_TOKEN can be added here if needed for coverage reporting
3+
4+
steps:
5+
- label: "Julia v{{matrix.version}}, {{matrix.label}}"
6+
plugins:
7+
- JuliaCI/julia#v1:
8+
version: "{{matrix.version}}"
9+
# - JuliaCI/julia-coverage#v1:
10+
# dirs:
11+
# - src
12+
# - ext
13+
command: julia --eval='println(pwd()); println(readdir()); include("test/CUDA/cuda.jl")'
14+
agents:
15+
queue: "juliagpu"
16+
cuda: "*"
17+
if: build.message !~ /\[skip tests\]/
18+
timeout_in_minutes: 60
19+
env:
20+
LABEL: "{{matrix.label}}"
21+
TEST_TYPE: ext
22+
matrix:
23+
setup:
24+
version:
25+
- "1"
26+
- "1.10"
27+
label:
28+
- "cuda"

test/CUDA/Project.toml

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
[deps]
2+
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
3+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
4+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
5+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
6+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
7+
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
8+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
+23-25
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
using ReTest
1+
using Pkg
2+
Pkg.activate(@__DIR__)
3+
Pkg.develop(; path=joinpath(@__DIR__, "..", ".."))
4+
5+
include(joinpath(@__DIR__, "..", "common.jl"))
6+
7+
using Test
28
using AdvancedHMC
39
using AdvancedHMC: DualValue, PhasePoint
410
using CUDA
@@ -22,31 +28,24 @@ using CUDA
2228
samples, stats = sample(hamiltonian, proposal, θ₀, n_samples)
2329
end
2430

25-
#=
26-
Broken! See https://github.com/JuliaTesting/ReTest.jl/issues/50
2731
@testset "PhasePoint GPU" begin
2832
for T in [Float32, Float64]
29-
init_z1() = PhasePoint(
30-
CuArray([T(NaN) T(NaN)]),
31-
CuArray([T(NaN) T(NaN)]),
32-
DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))),
33-
DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))),
34-
)
35-
init_z2() = PhasePoint(
36-
CuArray([T(Inf) T(Inf)]),
37-
CuArray([T(Inf) T(Inf)]),
38-
DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))),
39-
DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))),
40-
)
41-
42-
@test_logs (
43-
:warn,
44-
"The current proposal will be rejected due to numerical error(s).",
45-
) init_z1()
46-
@test_logs (
47-
:warn,
48-
"The current proposal will be rejected due to numerical error(s).",
49-
) init_z2()
33+
function init_z1()
34+
return PhasePoint(
35+
CuArray([T(NaN) T(NaN)]),
36+
CuArray([T(NaN) T(NaN)]),
37+
DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))),
38+
DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))),
39+
)
40+
end
41+
function init_z2()
42+
return PhasePoint(
43+
CuArray([T(Inf) T(Inf)]),
44+
CuArray([T(Inf) T(Inf)]),
45+
DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))),
46+
DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))),
47+
)
48+
end
5049

5150
z1 = init_z1()
5251
z2 = init_z2()
@@ -55,4 +54,3 @@ Broken! See https://github.com/JuliaTesting/ReTest.jl/issues/50
5554
@test z1.ℓκ.value == z2.ℓκ.value
5655
end
5756
end
58-
=#

test/Project.toml

-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
33
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
44
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
55
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
6-
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
76
Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542"
87
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
98
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"

test/runtests.jl

+1-7
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ const GROUP = get(ENV, "AHMC_TEST_GROUP", "AdvancedHMC")
1717
include("common.jl")
1818

1919
if GROUP == "All" || GROUP == "AdvancedHMC"
20-
using ReTest, CUDA
20+
using ReTest
2121

2222
include("aqua.jl")
2323
include("metric.jl")
@@ -33,12 +33,6 @@ if GROUP == "All" || GROUP == "AdvancedHMC"
3333
include("mcmcchains.jl")
3434
include("constructors.jl")
3535

36-
if CUDA.functional()
37-
include("cuda.jl")
38-
else
39-
@warn "Skipping GPU tests because no GPU available."
40-
end
41-
4236
Comonicon.@main function runtests(patterns...; dry::Bool=false)
4337
return retest(patterns...; dry=dry, verbose=Inf)
4438
end

0 commit comments

Comments
 (0)