Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

Implicit params: no gradient for `Array` element of `Vector{AbstractArray}` when parent `Vector` used in AD

Open arlk opened this issue 3 years ago • 3 comments

I’m seeing incorrectly computed gradients when I use a vector of matrices as a model in Flux. The implicit gradients compute to 0 but the structural gradients seem to be computed correctly.

MWE:

using Flux

struct VecofMat{T}
    W::T
end

Flux.@functor VecofMat

function (model::VecofMat)(x)
    mapreduce(m->m*x, +, model.W)
end

function implicit_grads(model, x)
    gs = gradient(()->sum(model(x)), params(model))
    return gs.grads
end

function structural_grads(model, x)
    gs = gradient((model)->sum(model(x)), model)
    return gs[1]
end

x = rand(2)
model = VecofMat([rand(2,2) for i = 1:3])
@show implicit_grads(model, x)
@show structural_grads(model, x)

Outputs:

julia> implicit_grads(model, x)
IdDict{Any,Any} with 3 entries:
  [0.363913 0.0393756; 0.827627 0.923598] => nothing
  [0.410637 0.17915; 0.224523 0.402422]   => nothing
  [0.129872 0.958427; 0.827143 0.950434]  => nothing

julia> structural_grads(model, x)
(W = [[0.006743089207722486 0.5560052133967388; 0.006743089207722486 0.5560052133967388], [0.006743089207722486 0.5560052133967388; 0.006743089207722486 0.5560052133967388], [0.006743089207722486 0.5560052133967388; 0.006743089207722486 0.5560052133967388]],)

arlk avatar Mar 12 '21 19:03 arlk

I'm about 80% sure this is the same issue as https://discourse.julialang.org/t/flux-jl-inconsistent-training-on-custom-architecture/63326/10?u=touchesir (Zygote not being able to track elements of a Vector{AbstractArray}). Re-posting the MWE I wrote there and moving this to Zygote so it gets on the radar of AD people.

julia> using Zygote, LinearAlgebra

julia> X = [rand(2, 2)]
1-element Vector{Matrix{Float64}}:
 [0.6202008269430048 0.8555679159356662; 0.8423362289177463 0.09771425479926421]

# 1. this doesn't work
julia> gradient(() -> norm(X[1]), Params(X)).grads
IdDict{Any, Any} with 2 entries:
  [0.620201 0.855568; 0.842336 0.0977143] => nothing  # gradient wrt. X[1]
  :(Main.X)                               => Union{Nothing, Matrix{Float64}}[[0.45775 0.631467; 0.621701 0.0721198]]

# 2. but this does
julia> gradient(() -> norm(X[1]), Params([X])).grads
IdDict{Any, Any} with 2 entries:
  :(Main.X)                                 => Union{Nothing, Matrix{Float64}}[[0.45775 0.631467; 0.621701 0.0721198]]
  [[0.620201 0.855568; 0.842336 0.0977143]] => Union{Nothing, Matrix{Float64}}[[0.45775 0.631467; 0.621701 0.0721198]]

# 3. as does this
julia> x₁ = X[1]
2×2 Matrix{Float64}:
 0.620201  0.855568
 0.842336  0.0977143

julia> gradient(() -> norm(x₁), Params(X)).grads
IdDict{Any, Any} with 2 entries:
  [0.620201 0.855568; 0.842336 0.0977143] => [0.45775 0.631467; 0.621701 0.0721198]
  :(Main.x₁)                              => [0.45775 0.631467; 0.621701 0.0721198]

# 4. and this (note explicit instead of implicit parameters.
# That is, we pass X directly and use it instead of params(X)). 
# This works with full Flux models too!
julia> gradient(x -> norm(x[1]), X)[1]
1-element Vector{Union{Nothing, Matrix{Float64}}}:
 [0.4577503206426886 0.6314672132598408; 0.6217013298363201 0.07211975463850043]

ToucheSir avatar Jul 03 '21 00:07 ToucheSir

You need to wrap the parameter in question in a vector. In this case the correct representation would be a vector of vector of matrices.

DhairyaLGandhi avatar Jul 03 '21 05:07 DhairyaLGandhi

That's what I showed with ex. 2 above. The question is whether Zygote should be able to figure out that getindex on the parent vector returns the child array in the params collection. I would say this should be consistent with other indexable datatypes, c.f:

julia> X = (rand(2, 2),);

julia> gradient(() -> norm(X[1]), Params(X)).grads
IdDict{Any, Any} with 1 entry:
  [0.411909 0.505953; 0.154458 0.240884] => [0.578184 0.710191; 0.216809 0.338122]

julia> X = (; foo=rand(2, 2));

julia> gradient(() -> norm(X[1]), Params(X)).grads
IdDict{Any, Any} with 1 entry:
  [0.738327 0.943921; 0.906215 0.620752] => [0.454179 0.580649; 0.557455 0.381853]

julia> gradient(() -> norm(X.foo), Params(X)).grads
IdDict{Any, Any} with 1 entry:
  [0.738327 0.943921; 0.906215 0.620752] => [0.454179 0.580649; 0.557455 0.381853]

Note that the Params are the same in all 3 cases (vector, tuple and nameduple). However, Zygote can figure out a gradient for the latter 2 and not for the vector. If we say this is a WONTFIX, we should provide a clear explanation as to why this inconsistency exists. Otherwise I think it's a bug.

ToucheSir avatar Jul 03 '21 17:07 ToucheSir