Open
Description
There was a slight performance regression caused by the switch to LazyBufferCache for the Enzyme vecjacobian!
. The change was necessary to fix some issues caused by EnzymeAD/Enzyme.jl#2382 , but if that gets fixed, we should go back to using FixedSizeDiffCache.
using Zygote, SciMLSensitivity
using OrdinaryDiffEq, ForwardDiff, Test
using BenchmarkTools
p = rand(3)
function dudt(u, p, t)
u .* p
end
function loss(p, sensealg)
prob = ODEProblem(dudt, [3.0, 2.0, 1.0], (0.0, 10.0), p)
sol = solve(prob, ImplicitEuler(), dt=0.01, saveat=0.1, sensealg=sensealg,
abstol=1e-5, reltol=1e-5)
sum(abs2, Array(sol))
end
@btime Zygote.gradient(
p -> loss(p, QuadratureAdjoint(autojacvec=EnzymeVJP())), p)
function lv(du, u, p, t)
du[1] = p[1] * u[1] - p[2] * u[1] * u[2]
du[2] = -3 * u[2] + u[1] * u[2]
end
function loss_lv(p,sensealg)
prob_lv = ODEProblem(lv, [1.0, 2.0], (0.0, 10.0), p)
sol = solve(prob_lv, ImplicitEuler(), dt=0.01, saveat=0.1, sensealg=sensealg,
abstol=1e-5, reltol=1e-5)
sum(abs2, Array(sol))
end
@btime Zygote.gradient(
p -> loss_lv(p, QuadratureAdjoint(autojacvec=EnzymeVJP())), [1.0, 2.0])
function vanderpol(du, u, p, t)
du[1] = u[2]
du[2] = p[1]*((1-u[1]^2)*u[2] - u[1])
end
function loss_vp(p, sensealg)
prob_vp = ODEProblem(vanderpol, [1.0, 2.0], (0.0, 10.0), p)
sol = solve(prob_vp, Rodas5(), dt=0.01, saveat=0.1, sensealg=sensealg,
abstol=1e-5, reltol=1e-5)
sum(abs2, Array(sol))
end
@btime Zygote.gradient(
p -> loss_vp(p, QuadratureAdjoint(autojacvec=EnzymeVJP())), [20.0])
@btime
results:
Metadata
Metadata
Assignees
Labels
No labels