SciMLSensitivity.jl
SciMLSensitivity.jl copied to clipboard
Continuous-adjoint methods for diagonal-noise SDEs scale in the square of number of dimensions
Hi, this line is scalar indexing in a pullback: https://github.com/SciML/SciMLSensitivity.jl/blob/1997fb1a2c288f3da37f61c7b0894eb4e42c5cd6/src/derivative_wrappers.jl#L899 This means you can't diff on a GPU in this case, as scalar indexing is not allowed. Excerpt from the error:
Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.
error(::String)@error.jl:35
assertscalar(::String)@GPUArraysCore.jl:103
[email protected]:9[inlined]
[email protected]:44[inlined]
_pullback(::Zygote.Context{false}, ::typeof(getindex), ::CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::Int64)@adjoint.jl:66
_pullback@derivative_wrappers.jl:899[inlined]
MWE:
using Lux, Zygote, DifferentialEquations, ComponentArrays, Random, CUDA, SciMLSensitivity
rng = Xoshiro()
drift_net = Dense(2 => 2)
diffusion_net = Dense(2 => 2)
ps_drift_, st_drift = Lux.setup(rng, drift_net)
ps_diffusion_, st_diffusion = Lux.setup(rng, diffusion_net)
ps_ = ComponentArray((ps_drift=ps_drift_,ps_diffusion=ps_diffusion_)) |> Lux.gpu
function drift(u, ps, t)
drift_net(u, ps.ps_drift, st_drift)[1]
end
function diffusion(u, ps, t)
diffusion_net(u, ps.ps_diffusion, st_diffusion)[1]
end
u0 = [1f0, 1f0] |> Lux.gpu
tspan = (0f0, 1f0)
datasize = 10
solver = EulerHeun()
sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP())
function loss(ps)
problem = SDEProblem(drift, diffusion, u0, tspan, ps)
solution = solve(problem, solver; sensealg=sensealg, saveat=collect(range(tspan[1], tspan[end], datasize)), dt=(tspan[end] / datasize))
return sum(vec(solution |> Lux.gpu))
end
println(loss(ps_))
println(Zygote.gradient(ps -> loss(ps), ps_))
On CPU:
┌ Info: The GPU function is being called but the GPU is not accessible.
│ Defaulting back to the CPU. (No action is required if you want
└ to run on the CPU).
-0.445899
((ps_drift = (weight = Float32[11.908233 11.131403; 1.8911543 1.8107269], bias = Float32[10.83848; 1.7313616;;]), ps_diffusion = (weight = Float32[-22.582191 -22.426481; -0.6261419 -1.0105969], bias = Float32[-21.103664; -0.61384314;;])),)
On GPU:
41.25828
ERROR: LoadError: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] assertscalar(op::String)
@ GPUArraysCore ~/.julia/packages/GPUArraysCore/HaQcr/src/GPUArraysCore.jl:103
[3] getindex
@ ~/.julia/packages/GPUArrays/TnEpb/src/host/indexing.jl:9 [inlined]
[4] adjoint
@ ~/.julia/packages/Zygote/SuKWp/src/lib/array.jl:44 [inlined]
[5] _pullback(__context__::Zygote.Context{false}, 568::typeof(getindex), x::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, inds::Int64)
@ Zygote ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:66
[6] _pullback
@ ~/.julia/packages/SciMLSensitivity/E8w3Z/src/derivative_wrappers.jl:899 [inlined]
...
I don't immediately see how we can rewrite the block https://github.com/SciML/SciMLSensitivity.jl/blob/1997fb1a2c288f3da37f61c7b0894eb4e42c5cd6/src/derivative_wrappers.jl#L897C13-L909 without scalar indexing...
Probably worth looking into https://github.com/patrick-kidger/diffrax/blob/main/diffrax/adjoint.py https://github.com/google-research/torchsde/blob/master/torchsde/_core/adjoint_sde.py if they avoid it.
Wait why is the indexing required there? Why not just compute all derivatives together, i.e.:
_dy, back = Zygote.pullback(y, p) do u, p
f(u, p, t)
end
tmp1, tmp2 = back(λ)
if dgrad !== nothing
if tmp2 !== nothing
!isempty(dgrad) && (vec(dgrad) .= vec(tmp2))
end
end
dλ !== nothing && (vec(dλ) .= vec(tmp1))
dy !== nothing && (dy = _dy)
?
because if the primal noise process has diagonal noise, the adjoint has commutative noise [see (14) in App. 9.5 of https://arxiv.org/pdf/2001.01328.pdf]
I get that, but I don't see why the piece of code right there needs to be indexed. That's exactly the same result as what I posted?
Maybe there is a trivial solution..
using Lux, Zygote, DifferentialEquations, ComponentArrays, Random, SciMLSensitivity
p = [1.5, 1.0, 3.0, 1.0]
m = 2
function f(u, p, t)
dx = p[1] * u[1] - p[2] * u[1] * u[2] * t
dy = -p[3] * u[2] + t * p[4] * u[1] * u[2]
[dx, dy]
end
Random.seed!(434988934)
y = rand(m)
λ = rand(m)
t = rand()
dW = rand(m)
dgrad = zeros(length(p),m)
dλ = zeros(m,m)
dy = zeros(m)
for i in 1:m
_dy, back = Zygote.pullback(y, p) do u, p
f(u, p, t)[i]
end
tmp1, tmp2 = back(λ[i])
dgrad[:, i] .= vec(tmp2)
dλ[:, i] .= vec(tmp1)
dy[i] = _dy
end
dy2, back = Zygote.pullback(y, p) do u, p
f(u, p, t)
end
tmp1, tmp2 = back(λ)
julia> dgrad
4×2 Matrix{Float64}:
0.0813625 0.0
-0.0261179 0.0
0.0 -0.409558
0.0 0.168718
vs.
tmp2
4-element Vector{Float64}:
0.08136250711023468
-0.02611788331081627
-0.40955766401332616
0.16871806153418192
# how to multiply tmp2 with dW such that dgrad * dW == tmp2 (*) dW?
and
julia> dλ
2×2 Matrix{Float64}:
0.107301 0.188725
-0.0374919 -1.52155
julia> tmp1
2-element Vector{Float64}:
0.2960254183310152
-1.5590457490405374
# how to multiply tmp1 with dW such that dλ * dW == tmp1 (*) dW?
Zygote has a bug here that's easy to workaround:
using Lux, Zygote, DifferentialEquations, ComponentArrays, Random, SciMLSensitivity, LinearAlgebra
p = [1.5, 1.0, 3.0, 1.0]
m = 2
function f(u, p, t)
dx = p[1] * u[1] - p[2] * u[1] * u[2] * t
dy = -p[3] * u[2] + t * p[4] * u[1] * u[2]
[dx, dy]
end
Random.seed!(434988934)
y = rand(m)
λ = rand(m)
t = rand()
dW = rand(m)
dgrad = zeros(length(p),m)
dλ = zeros(m,m)
dy = zeros(m)
for i in 1:m
_dy, back = Zygote.pullback(y, p) do u, p
f(u, p, t)[i]
end
tmp1, tmp2 = back(λ[i])
dgrad[:, i] .= vec(tmp2)
dλ[:, i] .= vec(tmp1)
dy[i] = _dy
end
dy2, back = Zygote.pullback(y, p) do u, p
f(u, p, t)
end
out = [back(x) for x in eachcol(Diagonal(λ))]
dgrad == stack(last.(out)) # true
dλ == stack(first.(out)) # true
whoaaaa nice!!! :0
Executing the MWE on SciMLSensitvitiy#master now yields this:
julia> println(Zygote.gradient(ps -> loss(ps), ps_))
ERROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] assertscalar(op::String)
@ GPUArraysCore ~/.julia/packages/GPUArraysCore/uOYfN/src/GPUArraysCore.jl:103
[3] getindex
@ ~/.julia/packages/GPUArrays/5XhED/src/host/indexing.jl:9 [inlined]
[4] generic_matvecmul!(C::Vector{Float32}, tA::Char, A::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, B::SubArray{Float32, 1, LinearAlgebra.Diagonal{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, false}, _add::LinearAlgebra.MulAddMul{true, true, Bool, Bool})
@ LinearAlgebra ~/julia-1.9.0/share/julia/stdlib/v1.9/LinearAlgebra/src/matmul.jl:791
[5] mul!
@ ~/julia-1.9.0/share/julia/stdlib/v1.9/LinearAlgebra/src/matmul.jl:115 [inlined]
[6] mul!
@ ~/julia-1.9.0/share/julia/stdlib/v1.9/LinearAlgebra/src/matmul.jl:276 [inlined]
[7] *
@ ~/julia-1.9.0/share/julia/stdlib/v1.9/LinearAlgebra/src/matmul.jl:105 [inlined]
[8] #1480
@ ~/.julia/packages/ChainRules/9sNmB/src/rulesets/Base/arraymath.jl:60 [inlined]
[9] unthunk
@ ~/.julia/packages/ChainRulesCore/0t04l/src/tangent_types/thunks.jl:204 [inlined]
[10] unthunk
@ ~/.julia/packages/ChainRulesCore/0t04l/src/tangent_types/thunks.jl:237 [inlined]
[11] wrap_chainrules_output
@ ~/.julia/packages/Zygote/JeHtr/src/compiler/chainrules.jl:110 [inlined]
...
[17] (::Zygote.Pullback{Tuple{typeof(diffusion), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Float32}, Tuple{Zygote.ZBack{ComponentArrays.var"#getproperty_adjoint#85"{ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Symbol}}, Zygote.var"#2033#back#209"{Zygote.var"#back#207"{2, 1, Zygote.Context{false}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, Zygote.var"#1990#back#190"{Zygote.var"#186#189"{Zygote.Context{false}, GlobalRef, NamedTuple{(), Tuple{}}}}, Zygote.Pullback{Tuple{Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))}}}, NamedTuple{(), Tuple{}}}, Tuple{Zygote.ZBack{ChainRules.var"#times_pullback#1481"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, Zygote.ZBack{Lux.var"#vec_pullback#193"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.var"#3754#back#1177"{Zygote.var"#1171#1175"{Tuple{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Zygote.ZBack{ComponentArrays.var"#getproperty_adjoint#85"{ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))}}}, Symbol}}, Zygote.var"#2184#back#299"{Zygote.var"#back#298"{:activation, Zygote.Context{false}, Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, typeof(identity)}}, Zygote.ZBack{ComponentArrays.var"#getproperty_adjoint#85"{ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))}}}, Symbol}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Lux.__apply_activation), typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{}}}}, Zygote.var"#1990#back#190"{Zygote.var"#186#189"{Zygote.Context{false}, GlobalRef, Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}}})(Δ::SubArray{Float32, 1, LinearAlgebra.Diagonal{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, false})
@ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
[18] #287
@ ~/.julia/packages/Zygote/JeHtr/src/lib/lib.jl:206 [inlined]
[19] #2173#back
@ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71 [inlined]
[20] Pullback
@ ~/.julia/packages/SciMLBase/kTUaf/src/scimlfunctions.jl:2127 [inlined]
[21] Pullback
@ ~/.julia/packages/SciMLSensitivity/bCIak/src/derivative_wrappers.jl:911 [inlined]
There is some sort of C::Vector{Float32} involved which shouldn't be there I think.
If I CUDA.allowscalar(true) it yields
julia> println(Zygote.gradient(ps -> loss(ps), ps_))
ERROR: BoundsError: attempt to access 12×1 ComponentMatrix{Float32, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}, FlatAxis}} with indices 1:1:12×1:1:1 at index [13:24]
Stacktrace:
[1] copyto!(dest::ComponentMatrix{Float32, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}, FlatAxis}}, dstart::Int64, src::ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, sstart::Int64, n::Int64)
@ Base ./abstractarray.jl:1137
[2] copyto!
@ ./abstractarray.jl:1121 [inlined]
[3] _typed_stack(::Colon, ::Type{Float32}, ::Type{ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}}, A::Vector{ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}}, Aax::Tuple{Base.OneTo{Int64}})
@ Base ./abstractarray.jl:2803
[4] _typed_stack
@ ./abstractarray.jl:2793 [inlined]
[5] _stack
@ ./abstractarray.jl:2783 [inlined]
[6] _stack
@ ./abstractarray.jl:2775 [inlined]
[7] #stack#178
@ ./abstractarray.jl:2743 [inlined]
[8] stack(iter::Vector{ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}})
@ Base ./abstractarray.jl:2743
[9] _jacNoise!(λ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, y::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, p::ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, t::Float32, S::SciMLSensitivity.ODEInterpolatingAdjointSensitivityFunction{SciMLSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}, InterpolatingAdjoint{0, true, Val{:central}, ZygoteVJP}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, RODESolution{Float32, 2, Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Nothing, Nothing, Vector{Float32}, NoiseProcess{Float32, 2, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing, Nothing, typeof(DiffEqNoiseProcess.WHITE_NOISE_DIST), typeof(DiffEqNoiseProcess.WHITE_NOISE_BRIDGE), false, ResettableStacks.ResettableStack{Tuple{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing}, false}, ResettableStacks.ResettableStack{Tuple{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing}, false}, RSWM{Float64}, Nothing, RandomNumbers.Xorshifts.Xoroshiro128Plus}, SDEProblem{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}, false, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, SciMLBase.FullSpecialize, typeof(drift), typeof(diffusion), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, typeof(diffusion), Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}, EulerHeun, StochasticDiffEq.LinearInterpolationData{Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Vector{Float32}}, DiffEqBase.Stats, Nothing}, SciMLSensitivity.CheckpointSolution{RODESolution{Float32, 2, Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Nothing, Nothing, Vector{Float32}, NoiseWrapper{Float32, 2, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing, NoiseProcess{Float32, 2, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing, Nothing, typeof(DiffEqNoiseProcess.WHITE_NOISE_DIST), typeof(DiffEqNoiseProcess.WHITE_NOISE_BRIDGE), false, ResettableStacks.ResettableStack{Tuple{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing}, false}, ResettableStacks.ResettableStack{Tuple{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing}, false}, RSWM{Float64}, Nothing, RandomNumbers.Xorshifts.Xoroshiro128Plus}, Nothing, false}, SDEProblem{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}, false, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, NoiseWrapper{Float32, 2, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing, NoiseProcess{Float32, 2, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing, Nothing, typeof(DiffEqNoiseProcess.WHITE_NOISE_DIST), typeof(DiffEqNoiseProcess.WHITE_NOISE_BRIDGE), false, ResettableStacks.ResettableStack{Tuple{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing}, false}, ResettableStacks.ResettableStack{Tuple{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing}, false}, RSWM{Float64}, Nothing, RandomNumbers.Xorshifts.Xoroshiro128Plus}, Nothing, false}, SDEFunction{false, SciMLBase.FullSpecialize, typeof(drift), typeof(diffusion), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, typeof(diffusion), Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}, EulerHeun, StochasticDiffEq.LinearInterpolationData{Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Vector{Float32}}, DiffEqBase.Stats, Nothing}, Vector{Tuple{Float32, Float32}}, NamedTuple{(:reltol, :abstol), Tuple{Float64, Float64}}, Nothing}, SDEProblem{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}, false, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, SciMLBase.FullSpecialize, typeof(drift), typeof(diffusion), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, typeof(diffusion), Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}, ODEFunction{false, true, typeof(diffusion), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}}, isnoise::ZygoteVJP, dgrad::SubArray{Float32, 2, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{UnitRange{Int64}, UnitRange{Int64}}, false}, dλ::Nothing, dy::Nothing)
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/bCIak/src/derivative_wrappers.jl:918
Same error on CPU:
julia> println(Zygote.gradient(ps -> loss(ps), ps_))
ERROR: BoundsError: attempt to access 12×1 ComponentMatrix{Float32, Matrix{Float32}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}, FlatAxis}} with indices 1:1:12×1:1:1 at index [13:24]
Stacktrace:
[1] copyto!(dest::ComponentMatrix{Float32, Matrix{Float32}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}, FlatAxis}}, dstart::Int64, src::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, sstart::Int64, n::Int64)
@ Base ./abstractarray.jl:1137
[2] copyto!
@ ./abstractarray.jl:1121 [inlined]
[3] _typed_stack(::Colon, ::Type{Float32}, ::Type{ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}}, A::Vector{ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}}, Aax::Tuple{Base.OneTo{Int64}})
@ Base ./abstractarray.jl:2803
[4] _typed_stack
@ ./abstractarray.jl:2793 [inlined]
[5] _stack
@ ./abstractarray.jl:2783 [inlined]
[6] _stack
...
maybe something wrong with the MWE? my component vector seems okay
I'm pretty sure I can solve this, it seems like stack on a ComponentVector isn't behaving as expected
Zygote has a bug here that's easy to workaround:
using Lux, Zygote, DifferentialEquations, ComponentArrays, Random, SciMLSensitivity, LinearAlgebra p = [1.5, 1.0, 3.0, 1.0] m = 2 function f(u, p, t) dx = p[1] * u[1] - p[2] * u[1] * u[2] * t dy = -p[3] * u[2] + t * p[4] * u[1] * u[2] [dx, dy] end Random.seed!(434988934) y = rand(m) λ = rand(m) t = rand() dW = rand(m) dgrad = zeros(length(p),m) dλ = zeros(m,m) dy = zeros(m) for i in 1:m _dy, back = Zygote.pullback(y, p) do u, p f(u, p, t)[i] end tmp1, tmp2 = back(λ[i]) dgrad[:, i] .= vec(tmp2) dλ[:, i] .= vec(tmp1) dy[i] = _dy end dy2, back = Zygote.pullback(y, p) do u, p f(u, p, t) end out = [back(x) for x in eachcol(Diagonal(λ))] dgrad == stack(last.(out)) # true dλ == stack(first.(out)) # true
seems like this workaround doesn't work on gpus as eachcol(Diagonal(...)) and every alternative I've tried doesn't work with CUDA
Zygote has a bug here
What is the bug exactly? Can we fix it? This workaround is O(n^2)
< deleted because I had an incorrect theory here - see below >
hmmm BUT we need to perturb this correctly with the noise. I can't get behind how torchsde / diffrax are doing this right now...
In torchsde, they never actually define g in the adjoint SDE, only define g_prod which is the product between g and the noise. So compare the implementation for EulerHeun:
integrator.f(ftmp1,uprev,p,t)
integrator.g(gtmp1,uprev,p,t)
if is_diagonal_noise(integrator.sol.prob)
@.. nrtmp=gtmp1*W.dW
else
mul!(nrtmp,gtmp1,W.dW)
end
Python:
f, g_prod = self.sde.f_and_g_prod(t0, y0, I_k)
y_prime = y0 + g_prod
g_prod_prime = self.sde.g_prod(t1, y_prime, I_k)
y1 = y0 + dt * f + (g_prod + g_prod_prime) * 0.5
So Python has this more generic gprod and this is used to avoid computing the large matrix. Somehow Julia needs this as well, otherwise we can't train any diagonal SDEs efficiently. But the assumption that we compute g seems pretty deeply nestled into the library.
update: there's a trivial solution!!
# how to multiply tmp2 with dW such that dgrad * dW == tmp2 (*) dW? # how to multiply tmp1 with dW such that dλ * dW == tmp1 (*) dW?
don't compute dgrad- move the multiplication with dW into the vjp. that's it
my comments about paramnoisemixing are not important, noisemixing has nothing to do with this, it just works. but the solver implementation hurdle is still relevant.
using Lux, Zygote, DifferentialEquations, ComponentArrays, Random, SciMLSensitivity, LinearAlgebra
p = [1.5, 1.0, 3.0, 1.0]
m = 2
function f(u, p, t)
dx = p[1] * u[1] - p[2] * u[1] * u[2] * t
dy = -p[3] * u[2] + t * p[4] * u[1] * u[2]
[dx, dy]
end
Random.seed!(434988934)
y = rand(m)
λ = rand(m)
t = rand()
dgrad = zeros(length(p),m)
dλ = zeros(m,m)
dy = zeros(m)
dy, back = Zygote.pullback(y, p) do u, p
f(u, p, t)
end
out = [back(x) for x in eachcol(Diagonal(λ))]
dλ1 = stack(first.(out))
dgrad1 = stack(last.(out))
dW = rand(m)
println("Computed with a vjp for each dimension: $(dλ1 * dW)")
println("Computed with a vjp for each dimension: $(dgrad1 * dW)")
dy2, back = Zygote.pullback(y, p) do u, p
f(u, p, t) .* dW
end
out2 = back(λ)
resλ = first(out2)
resgrad = last(out2)
println("Computed with a single vjp: $resλ")
println("Computed with a single vjp: $resgrad")
println(dλ1 * dW ≈ resλ) # true
println(dgrad1 * dW ≈ resgrad) # true
this is still nontrivial to implement in Julia because of the solver design issue mentioned above
this of course gives us quite a performance boost:
m = 10000
(...)
function f(u, p, t)
[p[1] * x - p[2] * t + p[3] * p[4] * t * x * x for x in u]
end
(...)
# this scales in O(m^2)
# precompile, then execute
[back(x) for x in eachcol(Diagonal(λ))]
@time out = [back(x) for x in eachcol(Diagonal(λ))]
(...)
# this scales in O(m)
back(λ)
@time out2 = back(λ)
println(dλ1 * dW ≈ resλ) # true
println(dgrad1 * dW ≈ resgrad) # true
=>>
29.061289 seconds (400.42 M allocations: 78.243 GiB, 21.32% gc time, 0.13% compilation time)
0.002238 seconds (40.05 k allocations: 8.241 MiB)
true
true