ChainRules.jl
ChainRules.jl copied to clipboard
Dict gradients leading to addition error after broadcasting change
As discussed on Slack there is an issue with dictionaries appearing in gradients. The following is as minimum an example as I could make.
This requires Molly master, Zygote master, ChainRules 1.44.2 and I am using Julia 1.7.2. The file ala5.pdb should be put in the current directory and is pasted below.
using Molly, Zygote
data_dir = joinpath(dirname(pathof(Molly)), "..", "data")
ff = OpenMMForceField(
joinpath(data_dir, "force_fields", "ff99SBildn.xml"),
joinpath(data_dir, "force_fields", "his.xml");
units=false,
)
sys = System(
"ala5.pdb",
ff;
boundary=CubicBoundary(500.0, 500.0, 500.0),
units=false,
gpu_diff_safe=true,
implicit_solvent="gbn2",
)
starting_coords = copy(sys.coords)
sim = Langevin(dt=0.001, temperature=300.0, friction=1.0)
params_dic = Dict(
"inter_LJ_weight_14" => 0.5,
"inter_CO_weight_14" => 0.5,
)
function loss(params_dic)
atoms, pis, sis, gis = inject_gradients(sys, params_dic)
sys2 = System(
atoms=atoms,
pairwise_inters=pis,
specific_inter_lists=sis,
general_inters=gis,
coords=copy(starting_coords),
boundary=sys.boundary,
neighbor_finder=sys.neighbor_finder,
force_units=NoUnits,
energy_units=NoUnits,
gpu_diff_safe=true,
)
simulate!(sys2, sim, 20)
return rmsd(sys2.coords, starting_coords)
end
loss(params_dic)
gradient(loss, params_dic)
The ala5.pdb file:
REMARK 1 CREATED WITH OPENMM 7.7, 2022-03-17
ATOM 1 N ALA A 1 -0.677 -1.230 -0.491 1.00 0.00 N
ATOM 2 H1 ALA A 1 -1.672 -1.326 0.175 1.00 0.00 H
ATOM 3 H2 ALA A 1 -0.205 -2.312 -0.284 1.00 0.00 H
ATOM 4 H3 ALA A 1 -1.142 -1.396 -1.586 1.00 0.00 H
ATOM 5 CA ALA A 1 -0.001 0.064 -0.491 1.00 0.00 C
ATOM 6 HA ALA A 1 -0.307 0.761 -1.410 1.00 0.00 H
ATOM 7 C ALA A 1 1.499 -0.110 -0.491 1.00 0.00 C
ATOM 8 O ALA A 1 2.233 0.524 -1.257 1.00 0.00 O
ATOM 9 CB ALA A 1 -0.509 0.856 0.727 1.00 0.00 C
ATOM 10 HB1 ALA A 1 -1.630 1.260 0.586 1.00 0.00 H
ATOM 11 HB2 ALA A 1 0.147 1.855 0.821 1.00 0.00 H
ATOM 12 HB3 ALA A 1 -0.513 0.440 1.850 1.00 0.00 H
ATOM 13 N ALA A 2 2.031 -0.947 0.335 1.00 0.00 N
ATOM 14 H ALA A 2 1.491 -1.234 1.355 1.00 0.00 H
ATOM 15 CA ALA A 2 3.481 -1.115 0.335 1.00 0.00 C
ATOM 16 HA ALA A 2 3.979 -0.110 0.741 1.00 0.00 H
ATOM 17 C ALA A 2 3.979 -1.516 -1.034 1.00 0.00 C
ATOM 18 O ALA A 2 4.951 -0.967 -1.565 1.00 0.00 O
ATOM 19 CB ALA A 2 3.832 -2.145 1.422 1.00 0.00 C
ATOM 20 HB1 ALA A 2 3.242 -2.174 2.466 1.00 0.00 H
ATOM 21 HB2 ALA A 2 3.903 -3.307 1.138 1.00 0.00 H
ATOM 22 HB3 ALA A 2 4.951 -1.909 1.785 1.00 0.00 H
ATOM 23 N ALA A 3 3.371 -2.461 -1.667 1.00 0.00 N
ATOM 24 H ALA A 3 2.703 -3.280 -1.122 1.00 0.00 H
ATOM 25 CA ALA A 3 3.852 -2.848 -2.990 1.00 0.00 C
ATOM 26 HA ALA A 3 4.957 -3.295 -2.907 1.00 0.00 H
ATOM 27 C ALA A 3 3.863 -1.666 -3.929 1.00 0.00 C
ATOM 28 O ALA A 3 4.836 -1.407 -4.647 1.00 0.00 O
ATOM 29 CB ALA A 3 2.962 -3.999 -3.492 1.00 0.00 C
ATOM 30 HB1 ALA A 3 3.402 -4.309 -4.563 1.00 0.00 H
ATOM 31 HB2 ALA A 3 1.783 -3.954 -3.696 1.00 0.00 H
ATOM 32 HB3 ALA A 3 3.081 -5.011 -2.859 1.00 0.00 H
ATOM 33 N ALA A 4 2.825 -0.902 -3.984 1.00 0.00 N
ATOM 34 H ALA A 4 1.758 -1.411 -3.866 1.00 0.00 H
ATOM 35 CA ALA A 4 2.836 0.242 -4.892 1.00 0.00 C
ATOM 36 HA ALA A 4 2.957 -0.134 -6.020 1.00 0.00 H
ATOM 37 C ALA A 4 4.002 1.154 -4.597 1.00 0.00 C
ATOM 38 O ALA A 4 4.737 1.590 -5.492 1.00 0.00 O
ATOM 39 CB ALA A 4 1.477 0.951 -4.765 1.00 0.00 C
ATOM 40 HB1 ALA A 4 1.191 1.837 -4.012 1.00 0.00 H
ATOM 41 HB2 ALA A 4 1.384 1.543 -5.807 1.00 0.00 H
ATOM 42 HB3 ALA A 4 0.476 0.293 -4.808 1.00 0.00 H
ATOM 43 N ALA A 5 4.239 1.491 -3.374 1.00 0.00 N
ATOM 44 H ALA A 5 3.314 2.039 -2.863 1.00 0.00 H
ATOM 45 CA ALA A 5 5.366 2.373 -3.090 1.00 0.00 C
ATOM 46 HA ALA A 5 5.290 3.438 -3.625 1.00 0.00 H
ATOM 47 C ALA A 5 6.657 1.787 -3.611 1.00 0.00 C
ATOM 48 O ALA A 5 6.692 0.703 -4.205 1.00 0.00 O
ATOM 49 CB ALA A 5 5.395 2.620 -1.571 1.00 0.00 C
ATOM 50 HB1 ALA A 5 4.470 3.196 -1.074 1.00 0.00 H
ATOM 51 HB2 ALA A 5 6.276 3.405 -1.348 1.00 0.00 H
ATOM 52 HB3 ALA A 5 5.721 1.705 -0.874 1.00 0.00 H
ATOM 53 OXT ALA A 5 7.648 2.530 -3.396 1.00 0.00 O
TER 54 ALA A 5
END
On ChainRules up to 1.42.0 this worked, on 1.43.0-1.44.1 it errors with a different error fixed by https://github.com/JuliaDiff/ChainRules.jl/pull/661, and on 1.44.2 it errors as follows:
ERROR: LoadError: MethodError: no method matching +(::Dict{Any, Any}, ::Dict{Any, Any})
Closest candidates are:
+(::Any, ::Any, ::Any, ::Any...) at ~/soft/julia/julia-1.7.2/share/julia/base/operators.jl:655
+(::Union{InitialValues.NonspecificInitialValue, InitialValues.SpecificInitialValue{typeof(+)}}, ::Any) at ~/.julia/packages/InitialValues/OWP8V/src/InitialValues.jl:154
+(::Dict, ::ChainRulesCore.Tangent{P}) where P at ~/.julia/packages/ChainRulesCore/ctmSK/src/tangent_arithmetic.jl:145
...
Stacktrace:
[1] add_sum(x::Dict{Any, Any}, y::Dict{Any, Any})
@ Base ./reduce.jl:24
[2] _mapreduce
@ ./reduce.jl:410 [inlined]
[3] _mapreduce_dim
@ ./reducedim.jl:330 [inlined]
[4] #mapreduce#725
@ ./reducedim.jl:322 [inlined]
[5] mapreduce
@ ./reducedim.jl:322 [inlined]
[6] #_sum#735
@ ./reducedim.jl:894 [inlined]
[7] _sum
@ ./reducedim.jl:894 [inlined]
[8] #_sum#734
@ ./reducedim.jl:893 [inlined]
[9] _sum
@ ./reducedim.jl:893 [inlined]
[10] #sum#732
@ ./reducedim.jl:889 [inlined]
[11] sum
@ ./reducedim.jl:889 [inlined]
[12] unbroadcast
@ ~/.julia/dev/ChainRules/src/rulesets/Base/broadcast.jl:348 [inlined]
[13] map(f::typeof(ChainRules.unbroadcast), t::Tuple{Tuple{LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{Dict{String, Float64}}}, s::Tuple{Vector{ChainRulesCore.Tangent}, Vector{Dict{Any, Any}}})
@ Base ./tuple.jl:247
[14] (::ChainRules.var"#back_generic#1708"{typeof(Molly.inject_interaction), Tuple{Zygote.var"#ad_pullback#50"{Tuple{typeof(Molly.inject_interaction), LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Dict{String, Float64}}, typeof(∂(inject_interaction))}, Zygote.var"#ad_pullback#50"{Tuple{typeof(Molly.inject_interaction), Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Dict{String, Float64}}, typeof(∂(inject_interaction))}}, Tuple{Tuple{LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{Dict{String, Float64}}}})(dys::ChainRulesCore.Tangent{Any, Tuple{LennardJones{false, Nothing, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{Nothing, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}})
@ ChainRules ~/.julia/dev/ChainRules/src/rulesets/Base/broadcast.jl:134
[15] ZBack
@ ~/.julia/dev/Zygote/src/compiler/chainrules.jl:206 [inlined]
[16] (::Zygote.var"#208#209"{Tuple{NTuple{4, Nothing}, Tuple{}}, Zygote.ZBack{ChainRules.var"#back_generic#1708"{typeof(Molly.inject_interaction), Tuple{Zygote.var"#ad_pullback#50"{Tuple{typeof(Molly.inject_interaction), LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Dict{String, Float64}}, typeof(∂(inject_interaction))}, Zygote.var"#ad_pullback#50"{Tuple{typeof(Molly.inject_interaction), Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Dict{String, Float64}}, typeof(∂(inject_interaction))}}, Tuple{Tuple{LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{Dict{String, Float64}}}}}})(Δ::Tuple{LennardJones{false, Nothing, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{Nothing, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}})
@ Zygote ~/.julia/dev/Zygote/src/lib/lib.jl:206
[17] (::Zygote.var"#1914#back#210"{Zygote.var"#208#209"{Tuple{NTuple{4, Nothing}, Tuple{}}, Zygote.ZBack{ChainRules.var"#back_generic#1708"{typeof(Molly.inject_interaction), Tuple{Zygote.var"#ad_pullback#50"{Tuple{typeof(Molly.inject_interaction), LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Dict{String, Float64}}, typeof(∂(inject_interaction))}, Zygote.var"#ad_pullback#50"{Tuple{typeof(Molly.inject_interaction), Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Dict{String, Float64}}, typeof(∂(inject_interaction))}}, Tuple{Tuple{LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{Dict{String, Float64}}}}}}})(Δ::Tuple{LennardJones{false, Nothing, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{Nothing, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}})
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[18] Pullback
@ ./broadcast.jl:1303 [inlined]
[19] (::typeof(∂(broadcasted)))(Δ::Tuple{LennardJones{false, Nothing, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{Nothing, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}})
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
[20] Pullback
@ ~/.julia/dev/Molly/src/gradients.jl:98 [inlined]
[21] (::typeof(∂(inject_gradients)))(Δ::Tuple{Vector{Atom{Float64, Float64, Float64, Float64}}, Tuple{LennardJones{false, Nothing, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{Nothing, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{NamedTuple{(:is, :js, :types, :inters), Tuple{Nothing, Nothing, Nothing, Vector{Tuple{Float64, Float64}}}}, NamedTuple{(:is, :js, :ks, :types, :inters), Tuple{Nothing, Nothing, Nothing, Nothing, Vector{Tuple{Float64, Float64}}}}, NamedTuple{(:is, :js, :ks, :ls, :types, :inters), Tuple{Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Tuple{NTuple{6, Int64}, NTuple{6, Float64}, NTuple{6, Float64}, Bool}}}}, NamedTuple{(:is, :js, :ks, :ls, :types, :inters), Tuple{Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Tuple{NTuple{6, Int64}, NTuple{6, Float64}, NTuple{6, Float64}, Bool}}}}}, Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}}})
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
[22] Pullback
@ ~/.julia/dev/Molly/src/gradients.jl:92 [inlined]
[23] (::typeof(∂(inject_gradients)))(Δ::Tuple{Vector{Atom{Float64, Float64, Float64, Float64}}, Tuple{LennardJones{false, Nothing, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{Nothing, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{NamedTuple{(:is, :js, :types, :inters), Tuple{Nothing, Nothing, Nothing, Vector{Tuple{Float64, Float64}}}}, NamedTuple{(:is, :js, :ks, :types, :inters), Tuple{Nothing, Nothing, Nothing, Nothing, Vector{Tuple{Float64, Float64}}}}, NamedTuple{(:is, :js, :ks, :ls, :types, :inters), Tuple{Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Tuple{NTuple{6, Int64}, NTuple{6, Float64}, NTuple{6, Float64}, Bool}}}}, NamedTuple{(:is, :js, :ks, :ls, :types, :inters), Tuple{Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Tuple{NTuple{6, Int64}, NTuple{6, Float64}, NTuple{6, Float64}, Bool}}}}}, Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}}})
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
[24] Pullback
@ ~/dms/molly_dev/grad_err.jl:31 [inlined]
[25] (::typeof(∂(loss)))(Δ::Float64)
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
[26] (::Zygote.var"#60#61"{typeof(∂(loss))})(Δ::Float64)
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:45
[27] gradient(f::Function, args::Dict{String, Float64})
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:97
[28] top-level scope
@ ~/dms/molly_dev/grad_err.jl:50
Adding ChainRulesCore.@opt_out rrule(cfg::Zygote.ZygoteRuleConfig, ::typeof(Broadcast.broadcasted), ::Broadcast.BroadcastStyle, f::F, args::Vararg{Any, N}) where {F, N} to Molly as suggested by @mcabbott gives a different error:
ERROR: LoadError: Need an adjoint for constructor Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple}, Nothing, typeof(Molly.inject_interaction), Tuple{Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Tuple{Dict{String, Float64}}, Tuple{System{3, true, Float64, false, Vector{Atom{Float64, Float64, Float64, Float64}}, Vector{AtomData}, Tuple{LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{InteractionList2Atoms{Vector{HarmonicBond{Float64, Float64}}}, InteractionList3Atoms{Vector{HarmonicAngle{Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}}, Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Vector{SVector{3, Float64}}, Vector{SVector{3, Float64}}, CubicBoundary{Float64}, DistanceVecNeighborFinder{Float64, BitMatrix, Matrix{Int64}}, Tuple{}, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}, Float64}}}}. Gradient is of type Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}}
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] (::Zygote.Jnew{Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple}, Nothing, typeof(Molly.inject_interaction), Tuple{Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Tuple{Dict{String, Float64}}, Tuple{System{3, true, Float64, false, Vector{Atom{Float64, Float64, Float64, Float64}}, Vector{AtomData}, Tuple{LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{InteractionList2Atoms{Vector{HarmonicBond{Float64, Float64}}}, InteractionList3Atoms{Vector{HarmonicAngle{Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}}, Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Vector{SVector{3, Float64}}, Vector{SVector{3, Float64}}, CubicBoundary{Float64}, DistanceVecNeighborFinder{Float64, BitMatrix, Matrix{Int64}}, Tuple{}, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}, Float64}}}}, Nothing, false})(Δ::Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}})
@ Zygote ~/.julia/dev/Zygote/src/lib/lib.jl:327
[3] (::Zygote.var"#1948#back#224"{Zygote.Jnew{Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple}, Nothing, typeof(Molly.inject_interaction), Tuple{Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Tuple{Dict{String, Float64}}, Tuple{System{3, true, Float64, false, Vector{Atom{Float64, Float64, Float64, Float64}}, Vector{AtomData}, Tuple{LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{InteractionList2Atoms{Vector{HarmonicBond{Float64, Float64}}}, InteractionList3Atoms{Vector{HarmonicAngle{Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}}, Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Vector{SVector{3, Float64}}, Vector{SVector{3, Float64}}, CubicBoundary{Float64}, DistanceVecNeighborFinder{Float64, BitMatrix, Matrix{Int64}}, Tuple{}, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}, Float64}}}}, Nothing, false}})(Δ::Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}})
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[4] Pullback
@ ./broadcast.jl:170 [inlined]
[5] (::typeof(∂(Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple}, Nothing, typeof(Molly.inject_interaction), Tuple{Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Tuple{Dict{String, Float64}}, Tuple{System{3, true, Float64, false, Vector{Atom{Float64, Float64, Float64, Float64}}, Vector{AtomData}, Tuple{LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{InteractionList2Atoms{Vector{HarmonicBond{Float64, Float64}}}, InteractionList3Atoms{Vector{HarmonicAngle{Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}}, Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Vector{SVector{3, Float64}}, Vector{SVector{3, Float64}}, CubicBoundary{Float64}, DistanceVecNeighborFinder{Float64, BitMatrix, Matrix{Int64}}, Tuple{}, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}, Float64}}}})))(Δ::Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}})
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
[6] Pullback
@ ./broadcast.jl:179 [inlined]
[7] (::typeof(∂(Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple}})))(Δ::Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}})
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
[8] Pullback
@ ./broadcast.jl:179 [inlined]
[9] (::typeof(∂(Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple}})))(Δ::Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}})
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
[10] Pullback
@ ./broadcast.jl:1305 [inlined]
[11] (::Zygote.var"#208#209"{Tuple{NTuple{4, Nothing}, Tuple{Nothing}}, typeof(∂(broadcasted))})(Δ::Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}})
@ Zygote ~/.julia/dev/Zygote/src/lib/lib.jl:206
[12] (::Zygote.var"#1914#back#210"{Zygote.var"#208#209"{Tuple{NTuple{4, Nothing}, Tuple{Nothing}}, typeof(∂(broadcasted))}})(Δ::Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}})
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[13] Pullback
@ ./broadcast.jl:1303 [inlined]
[14] (::typeof(∂(broadcasted)))(Δ::Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}})
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
[15] Pullback
@ ~/.julia/dev/Molly/src/gradients.jl:108 [inlined]
[16] (::typeof(∂(inject_gradients)))(Δ::Tuple{Vector{Atom{Float64, Float64, Float64, Float64}}, Tuple{LennardJones{false, Nothing, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{Nothing, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{NamedTuple{(:is, :js, :types, :inters), Tuple{Nothing, Nothing, Nothing, Vector{Tuple{Float64, Float64}}}}, NamedTuple{(:is, :js, :ks, :types, :inters), Tuple{Nothing, Nothing, Nothing, Nothing, Vector{Tuple{Float64, Float64}}}}, NamedTuple{(:is, :js, :ks, :ls, :types, :inters), Tuple{Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Tuple{NTuple{6, Int64}, NTuple{6, Float64}, NTuple{6, Float64}, Bool}}}}, NamedTuple{(:is, :js, :ks, :ls, :types, :inters), Tuple{Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Tuple{NTuple{6, Int64}, NTuple{6, Float64}, NTuple{6, Float64}, Bool}}}}}, Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}}})
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
[17] Pullback
@ ~/.julia/dev/Molly/src/gradients.jl:92 [inlined]
[18] (::typeof(∂(inject_gradients)))(Δ::Tuple{Vector{Atom{Float64, Float64, Float64, Float64}}, Tuple{LennardJones{false, Nothing, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{Nothing, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{NamedTuple{(:is, :js, :types, :inters), Tuple{Nothing, Nothing, Nothing, Vector{Tuple{Float64, Float64}}}}, NamedTuple{(:is, :js, :ks, :types, :inters), Tuple{Nothing, Nothing, Nothing, Nothing, Vector{Tuple{Float64, Float64}}}}, NamedTuple{(:is, :js, :ks, :ls, :types, :inters), Tuple{Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Tuple{NTuple{6, Int64}, NTuple{6, Float64}, NTuple{6, Float64}, Bool}}}}, NamedTuple{(:is, :js, :ks, :ls, :types, :inters), Tuple{Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Tuple{NTuple{6, Int64}, NTuple{6, Float64}, NTuple{6, Float64}, Bool}}}}}, Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}}})
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
[19] Pullback
@ ~/dms/molly_dev/grad_err.jl:31 [inlined]
[20] (::typeof(∂(loss)))(Δ::Float64)
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
[21] (::Zygote.var"#60#61"{typeof(∂(loss))})(Δ::Float64)
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:45
[22] gradient(f::Function, args::Dict{String, Float64})
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:97
[23] top-level scope
@ ~/dms/molly_dev/grad_err.jl:50
Commenting out either the "inter_LJ_weight_14" => 0.5, or "inter_CO_weight_14" => 0.5, lines makes it work, presumably because no dictionaries have to be added in the case of one gradient.
Can reproduce the errors shown. Note BTW that Molly only seems to work (for me) on 1.7, not on Julia 1.8. (But maybe it's an Apple M1 problem.)
Without the opt-out, @debug prints some things. And the shortcut in rrule_via_ad does not seem to be involved in calling these rules:
(jl_GvojyS) pkg> st
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_GvojyS/Project.toml`
[082447d4] ChainRules v1.44.2
[aa0f7f06] Molly v0.13.0
[e88e6eb3] Zygote v0.6.43
julia> ENV["JULIA_DEBUG"] = ChainRules;
# This isn't in fact called:
julia> @eval Zygote function ChainRulesCore.rrule_via_ad(config::ZygoteRuleConfig, f_args...; kwargs...)
# first check whether there is an `rrule` which handles this directly
direct = rrule(config, f_args...; kwargs...)
f = f_args[1]
direct === nothing || (@info "rrule shortcut" f; return direct)
# create a closure to work around _pullback not accepting kwargs
# but avoid creating a closure unnecessarily (pullbacks of closures do not infer)
y, pb = if !isempty(kwargs)
kwf() = first(f_args)(Base.tail(f_args)...; kwargs...)
_y, _pb = _pullback(config.context, kwf)
_y, Δ -> first(_pb(Δ)).f_args # `first` should be `only`
else
_pullback(config.context, f_args...)
end
ad_pullback(Δ) = zygote2differential(pb(wrap_chainrules_output(Δ)), f_args)
return y, ad_pullback
end;
# Note that the T == Bool path is called many times, no @info here
julia> @eval Zygote @adjoint function broadcasted(::AbstractArrayStyle, f::F, args...) where {F}
T = Broadcast.combine_eltypes(f, args)
# Avoid generic broadcasting in two easy cases:
if T == Bool
return (f.(args...), _ -> nothing)
elseif T <: Real && isconcretetype(T) && _dual_purefun(F) && all(_dual_safearg, args) && !isderiving()
@info "Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> broadcast_forward" f
return broadcast_forward(f, args...)
end
len = inclen(args)
@info "Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks" f
y∂b = _broadcast((x...) -> _pullback(__context__, f, x...), args...)
y = map(first, y∂b)
function ∇broadcasted(ȳ)
dxs_zip = map(((_, pb), ȳ₁) -> pb(ȳ₁), y∂b, ȳ)
dxs = ntuple(len) do i
collapse_nothings(map(StaticGetter{i}(), dxs_zip))
end
(nothing, accum_sum(dxs[1]), map(unbroadcast, args, Base.tail(dxs))...)
end
return y, ∇broadcasted
end
# Random easy test
julia> gradient(xs -> sum((x -> sin(x)).(xs)), [1,2,3]/4)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> broadcast_forward
└ f = #6 (generic function with 1 method)
([0.9689124217106447, 0.8775825618903728, 0.7316888688738209],)
# MWE from above
julia> @time gradient(loss, params_dic)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = inject_atom (generic function with 1 method)
┌ Debug: split broadcasting generic
│ f = inject_interaction (generic function with 7 methods)
│ N = 2
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:126
┌ Debug: split broadcasting generic
│ f = inject_interaction_list (generic function with 4 methods)
│ N = 3
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:126
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = inject_interaction (generic function with 7 methods)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = inject_interaction (generic function with 7 methods)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = inject_interaction (generic function with 7 methods)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = inject_interaction (generic function with 7 methods)
┌ Debug: split broadcasting generic
│ f = inject_interaction (generic function with 7 methods)
│ N = 3
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:126
ERROR: MethodError: no method matching +(::Dict{Any, Any}, ::Dict{Any, Any})
(Error as above.)
With the opt-out, it's the second broadcast above, with inject_interaction, which fails:
julia> ChainRulesCore.@opt_out rrule(cfg::Zygote.ZygoteRuleConfig, ::typeof(Broadcast.broadcasted), ::Broadcast.BroadcastStyle, f::F, args::Vararg{Any, N}) where {F, N}
julia> gradient(xs -> sum((x -> sin(x)).(xs)), [1,2,3]/4)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> broadcast_forward
└ f = #18 (generic function with 1 method)
([0.9689124217106447, 0.8775825618903728, 0.7316888688738209],)
julia> gradient((xs, y) -> sum((x -> sin(x/y)).(xs)), [1,2,3]/4, 5/6)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = #22 (generic function with 1 method)
([1.146403786950727, 0.9904027378916139, 0.7459319619247974], -1.609501544552504)
julia> @time gradient(loss, params_dic)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = inject_atom (generic function with 1 method)
ERROR: Need an adjoint for constructor Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple}, Nothing, typeof(Molly.inject_interaction), Tuple{Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Tuple{Dict{String, Float64}}, Tuple{System{3, true, Float64, false, Vector{Atom{Float64, Float64, Float64, Float64}}, Vector{AtomData}, Tuple{LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{InteractionList2Atoms{Vector{HarmonicBond{Float64, Float64}}}, InteractionList3Atoms{Vector{HarmonicAngle{Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}}, Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Vector{SVector{3, Float64}}, Vector{SVector{3, Float64}}, CubicBoundary{Float64}, DistanceVecNeighborFinder{Float64, BitMatrix, Matrix{Int64}}, Tuple{}, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}, Float64}}}}. Gradient is of type Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}}
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] (::Zygote.Jnew{Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple}, Nothing, typeof(Molly.inject_interaction), Tuple{Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Tuple{Dict{String, Float64}}, Tuple{System{3, true, Float64, false, Vector{Atom{Float64, Float64, Float64, Float64}}, Vector{AtomData}, Tuple{LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{InteractionList2Atoms{Vector{HarmonicBond{Float64, Float64}}}, InteractionList3Atoms{Vector{HarmonicAngle{Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}}, Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Vector{SVector{3, Float64}}, Vector{SVector{3, Float64}}, CubicBoundary{Float64}, DistanceVecNeighborFinder{Float64, BitMatrix, Matrix{Int64}}, Tuple{}, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}, Float64}}}}, Nothing, false})(Δ::Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}})
@ Zygote ~/.julia/packages/Zygote/D7j8v/src/lib/lib.jl:326
So plausibly it's a failure of the opt-out mechanism? Where above it uses the rrule, here it instead seems to not find the @adjoint rule at all, perhaps?
Now I see. The CR rule accepts any BroadcastStyle, to handle tuples too, while the Zygote one restricts to AbstractArrayStyle. The cases where the CR rule is called all have Broadcast.Style{Tuple}():
julia> @eval ChainRules function rrule(cfg::RCR, ::typeof(broadcasted), style::BroadcastStyle, f::F, args::Vararg{Any,N}) where {F,N}
@debug "called the rrule!" f style
T = Broadcast.combine_eltypes(f, args)
if T === Bool # TODO use nondifftype here
# 1: Trivial case: non-differentiable output, e.g. `x .> 0`
@debug("split broadcasting trivial", f, T)
bc_trivial_back(_) = (TRI_NO..., ntuple(Returns(ZeroTangent()), length(args))...)
return f.(args...), bc_trivial_back
elseif T <: Number && may_bc_derivatives(T, f, args...)
# 2: Fast path: use arguments & result to find derivatives.
return split_bc_derivatives(f, args...)
elseif T <: Number && may_bc_forwards(cfg, f, args...)
# 3: Future path: use `frule_via_ad`?
return split_bc_forwards(cfg, f, args...)
else
# 4: Slow path: collect all the pullbacks & apply them later.
return split_bc_pullbacks(cfg, f, args...)
end
end
rrule (generic function with 1065 methods)
julia> @time gradient(loss, params_dic)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = inject_atom (generic function with 1 method)
┌ Debug: called the rrule!
│ f = inject_interaction (generic function with 7 methods)
│ style = Base.Broadcast.Style{Tuple}()
└ @ ChainRules REPL[37]:2
┌ Debug: split broadcasting generic
│ f = inject_interaction (generic function with 7 methods)
│ N = 2
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:126
┌ Debug: called the rrule!
│ f = inject_interaction_list (generic function with 4 methods)
│ style = Base.Broadcast.Style{Tuple}()
└ @ ChainRules REPL[37]:2
┌ Debug: split broadcasting generic
│ f = inject_interaction_list (generic function with 4 methods)
│ N = 3
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:126
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = inject_interaction (generic function with 7 methods)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = inject_interaction (generic function with 7 methods)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = inject_interaction (generic function with 7 methods)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = inject_interaction (generic function with 7 methods)
┌ Debug: called the rrule!
│ f = inject_interaction (generic function with 7 methods)
│ style = Base.Broadcast.Style{Tuple}()
└ @ ChainRules REPL[37]:2
┌ Debug: split broadcasting generic
│ f = inject_interaction (generic function with 7 methods)
│ N = 3
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:126
ERROR: MethodError: no method matching +(::Dict{Any, Any}, ::Dict{Any, Any})
This is not solved by changing the rrule's signature to match the @adjoint rule, and reject Broadcast.Style{Tuple}. With https://github.com/JuliaDiff/ChainRules.jl/commit/6e383c10f8dfc0731be64a51bc3a33a6c7d21f5b you get the same error as with the @opt_out:
julia> @time gradient(loss, params_dic)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = inject_atom (generic function with 1 method)
ERROR: Need an adjoint for constructor Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple}, Nothing, typeof(Molly.inject_interaction), Tuple{Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Tuple{Dict{String, Float64}}, Tuple{System{3, true, Float64, false, Vector{Atom{Float64, Float64, Float64, Float64}}, Vector{AtomData}, Tuple{LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{InteractionList2Atoms{Vector{HarmonicBond{Float64, Float64}}}, InteractionList3Atoms{Vector{HarmonicAngle{Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}}, Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Vector{SVector{3, Float64}}, Vector{SVector{3, Float64}}, CubicBoundary{Float64}, DistanceVecNeighborFinder{Float64, BitMatrix, Matrix{Int64}}, Tuple{}, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}, Float64}}}}. Gradient is of type Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}}
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] (::Zygote.Jnew{Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple}, Nothing, typeof(Molly.inject_interaction), Tuple{Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Tuple{Dict{String, Float64}}, Tuple{System{3, true, Float64, false, Vector{Atom{Float64, Float64, Float64, Float64}}, Vector{AtomData}, Tuple{LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{InteractionList2Atoms{Vector{HarmonicBond{Float64, Float64}}}, InteractionList3Atoms{Vector{HarmonicAngle{Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}}, Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Vector{SVector{3, Float64}}, Vector{SVector{3, Float64}}, CubicBoundary{Float64}, DistanceVecNeighborFinder{Float64, BitMatrix, Matrix{Int64}}, Tuple{}, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}, Float64}}}}, Nothing, false})(Δ::Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}})
This is a zygote bug. I wish i i could transfer this issue there