DiffEqFlux.jl
DiffEqFlux.jl copied to clipboard
Higher-order derivatives of ffjord on gpu fail
I need to calculate the Laplacian of the densities modelled by a normalizing flow w.r.t. to the inputs. On CPU, I can e.g. use the following code (which works, but seems to scale poorly with the number of samples)
nn = Chain(
Dense(2, 32, tanh),
Dense(32, 2),
) |> f32
tspan = (0.0f0, 1.0f0)
ffjord_mdl = FFJORD(nn, tspan, Tsit5())
function loss(x)
logpx, λ₁, λ₂ = ffjord_mdl(x)
return logpx
end
function lapl(x)
return Zygote.diaghessian(x->sum(loss(x)), x)
end
data_dist = Normal(0.0f0, 1.0f0)
train_data = rand(data_dist, 2, 10)
lapl(train_data)
However, when I attempt to run this code on GPU
nn = Chain(
Dense(2, 32, tanh),
Dense(32, 2),
) |> gpu
tspan = (0.0f0, 1.0f0)
ffjord_mdl = FFJORD(nn, tspan, Tsit5())
function loss(x)
e = randn(Float32, size(x)) |> gpu
logpx, λ₁, λ₂ = ffjord_mdl(x, ffjord_mdl.p, e)
return logpx
end
function lapl(x)
return Zygote.diaghessian(x->sum(loss(x)), x)
end
data_dist = Normal(0.0f0, 1.0f0)
train_data = gpu(rand(data_dist, 2, 10))
lapl(train_data)
I get the following error:
ERROR: MethodError: no method matching cudnnDataType(::Type{ForwardDiff.Dual{Nothing, Float32, 12}})
Closest candidates are:
cudnnDataType(::Type{Float16}) at /data/packages/CUDA/YpW0k/lib/cudnn/util.jl:7
cudnnDataType(::Type{Float32}) at /data/packages/CUDA/YpW0k/lib/cudnn/util.jl:8
cudnnDataType(::Type{Float64}) at /data/packages/CUDA/YpW0k/lib/cudnn/util.jl:9
...
Stacktrace:
[1] CUDA.CUDNN.cudnnTensorDescriptor(array::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}; format::CUDA.CUDNN.cudnnTensorFormat_t, dims::Vector{Int32})
@ CUDA.CUDNN /data/packages/CUDA/YpW0k/lib/cudnn/tensor.jl:9
[2] CUDA.CUDNN.cudnnTensorDescriptor(array::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer})
@ CUDA.CUDNN /data/packages/CUDA/YpW0k/lib/cudnn/tensor.jl:8
[3] cudnnActivationForward!(y::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}, x::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}; o::Base.Iterators.Pairs{Symbol, CUDA.CUDNN.cudnnActivationMode_t, Tuple{Symbol}, NamedTuple{(:mode,), Tuple{CUDA.CUDNN.cudnnActivationMode_t}}})
@ CUDA.CUDNN /data/packages/CUDA/YpW0k/lib/cudnn/activation.jl:22
[4] (::NNlibCUDA.var"#64#68")(src::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}, dst::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer})
@ NNlibCUDA /data/packages/NNlibCUDA/gWBCU/src/cudnn/activations.jl:10
[5] materialize(bc::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Nothing, typeof(tanh), Tuple{CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}}})
@ NNlibCUDA /data/packages/NNlibCUDA/gWBCU/src/cudnn/activations.jl:30
[6] adjoint
@ /data/packages/Zygote/AlLTp/src/lib/broadcast.jl:102 [inlined]
[7] _pullback(__context__::Zygote.Context, 641::typeof(Base.Broadcast.broadcasted), 642::typeof(tanh), x::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer})
@ Zygote /data/packages/ZygoteRules/AIbCs/src/adjoint.jl:65
[8] _pullback
@ /data/packages/Flux/BPPNj/src/layers/basic.jl:158 [inlined]
[9] _pullback(ctx::Zygote.Context, f::Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, args::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer})
@ Zygote /data/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
[10] _pullback
@ /data/packages/Flux/BPPNj/src/layers/basic.jl:47 [inlined]
[11] _pullback(::Zygote.Context, ::typeof(Flux.applychain), ::Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, ::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer})
@ Zygote /data/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
[12] _pullback
@ /data/packages/Flux/BPPNj/src/layers/basic.jl:49 [inlined]
[13] _pullback(ctx::Zygote.Context, f::Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, args::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer})
@ Zygote /data/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
[14] _pullback
@ /data/packages/Zygote/AlLTp/src/compiler/interface.jl:34 [inlined]
[15] pullback(f::Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, args::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer})
@ Zygote /data/packages/Zygote/AlLTp/src/compiler/interface.jl:40
[16] ffjord(u::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}, p::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, t::Float32, re::Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, e::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}; regularize::Bool, monte_carlo::Bool)
@ DiffEqFlux /data/packages/DiffEqFlux/jpIWG/src/ffjord.jl:186
[17] ffjord_
@ /data/packages/DiffEqFlux/jpIWG/src/ffjord.jl:204 [inlined]
[18] ODEFunction
@ /data/packages/SciMLBase/x3z0g/src/scimlfunctions.jl:334 [inlined]
[19] initialize!(integrator::OrdinaryDiffEq.ODEIntegrator{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, false, CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}, Nothing, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, Float32, Float32, Float32, Vector{CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}}, ODESolution{ForwardDiff.Dual{Nothing, Float32, 12}, 3, Vector{CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}}}, ODEProblem{CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}, false, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ODEFunction{false, DiffEqFlux.var"#ffjord_#61"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{false, DiffEqFlux.var"#ffjord_#61"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}}, Vector{Float32}, Vector{Vector{CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}}}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}}, DiffEqBase.DEStats}, ODEFunction{false, DiffEqFlux.var"#ffjord_#61"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}, OrdinaryDiffEq.DEOptions{ForwardDiff.Dual{Nothing, Float32, 12}, ForwardDiff.Dual{Nothing, Float32, 12}, Float32, Float32, PIController{Rational{Int64}}, typeof(DiffEqBase.ODE_DEFAULT_NORM), typeof(LinearAlgebra.opnorm), Bool, CallbackSet{Tuple{}, Tuple{}}, typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), DataStructures.BinaryHeap{Float32, DataStructures.FasterForward}, DataStructures.BinaryHeap{Float32, DataStructures.FasterForward}, Nothing, Nothing, Int64, Tuple{}, Tuple{}, Tuple{}}, CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}, ForwardDiff.Dual{Nothing, Float32, 12}, Nothing, OrdinaryDiffEq.DefaultInit}, cache::OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32})
@ OrdinaryDiffEq /data/packages/OrdinaryDiffEq/JsAS0/src/perform_step/low_order_rk_perform_step.jl:569
[20] __init(prob::ODEProblem{CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}, false, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ODEFunction{false, DiffEqFlux.var"#ffjord_#61"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, alg::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{Val{true}}; saveat::Tuple{}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Bool, callback::Nothing, dense::Bool, calck::Bool, dt::Float32, dtmin::Nothing, dtmax::Float32, force_dtmin::Bool, adaptive::Bool, gamma::Rational{Int64}, abstol::Nothing, reltol::Nothing, qmin::Rational{Int64}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{Int64}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::Base.Iterators.Pairs{Symbol, Bool, Tuple{Symbol}, NamedTuple{(:save_noise,), Tuple{Bool}}})
@ OrdinaryDiffEq /data/packages/OrdinaryDiffEq/JsAS0/src/solve.jl:456
[21] #__solve#493
@ /data/packages/OrdinaryDiffEq/JsAS0/src/solve.jl:4 [inlined]
[22] #solve_call#42
@ /data/packages/DiffEqBase/b1nST/src/solve.jl:61 [inlined]
[23] solve_up(prob::ODEProblem{CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}, false, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ODEFunction{false, DiffEqFlux.var"#ffjord_#61"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, sensealg::Nothing, u0::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}, p::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, args::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}; kwargs::Base.Iterators.Pairs{Symbol, Bool, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:save_noise, :save_start, :save_end), Tuple{Bool, Bool, Bool}}})
@ DiffEqBase /data/packages/DiffEqBase/b1nST/src/solve.jl:87
[24] #solve#43
@ /data/packages/DiffEqBase/b1nST/src/solve.jl:73 [inlined]
[25] _concrete_solve_adjoint(::ODEProblem{CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}, false, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ODEFunction{false, DiffEqFlux.var"#ffjord_#61"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ::InterpolatingAdjoint{0, true, Val{:central}, Bool, Bool}, ::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}; save_start::Bool, save_end::Bool, saveat::Vector{Float32}, save_idxs::Nothing, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ DiffEqSensitivity /data/packages/DiffEqSensitivity/uakCr/src/concrete_solve.jl:151
[26] _concrete_solve_adjoint
@ /data/packages/DiffEqSensitivity/uakCr/src/concrete_solve.jl:131 [inlined]
[27] #_solve_adjoint#61
@ /data/packages/DiffEqBase/b1nST/src/solve.jl:347 [inlined]
[28] _solve_adjoint
@ /data/packages/DiffEqBase/b1nST/src/solve.jl:322 [inlined]
[29] #rrule#59
@ /data/packages/DiffEqBase/b1nST/src/solve.jl:310 [inlined]
[30] rrule
@ /data/packages/DiffEqBase/b1nST/src/solve.jl:310 [inlined]
[31] rrule
@ /data/packages/ChainRulesCore/7ZiwT/src/rules.jl:134 [inlined]
[32] chain_rrule
@ /data/packages/Zygote/AlLTp/src/compiler/chainrules.jl:216 [inlined]
[33] macro expansion
@ /data/packages/Zygote/AlLTp/src/compiler/interface2.jl:0 [inlined]
[34] _pullback
@ /data/packages/Zygote/AlLTp/src/compiler/interface2.jl:9 [inlined]
[35] _apply
@ ./boot.jl:804 [inlined]
[36] adjoint
@ /data/packages/Zygote/AlLTp/src/lib/lib.jl:200 [inlined]
[37] _pullback
@ /data/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
[38] _pullback
@ /data/packages/DiffEqBase/b1nST/src/solve.jl:73 [inlined]
[39] _pullback(::Zygote.Context, ::DiffEqBase.var"##solve#43", ::InterpolatingAdjoint{0, true, Val{:central}, Bool, Bool}, ::Nothing, ::Nothing, ::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(solve), ::ODEProblem{CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}, false, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ODEFunction{false, DiffEqFlux.var"#ffjord_#61"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False})
@ Zygote /data/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
[40] _apply(::Function, ::Vararg{Any, N} where N)
@ Core ./boot.jl:804
[41] adjoint
@ /data/packages/Zygote/AlLTp/src/lib/lib.jl:200 [inlined]
[42] _pullback
@ /data/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
[43] _pullback
@ /data/packages/DiffEqBase/b1nST/src/solve.jl:68 [inlined]
[44] _pullback(::Zygote.Context, ::CommonSolve.var"#solve##kw", ::NamedTuple{(:sensealg,), Tuple{InterpolatingAdjoint{0, true, Val{:central}, Bool, Bool}}}, ::typeof(solve), ::ODEProblem{CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}, false, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ODEFunction{false, DiffEqFlux.var"#ffjord_#61"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False})
@ Zygote /data/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
[45] _apply(::Function, ::Vararg{Any, N} where N)
@ Core ./boot.jl:804
[46] adjoint
@ /data/packages/Zygote/AlLTp/src/lib/lib.jl:200 [inlined]
[47] _pullback
@ /data/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
[48] _pullback
@ /data/packages/DiffEqFlux/jpIWG/src/ffjord.jl:218 [inlined]
[49] _pullback(::Zygote.Context, ::DiffEqFlux.var"##forward_ffjord#56", ::Bool, ::Bool, ::typeof(DiffEqFlux.forward_ffjord), ::FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
@ Zygote /data/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
[50] _pullback
@ /data/packages/DiffEqFlux/jpIWG/src/ffjord.jl:202 [inlined]
[51] _pullback(::Zygote.Context, ::typeof(DiffEqFlux.forward_ffjord), ::FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
@ Zygote /data/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
[52] _apply(::Function, ::Vararg{Any, N} where N)
@ Core ./boot.jl:804
[53] adjoint
@ /data/packages/Zygote/AlLTp/src/lib/lib.jl:200 [inlined]
[54] _pullback
@ /data/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
[55] _pullback
@ /data/packages/DiffEqFlux/jpIWG/src/ffjord.jl:198 [inlined]
--- the last 5 lines are repeated 1 more time ---
[61] _pullback(::Zygote.Context, ::FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
@ Zygote /data/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
[62] _pullback
@ ./REPL[6]:3 [inlined]
[63] _pullback(ctx::Zygote.Context, f::typeof(loss), args::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer})
@ Zygote /data/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
[64] _pullback
@ ./REPL[7]:2 [inlined]
[65] _pullback(ctx::Zygote.Context, f::var"#1#2", args::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer})
@ Zygote /data/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
[66] _pullback(f::Function, args::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer})
@ Zygote /data/packages/Zygote/AlLTp/src/compiler/interface.jl:34
[67] pullback(f::Function, args::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer})
@ Zygote /data/packages/Zygote/AlLTp/src/compiler/interface.jl:40
[68] gradient(f::Function, args::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer})
@ Zygote /data/packages/Zygote/AlLTp/src/compiler/interface.jl:75
[69] (::Zygote.var"#105#108"{Int64, Val{1}, var"#1#2", Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}})(x::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer})
@ Zygote /data/packages/Zygote/AlLTp/src/lib/grad.jl:272
[70] forward_diag(f::Zygote.var"#105#108"{Int64, Val{1}, var"#1#2", Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, x::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, #unused#::Val{12})
@ Zygote /data/packages/Zygote/AlLTp/src/lib/forward.jl:65
[71] forward_diag(f::Function, x::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
@ Zygote /data/packages/Zygote/AlLTp/src/lib/forward.jl:80
[72] #104
@ /data/packages/Zygote/AlLTp/src/lib/grad.jl:272 [inlined]
[73] ntuple
@ ./ntuple.jl:19 [inlined]
[74] diaghessian
@ /data/packages/Zygote/AlLTp/src/lib/grad.jl:269 [inlined]
[75] lapl(x::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
@ Main ./REPL[7]:2
[76] top-level scope
@ ./timing.jl:210 [inlined]
[77] top-level scope
@ ./REPL[11]:0
These are my package versions:
(@v1.6) pkg> status
Status `/data/environments/v1.6/Project.toml`
[6e4b80f9] BenchmarkTools v1.2.0
[052768ef] CUDA v3.5.0
[aae7a2af] DiffEqFlux v1.44.0
[41bf760c] DiffEqSensitivity v6.60.3
[0c46a032] DifferentialEquations v6.20.0
[31c24e10] Distributions v0.25.29
[587475ba] Flux v0.12.8
[a75be94c] GalacticOptim v2.2.0
[429524aa] Optim v1.5.0
[1dea7af3] OrdinaryDiffEq v5.67.0
[91a5bcdd] Plots v1.23.6
[e88e6eb3] Zygote v0.6.30
[de0858da] Printf
[10745b16] Statistics
Any help would be appreciated!
@DhairyaLGandhi any good way around this?
Just curious whether there has been any movement on this? Or whether there could be an alternative to getting the laplacian which works by using other functions / AD packages.
Not sure if its related, but if I simply try to call the loss of this FFJORD code I get a scalar indexing on a GPU array error. Seems to point to the [:, :, end] slice on the solve in forward_ffjord.
Code:
nn = Chain(
Dense(2, 32, tanh),
Dense(32, 2),
) |> gpu
tspan = (0.0f0, 1.0f0)
ffjord_mdl = FFJORD(nn, tspan, Tsit5())
function loss(x)
e = randn(Float32, size(x)) |> gpu
logpx, λ₁, λ₂ = ffjord_mdl(x, ffjord_mdl.p, e)
return logpx
end
function lapl(x)
return Zygote.diaghessian(x -> sum(loss(x)), x)
end
data_dist = Normal(0.0f0, 1.0f0)
train_data = gpu(rand(data_dist, 2))
loss(train_data)
Error:
ERROR: LoadError: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.
Stacktrace:
[1] error(s::String)
@ Base .\error.jl:33
[2] assertscalar(op::String)
@ GPUArrays C:\Users\domin\.julia\packages\GPUArrays\gkF6S\src\host\indexing.jl:53
[3] getindex(::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Int64, ::Int64)
@ GPUArrays C:\Users\domin\.julia\packages\GPUArrays\gkF6S\src\host\indexing.jl:86
[4] getindex
@ C:\Users\domin\.julia\packages\RecursiveArrayTools\gr5FR\src\vector_of_array.jl:164 [inlined]
[5] macro expansion
@ .\multidimensional.jl:860 [inlined]
[6] macro expansion
@ .\cartesian.jl:64 [inlined]
[7] macro expansion
@ .\multidimensional.jl:855 [inlined]
[8] _unsafe_getindex!
@ .\multidimensional.jl:868 [inlined]
[9] _unsafe_getindex(::IndexCartesian, ::VectorOfArray{Float32, 3, Vector{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, ::Base.Slice{Base.OneTo{Int64}}, ::Base.Slice{Base.OneTo{Int64}}, ::Int64)
@ Base .\multidimensional.jl:846
[10] _getindex
@ .\multidimensional.jl:832 [inlined]
[11] getindex
@ .\abstractarray.jl:1170 [inlined]
[12] getindex(::ODESolution{Float32, 3, Vector{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, ODEProblem{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}, false, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ODEFunction{false, DiffEqFlux.var"#ffjord_#63"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{false, DiffEqFlux.var"#ffjord_#63"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Vector{Float32}, Vector{Vector{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}}, DiffEqBase.DEStats}, ::Colon, ::Colon, ::Int64)
@ SciMLBase C:\Users\domin\.julia\packages\SciMLBase\jj8Ix\src\solutions\solution_interface.jl:33
[13] forward_ffjord(n::FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, x::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, p::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, e::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}; regularize::Bool, monte_carlo::Bool)
@ DiffEqFlux C:\Users\domin\.julia\packages\DiffEqFlux\w4Zm0\src\ffjord.jl:219
[14] forward_ffjord(n::FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, x::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, p::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, e::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ DiffEqFlux C:\Users\domin\.julia\packages\DiffEqFlux\w4Zm0\src\ffjord.jl:203
[15] (::FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}})(::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::Vararg{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, N} where N; kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ DiffEqFlux C:\Users\domin\.julia\packages\DiffEqFlux\w4Zm0\src\ffjord.jl:199
[16] (::FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}})(::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::Vararg{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, N} where N)
@ DiffEqFlux C:\Users\domin\.julia\packages\DiffEqFlux\w4Zm0\src\ffjord.jl:199
[17] loss(x::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Main c:\Users\domin\Dropbox\code_practice\julia\diffeqflux\ffjord_test.jl:38
[18] top-level scope
@ c:\Users\domin\Dropbox\code_practice\julia\diffeqflux\ffjord_test.jl:48
in expression starting at c:\Users\domin\Dropbox\code_practice\julia\diffeqflux\ffjord_test.jl:48
Environment status:
[052768ef] CUDA v3.6.4
[aae7a2af] DiffEqFlux v1.44.1
[0c46a032] DifferentialEquations v7.1.0
[31c24e10] Distributions v0.25.38
https://github.com/SciML/DiffEqFlux.jl/pull/614 is probably the solution when it's finished.
A small update: https://github.com/FluxML/NNlibCUDA.jl/pull/48 fixes the original bug in this issue. However, there remains another bug (that now looks Zygote related) in the diaghessian call. The scalar indexing in the forward call of the loss also remains.