Zygote.jl
Zygote.jl copied to clipboard
Implicit params: no gradient for `Array` element of `Vector{AbstractArray}` when parent `Vector` used in AD
I’m seeing incorrectly computed gradients when I use a vector of matrices as a model in Flux. The implicit gradients compute to 0 but the structural gradients seem to be computed correctly.
MWE:
using Flux
struct VecofMat{T}
W::T
end
Flux.@functor VecofMat
function (model::VecofMat)(x)
mapreduce(m->m*x, +, model.W)
end
function implicit_grads(model, x)
gs = gradient(()->sum(model(x)), params(model))
return gs.grads
end
function structural_grads(model, x)
gs = gradient((model)->sum(model(x)), model)
return gs[1]
end
x = rand(2)
model = VecofMat([rand(2,2) for i = 1:3])
@show implicit_grads(model, x)
@show structural_grads(model, x)
Outputs:
julia> implicit_grads(model, x)
IdDict{Any,Any} with 3 entries:
[0.363913 0.0393756; 0.827627 0.923598] => nothing
[0.410637 0.17915; 0.224523 0.402422] => nothing
[0.129872 0.958427; 0.827143 0.950434] => nothing
julia> structural_grads(model, x)
(W = [[0.006743089207722486 0.5560052133967388; 0.006743089207722486 0.5560052133967388], [0.006743089207722486 0.5560052133967388; 0.006743089207722486 0.5560052133967388], [0.006743089207722486 0.5560052133967388; 0.006743089207722486 0.5560052133967388]],)
I'm about 80% sure this is the same issue as https://discourse.julialang.org/t/flux-jl-inconsistent-training-on-custom-architecture/63326/10?u=touchesir (Zygote not being able to track elements of a Vector{AbstractArray}
). Re-posting the MWE I wrote there and moving this to Zygote so it gets on the radar of AD people.
julia> using Zygote, LinearAlgebra
julia> X = [rand(2, 2)]
1-element Vector{Matrix{Float64}}:
[0.6202008269430048 0.8555679159356662; 0.8423362289177463 0.09771425479926421]
# 1. this doesn't work
julia> gradient(() -> norm(X[1]), Params(X)).grads
IdDict{Any, Any} with 2 entries:
[0.620201 0.855568; 0.842336 0.0977143] => nothing # gradient wrt. X[1]
:(Main.X) => Union{Nothing, Matrix{Float64}}[[0.45775 0.631467; 0.621701 0.0721198]]
# 2. but this does
julia> gradient(() -> norm(X[1]), Params([X])).grads
IdDict{Any, Any} with 2 entries:
:(Main.X) => Union{Nothing, Matrix{Float64}}[[0.45775 0.631467; 0.621701 0.0721198]]
[[0.620201 0.855568; 0.842336 0.0977143]] => Union{Nothing, Matrix{Float64}}[[0.45775 0.631467; 0.621701 0.0721198]]
# 3. as does this
julia> x₁ = X[1]
2×2 Matrix{Float64}:
0.620201 0.855568
0.842336 0.0977143
julia> gradient(() -> norm(x₁), Params(X)).grads
IdDict{Any, Any} with 2 entries:
[0.620201 0.855568; 0.842336 0.0977143] => [0.45775 0.631467; 0.621701 0.0721198]
:(Main.x₁) => [0.45775 0.631467; 0.621701 0.0721198]
# 4. and this (note explicit instead of implicit parameters.
# That is, we pass X directly and use it instead of params(X)).
# This works with full Flux models too!
julia> gradient(x -> norm(x[1]), X)[1]
1-element Vector{Union{Nothing, Matrix{Float64}}}:
[0.4577503206426886 0.6314672132598408; 0.6217013298363201 0.07211975463850043]
You need to wrap the parameter in question in a vector. In this case the correct representation would be a vector of vector of matrices.
That's what I showed with ex. 2 above. The question is whether Zygote should be able to figure out that getindex
on the parent vector returns the child array in the params collection. I would say this should be consistent with other indexable datatypes, c.f:
julia> X = (rand(2, 2),);
julia> gradient(() -> norm(X[1]), Params(X)).grads
IdDict{Any, Any} with 1 entry:
[0.411909 0.505953; 0.154458 0.240884] => [0.578184 0.710191; 0.216809 0.338122]
julia> X = (; foo=rand(2, 2));
julia> gradient(() -> norm(X[1]), Params(X)).grads
IdDict{Any, Any} with 1 entry:
[0.738327 0.943921; 0.906215 0.620752] => [0.454179 0.580649; 0.557455 0.381853]
julia> gradient(() -> norm(X.foo), Params(X)).grads
IdDict{Any, Any} with 1 entry:
[0.738327 0.943921; 0.906215 0.620752] => [0.454179 0.580649; 0.557455 0.381853]
Note that the Params are the same in all 3 cases (vector, tuple and nameduple). However, Zygote can figure out a gradient for the latter 2 and not for the vector. If we say this is a WONTFIX, we should provide a clear explanation as to why this inconsistency exists. Otherwise I think it's a bug.