SciMLSensitivity.jl icon indicating copy to clipboard operation
SciMLSensitivity.jl copied to clipboard

Continuous-adjoint methods for diagonal-noise SDEs scale in the square of number of dimensions

Open linusheck opened this issue 2 years ago • 18 comments

Hi, this line is scalar indexing in a pullback: https://github.com/SciML/SciMLSensitivity.jl/blob/1997fb1a2c288f3da37f61c7b0894eb4e42c5cd6/src/derivative_wrappers.jl#L899 This means you can't diff on a GPU in this case, as scalar indexing is not allowed. Excerpt from the error:

Scalar indexing is disallowed.

Invocation of getindex resulted in scalar indexing of a GPU array.

This is typically caused by calling an iterating implementation of a method.

Such implementations *do not* execute on the GPU, but very slowly on the CPU,

and therefore are only permitted from the REPL for prototyping purposes.

If you did intend to index this array, annotate the caller with @allowscalar.

error(::String)@error.jl:35
assertscalar(::String)@GPUArraysCore.jl:103
[email protected]:9[inlined]
[email protected]:44[inlined]
_pullback(::Zygote.Context{false}, ::typeof(getindex), ::CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::Int64)@adjoint.jl:66
_pullback@derivative_wrappers.jl:899[inlined]

linusheck avatar Jul 28 '23 09:07 linusheck

MWE:

using Lux, Zygote, DifferentialEquations, ComponentArrays, Random, CUDA, SciMLSensitivity

rng = Xoshiro()

drift_net = Dense(2 => 2)
diffusion_net = Dense(2 => 2)

ps_drift_, st_drift = Lux.setup(rng, drift_net)

ps_diffusion_, st_diffusion = Lux.setup(rng, diffusion_net)

ps_ = ComponentArray((ps_drift=ps_drift_,ps_diffusion=ps_diffusion_)) |> Lux.gpu

function drift(u, ps, t)
    drift_net(u, ps.ps_drift, st_drift)[1]
end
function diffusion(u, ps, t)
    diffusion_net(u, ps.ps_diffusion, st_diffusion)[1]
end

u0 = [1f0, 1f0] |> Lux.gpu

tspan = (0f0, 1f0)
datasize = 10

solver = EulerHeun()
sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP())

function loss(ps)
    problem = SDEProblem(drift, diffusion, u0, tspan, ps)
    solution = solve(problem, solver; sensealg=sensealg, saveat=collect(range(tspan[1], tspan[end], datasize)), dt=(tspan[end] / datasize))
    return sum(vec(solution |> Lux.gpu))
end

println(loss(ps_))
println(Zygote.gradient(ps -> loss(ps), ps_))

On CPU:

┌ Info: The GPU function is being called but the GPU is not accessible.
│ Defaulting back to the CPU. (No action is required if you want
└ to run on the CPU).
-0.445899
((ps_drift = (weight = Float32[11.908233 11.131403; 1.8911543 1.8107269], bias = Float32[10.83848; 1.7313616;;]), ps_diffusion = (weight = Float32[-22.582191 -22.426481; -0.6261419 -1.0105969], bias = Float32[-21.103664; -0.61384314;;])),)

On GPU:

41.25828
ERROR: LoadError: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] assertscalar(op::String)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/HaQcr/src/GPUArraysCore.jl:103
  [3] getindex
    @ ~/.julia/packages/GPUArrays/TnEpb/src/host/indexing.jl:9 [inlined]
  [4] adjoint
    @ ~/.julia/packages/Zygote/SuKWp/src/lib/array.jl:44 [inlined]
  [5] _pullback(__context__::Zygote.Context{false}, 568::typeof(getindex), x::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, inds::Int64)
    @ Zygote ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:66
  [6] _pullback
    @ ~/.julia/packages/SciMLSensitivity/E8w3Z/src/derivative_wrappers.jl:899 [inlined]
...

linusheck avatar Aug 02 '23 08:08 linusheck

I don't immediately see how we can rewrite the block https://github.com/SciML/SciMLSensitivity.jl/blob/1997fb1a2c288f3da37f61c7b0894eb4e42c5cd6/src/derivative_wrappers.jl#L897C13-L909 without scalar indexing...

Probably worth looking into https://github.com/patrick-kidger/diffrax/blob/main/diffrax/adjoint.py https://github.com/google-research/torchsde/blob/master/torchsde/_core/adjoint_sde.py if they avoid it.

frankschae avatar Aug 02 '23 13:08 frankschae

Wait why is the indexing required there? Why not just compute all derivatives together, i.e.:

                _dy, back = Zygote.pullback(y, p) do u, p
                    f(u, p, t)
                end
                tmp1, tmp2 = back(λ)
                if dgrad !== nothing
                    if tmp2 !== nothing
                        !isempty(dgrad) && (vec(dgrad) .= vec(tmp2))
                    end
                end
                dλ !== nothing && (vec(dλ) .= vec(tmp1))
                dy !== nothing && (dy = _dy)

?

ChrisRackauckas avatar Aug 02 '23 13:08 ChrisRackauckas

because if the primal noise process has diagonal noise, the adjoint has commutative noise [see (14) in App. 9.5 of https://arxiv.org/pdf/2001.01328.pdf]

frankschae avatar Aug 02 '23 16:08 frankschae

I get that, but I don't see why the piece of code right there needs to be indexed. That's exactly the same result as what I posted?

ChrisRackauckas avatar Aug 02 '23 16:08 ChrisRackauckas

Maybe there is a trivial solution..

using Lux, Zygote, DifferentialEquations, ComponentArrays, Random, SciMLSensitivity

p = [1.5, 1.0, 3.0, 1.0]
m = 2

function f(u, p, t)
    dx = p[1] * u[1] - p[2] * u[1] * u[2] * t
    dy = -p[3] * u[2] + t * p[4] * u[1] * u[2]
    [dx, dy]
end
Random.seed!(434988934)
y = rand(m)
λ = rand(m)
t = rand()

dW = rand(m)


dgrad = zeros(length(p),m)
dλ = zeros(m,m)
dy = zeros(m)

for i in 1:m
    _dy, back = Zygote.pullback(y, p) do u, p
        f(u, p, t)[i]
    end
    tmp1, tmp2 = back(λ[i])
    dgrad[:, i] .= vec(tmp2)
    dλ[:, i] .= vec(tmp1)
    dy[i] = _dy
end

dy2, back = Zygote.pullback(y, p) do u, p
    f(u, p, t)
end
tmp1, tmp2 = back(λ)
julia> dgrad
4×2 Matrix{Float64}:
  0.0813625   0.0
 -0.0261179   0.0
  0.0        -0.409558
  0.0         0.168718

vs.

tmp2
4-element Vector{Float64}:
  0.08136250711023468
 -0.02611788331081627
 -0.40955766401332616
  0.16871806153418192
  # how to multiply tmp2 with dW such that dgrad * dW  ==  tmp2 (*) dW?

and

julia> dλ
2×2 Matrix{Float64}:
  0.107301    0.188725
 -0.0374919  -1.52155

julia> tmp1
2-element Vector{Float64}:
  0.2960254183310152
 -1.5590457490405374
 
  # how to multiply tmp1 with dW such that dλ * dW  ==  tmp1 (*) dW?

frankschae avatar Aug 02 '23 18:08 frankschae

Zygote has a bug here that's easy to workaround:

using Lux, Zygote, DifferentialEquations, ComponentArrays, Random, SciMLSensitivity, LinearAlgebra

p = [1.5, 1.0, 3.0, 1.0]
m = 2

function f(u, p, t)
    dx = p[1] * u[1] - p[2] * u[1] * u[2] * t
    dy = -p[3] * u[2] + t * p[4] * u[1] * u[2]
    [dx, dy]
end
Random.seed!(434988934)
y = rand(m)
λ = rand(m)
t = rand()

dW = rand(m)


dgrad = zeros(length(p),m)
dλ = zeros(m,m)
dy = zeros(m)

for i in 1:m
    _dy, back = Zygote.pullback(y, p) do u, p
        f(u, p, t)[i]
    end
    tmp1, tmp2 = back(λ[i])
    dgrad[:, i] .= vec(tmp2)
    dλ[:, i] .= vec(tmp1)
    dy[i] = _dy
end

dy2, back = Zygote.pullback(y, p) do u, p
    f(u, p, t)
end
out = [back(x) for x in eachcol(Diagonal(λ))]

dgrad == stack(last.(out)) # true
dλ == stack(first.(out)) # true

ChrisRackauckas avatar Aug 02 '23 20:08 ChrisRackauckas

whoaaaa nice!!! :0

linusheck avatar Aug 03 '23 07:08 linusheck

Executing the MWE on SciMLSensitvitiy#master now yields this:

julia> println(Zygote.gradient(ps -> loss(ps), ps_))
ERROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] assertscalar(op::String)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/uOYfN/src/GPUArraysCore.jl:103
  [3] getindex
    @ ~/.julia/packages/GPUArrays/5XhED/src/host/indexing.jl:9 [inlined]
  [4] generic_matvecmul!(C::Vector{Float32}, tA::Char, A::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, B::SubArray{Float32, 1, LinearAlgebra.Diagonal{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, false}, _add::LinearAlgebra.MulAddMul{true, true, Bool, Bool})
    @ LinearAlgebra ~/julia-1.9.0/share/julia/stdlib/v1.9/LinearAlgebra/src/matmul.jl:791
  [5] mul!
    @ ~/julia-1.9.0/share/julia/stdlib/v1.9/LinearAlgebra/src/matmul.jl:115 [inlined]
  [6] mul!
    @ ~/julia-1.9.0/share/julia/stdlib/v1.9/LinearAlgebra/src/matmul.jl:276 [inlined]
  [7] *
    @ ~/julia-1.9.0/share/julia/stdlib/v1.9/LinearAlgebra/src/matmul.jl:105 [inlined]
  [8] #1480
    @ ~/.julia/packages/ChainRules/9sNmB/src/rulesets/Base/arraymath.jl:60 [inlined]
  [9] unthunk
    @ ~/.julia/packages/ChainRulesCore/0t04l/src/tangent_types/thunks.jl:204 [inlined]
 [10] unthunk
    @ ~/.julia/packages/ChainRulesCore/0t04l/src/tangent_types/thunks.jl:237 [inlined]
 [11] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/JeHtr/src/compiler/chainrules.jl:110 [inlined]
...
[17] (::Zygote.Pullback{Tuple{typeof(diffusion), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Float32}, Tuple{Zygote.ZBack{ComponentArrays.var"#getproperty_adjoint#85"{ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Symbol}}, Zygote.var"#2033#back#209"{Zygote.var"#back#207"{2, 1, Zygote.Context{false}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, Zygote.var"#1990#back#190"{Zygote.var"#186#189"{Zygote.Context{false}, GlobalRef, NamedTuple{(), Tuple{}}}}, Zygote.Pullback{Tuple{Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))}}}, NamedTuple{(), Tuple{}}}, Tuple{Zygote.ZBack{ChainRules.var"#times_pullback#1481"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, Zygote.ZBack{Lux.var"#vec_pullback#193"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.var"#3754#back#1177"{Zygote.var"#1171#1175"{Tuple{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Zygote.ZBack{ComponentArrays.var"#getproperty_adjoint#85"{ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))}}}, Symbol}}, Zygote.var"#2184#back#299"{Zygote.var"#back#298"{:activation, Zygote.Context{false}, Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, typeof(identity)}}, Zygote.ZBack{ComponentArrays.var"#getproperty_adjoint#85"{ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))}}}, Symbol}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Lux.__apply_activation), typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{}}}}, Zygote.var"#1990#back#190"{Zygote.var"#186#189"{Zygote.Context{false}, GlobalRef, Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}}})(Δ::SubArray{Float32, 1, LinearAlgebra.Diagonal{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, false})
    @ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
 [18] #287
    @ ~/.julia/packages/Zygote/JeHtr/src/lib/lib.jl:206 [inlined]
 [19] #2173#back
    @ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71 [inlined]
 [20] Pullback
    @ ~/.julia/packages/SciMLBase/kTUaf/src/scimlfunctions.jl:2127 [inlined]
 [21] Pullback
    @ ~/.julia/packages/SciMLSensitivity/bCIak/src/derivative_wrappers.jl:911 [inlined]

There is some sort of C::Vector{Float32} involved which shouldn't be there I think.

If I CUDA.allowscalar(true) it yields

julia> println(Zygote.gradient(ps -> loss(ps), ps_))
ERROR: BoundsError: attempt to access 12×1 ComponentMatrix{Float32, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}, FlatAxis}} with indices 1:1:12×1:1:1 at index [13:24]
Stacktrace:
  [1] copyto!(dest::ComponentMatrix{Float32, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}, FlatAxis}}, dstart::Int64, src::ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, sstart::Int64, n::Int64)
    @ Base ./abstractarray.jl:1137
  [2] copyto!
    @ ./abstractarray.jl:1121 [inlined]
  [3] _typed_stack(::Colon, ::Type{Float32}, ::Type{ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}}, A::Vector{ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}}, Aax::Tuple{Base.OneTo{Int64}})
    @ Base ./abstractarray.jl:2803
  [4] _typed_stack
    @ ./abstractarray.jl:2793 [inlined]
  [5] _stack
    @ ./abstractarray.jl:2783 [inlined]
  [6] _stack
    @ ./abstractarray.jl:2775 [inlined]
  [7] #stack#178
    @ ./abstractarray.jl:2743 [inlined]
  [8] stack(iter::Vector{ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}})
    @ Base ./abstractarray.jl:2743
  [9] _jacNoise!(λ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, y::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, p::ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, t::Float32, S::SciMLSensitivity.ODEInterpolatingAdjointSensitivityFunction{SciMLSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}, InterpolatingAdjoint{0, true, Val{:central}, ZygoteVJP}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, RODESolution{Float32, 2, Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Nothing, Nothing, Vector{Float32}, NoiseProcess{Float32, 2, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing, Nothing, typeof(DiffEqNoiseProcess.WHITE_NOISE_DIST), typeof(DiffEqNoiseProcess.WHITE_NOISE_BRIDGE), false, ResettableStacks.ResettableStack{Tuple{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing}, false}, ResettableStacks.ResettableStack{Tuple{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing}, false}, RSWM{Float64}, Nothing, RandomNumbers.Xorshifts.Xoroshiro128Plus}, SDEProblem{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}, false, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, SciMLBase.FullSpecialize, typeof(drift), typeof(diffusion), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, typeof(diffusion), Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}, EulerHeun, StochasticDiffEq.LinearInterpolationData{Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Vector{Float32}}, DiffEqBase.Stats, Nothing}, SciMLSensitivity.CheckpointSolution{RODESolution{Float32, 2, Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Nothing, Nothing, Vector{Float32}, NoiseWrapper{Float32, 2, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing, NoiseProcess{Float32, 2, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing, Nothing, typeof(DiffEqNoiseProcess.WHITE_NOISE_DIST), typeof(DiffEqNoiseProcess.WHITE_NOISE_BRIDGE), false, ResettableStacks.ResettableStack{Tuple{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing}, false}, ResettableStacks.ResettableStack{Tuple{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing}, false}, RSWM{Float64}, Nothing, RandomNumbers.Xorshifts.Xoroshiro128Plus}, Nothing, false}, SDEProblem{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}, false, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, NoiseWrapper{Float32, 2, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing, NoiseProcess{Float32, 2, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing, Nothing, typeof(DiffEqNoiseProcess.WHITE_NOISE_DIST), typeof(DiffEqNoiseProcess.WHITE_NOISE_BRIDGE), false, ResettableStacks.ResettableStack{Tuple{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing}, false}, ResettableStacks.ResettableStack{Tuple{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing}, false}, RSWM{Float64}, Nothing, RandomNumbers.Xorshifts.Xoroshiro128Plus}, Nothing, false}, SDEFunction{false, SciMLBase.FullSpecialize, typeof(drift), typeof(diffusion), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, typeof(diffusion), Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}, EulerHeun, StochasticDiffEq.LinearInterpolationData{Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Vector{Float32}}, DiffEqBase.Stats, Nothing}, Vector{Tuple{Float32, Float32}}, NamedTuple{(:reltol, :abstol), Tuple{Float64, Float64}}, Nothing}, SDEProblem{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}, false, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, SciMLBase.FullSpecialize, typeof(drift), typeof(diffusion), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, typeof(diffusion), Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}, ODEFunction{false, true, typeof(diffusion), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}}, isnoise::ZygoteVJP, dgrad::SubArray{Float32, 2, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{UnitRange{Int64}, UnitRange{Int64}}, false}, dλ::Nothing, dy::Nothing)
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/bCIak/src/derivative_wrappers.jl:918

linusheck avatar Aug 10 '23 13:08 linusheck

Same error on CPU:

julia> println(Zygote.gradient(ps -> loss(ps), ps_))
ERROR: BoundsError: attempt to access 12×1 ComponentMatrix{Float32, Matrix{Float32}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}, FlatAxis}} with indices 1:1:12×1:1:1 at index [13:24]
Stacktrace:
  [1] copyto!(dest::ComponentMatrix{Float32, Matrix{Float32}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}, FlatAxis}}, dstart::Int64, src::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, sstart::Int64, n::Int64)
    @ Base ./abstractarray.jl:1137
  [2] copyto!
    @ ./abstractarray.jl:1121 [inlined]
  [3] _typed_stack(::Colon, ::Type{Float32}, ::Type{ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}}, A::Vector{ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}}, Aax::Tuple{Base.OneTo{Int64}})
    @ Base ./abstractarray.jl:2803
  [4] _typed_stack
    @ ./abstractarray.jl:2793 [inlined]
  [5] _stack
    @ ./abstractarray.jl:2783 [inlined]
  [6] _stack
...

maybe something wrong with the MWE? my component vector seems okay

linusheck avatar Aug 10 '23 13:08 linusheck

I'm pretty sure I can solve this, it seems like stack on a ComponentVector isn't behaving as expected

linusheck avatar Aug 12 '23 06:08 linusheck

Zygote has a bug here that's easy to workaround:

using Lux, Zygote, DifferentialEquations, ComponentArrays, Random, SciMLSensitivity, LinearAlgebra

p = [1.5, 1.0, 3.0, 1.0]
m = 2

function f(u, p, t)
    dx = p[1] * u[1] - p[2] * u[1] * u[2] * t
    dy = -p[3] * u[2] + t * p[4] * u[1] * u[2]
    [dx, dy]
end
Random.seed!(434988934)
y = rand(m)
λ = rand(m)
t = rand()

dW = rand(m)


dgrad = zeros(length(p),m)
dλ = zeros(m,m)
dy = zeros(m)

for i in 1:m
    _dy, back = Zygote.pullback(y, p) do u, p
        f(u, p, t)[i]
    end
    tmp1, tmp2 = back(λ[i])
    dgrad[:, i] .= vec(tmp2)
    dλ[:, i] .= vec(tmp1)
    dy[i] = _dy
end

dy2, back = Zygote.pullback(y, p) do u, p
    f(u, p, t)
end
out = [back(x) for x in eachcol(Diagonal(λ))]

dgrad == stack(last.(out)) # true
dλ == stack(first.(out)) # true

seems like this workaround doesn't work on gpus as eachcol(Diagonal(...)) and every alternative I've tried doesn't work with CUDA

linusheck avatar Aug 12 '23 08:08 linusheck

Zygote has a bug here

What is the bug exactly? Can we fix it? This workaround is O(n^2)

linusheck avatar Aug 12 '23 10:08 linusheck

< deleted because I had an incorrect theory here - see below >

linusheck avatar Aug 12 '23 10:08 linusheck

hmmm BUT we need to perturb this correctly with the noise. I can't get behind how torchsde / diffrax are doing this right now...

linusheck avatar Aug 13 '23 18:08 linusheck

In torchsde, they never actually define g in the adjoint SDE, only define g_prod which is the product between g and the noise. So compare the implementation for EulerHeun:

integrator.f(ftmp1,uprev,p,t)
integrator.g(gtmp1,uprev,p,t)

if is_diagonal_noise(integrator.sol.prob)
  @.. nrtmp=gtmp1*W.dW
else
  mul!(nrtmp,gtmp1,W.dW)
end

Python:

f, g_prod = self.sde.f_and_g_prod(t0, y0, I_k)

y_prime = y0 + g_prod

g_prod_prime = self.sde.g_prod(t1, y_prime, I_k)

y1 = y0 + dt * f + (g_prod + g_prod_prime) * 0.5

So Python has this more generic gprod and this is used to avoid computing the large matrix. Somehow Julia needs this as well, otherwise we can't train any diagonal SDEs efficiently. But the assumption that we compute g seems pretty deeply nestled into the library.

linusheck avatar Aug 14 '23 06:08 linusheck

update: there's a trivial solution!!

  # how to multiply tmp2 with dW such that dgrad * dW  ==  tmp2 (*) dW?
  # how to multiply tmp1 with dW such that dλ * dW  ==  tmp1 (*) dW?

don't compute dgrad- move the multiplication with dW into the vjp. that's it

my comments about paramnoisemixing are not important, noisemixing has nothing to do with this, it just works. but the solver implementation hurdle is still relevant.

using Lux, Zygote, DifferentialEquations, ComponentArrays, Random, SciMLSensitivity, LinearAlgebra

p = [1.5, 1.0, 3.0, 1.0]
m = 2

function f(u, p, t)
    dx = p[1] * u[1] - p[2] * u[1] * u[2] * t
    dy = -p[3] * u[2] + t * p[4] * u[1] * u[2]
    [dx, dy]
end

Random.seed!(434988934)
y = rand(m)
λ = rand(m)
t = rand()

dgrad = zeros(length(p),m)
dλ = zeros(m,m)
dy = zeros(m)

dy, back = Zygote.pullback(y, p) do u, p
    f(u, p, t)
end
out = [back(x) for x in eachcol(Diagonal(λ))]

dλ1 = stack(first.(out))
dgrad1 = stack(last.(out))

dW = rand(m)

println("Computed with a vjp for each dimension: $(dλ1 * dW)")
println("Computed with a vjp for each dimension: $(dgrad1 * dW)")

dy2, back = Zygote.pullback(y, p) do u, p
    f(u, p, t) .* dW
end

out2 = back(λ)

resλ = first(out2)
resgrad = last(out2)

println("Computed with a single vjp: $resλ")
println("Computed with a single vjp: $resgrad")

println(dλ1 * dW ≈ resλ) # true
println(dgrad1 * dW ≈ resgrad) # true

this is still nontrivial to implement in Julia because of the solver design issue mentioned above

linusheck avatar Sep 09 '23 14:09 linusheck

this of course gives us quite a performance boost:

m = 10000
(...)
function f(u, p, t)
    [p[1] * x - p[2] * t + p[3] * p[4] * t * x * x for x in u]
end

(...)

# this scales in O(m^2)
# precompile, then execute
[back(x) for x in eachcol(Diagonal(λ))]
@time out = [back(x) for x in eachcol(Diagonal(λ))]

(...)

# this scales in O(m)
back(λ)
@time out2 = back(λ)

println(dλ1 * dW ≈ resλ) # true
println(dgrad1 * dW ≈ resgrad) # true

=>>

 29.061289 seconds (400.42 M allocations: 78.243 GiB, 21.32% gc time, 0.13% compilation time)
  0.002238 seconds (40.05 k allocations: 8.241 MiB)
true
true

linusheck avatar Sep 09 '23 14:09 linusheck