Flux.jl icon indicating copy to clipboard operation
Flux.jl copied to clipboard

Tied weights using Flux layers

Open dfenn opened this issue 3 years ago • 9 comments

I'm trying to build an autoencoder that uses both conv and dense layers, and I'd like to tie the weights. #488 demonstrates how to do this for dense layers by not using the Flux Dense type and instead using the encoder's weights directly.

Is there a way to accomplish something similar while still using Flux-defined layer types, such as Conv? I've tried manually setting the decoder parameters in the loss function; something like this:


mutable struct AE_tied
    encoder
    decoder

    weights_encoder
    weights_decoder
end

AE_tied(encoder, decoder) = AE_tied(encoder, decoder, params(encoder), params(decoder))

function (a::AE_tied)(x)
    x = a.encoder(x)
    a.weights_decoder[1] .= a.weights_encoder[1]
    a.decoder(x)
end

encoder = Conv((3,3), 1=>2, relu, pad=SamePad())
decoder = ConvTranspose((3,3), 2=>1, relu, pad=SamePad())

model = AE_tied(encoder, decoder)
model = cpu(model)

ps = Flux.params(model.encoder)
opt = ADAM(0.1)

function loss(x) 
    y = model(x)
    sum((y .- x) .^2) / length(x)
end

train_data = cpu(rand(5, 5, 1, 2))

for epoch in 1:1
    local trainLoss
    gs = Flux.gradient(ps) do
        trainLoss = loss(train_data)
        return trainLoss
    end
    Flux.Optimise.update!(opt, ps, gs)
    @show trainLoss
end

Running this gives ERROR: LoadError: Mutating arrays is not supported. It's the line a.weights_decoder[1] .= a.weights_encoder[1] that's the issue.

Am I going about this the wrong way, or is what I'm trying to do not supported? Thanks in advance for any help

dfenn avatar May 07 '21 17:05 dfenn

Indeed mutating isn't supported by Zygote, which is used to calculate the gradients. It is supported in some other Julia AD packages which you might be able to use.

However, I don't believe the above snippet actually does tie the weights properly. E.g. the gradients of tied weights should be the same, but this won't be the case if you tie them by manually tweaking them to be equal.

With this in mind, my preferred solution would be to initialise the weights of the decoder to be a @view on the weights of the encoder.

I haven't actually checked to see whether this plays nicely with Flux, but maybe it's something to try.

atiyo avatar May 09 '21 21:05 atiyo

Thanks for you response. I was able to get it working using @views for the convolutional layers. However, the same approach isn't working for dense layers, where the weights matrix must be transposed:

encoder = Dense(5, 2)
@views decoder = Dense(transpose(encoder.weight), rand(5))

This gives the error

ERROR: LoadError: TypeError: in typeassert, expected Tuple{Transpose{Float32, Matrix{Float32}}, Transpose{Float32, Matrix{Float32}}, Vector{Float64}}, got a value of type Tuple{Matrix{Float32}, Matrix{Float32}, Vector{Float64}}
Stacktrace:
 [1] apply!(o::ADAM, x::Transpose{Float32, Matrix{Float32}}, Δ::Matrix{Float64})
   @ Flux.Optimise ~/.julia/packages/Flux/6BByF/src/optimise/optimisers.jl:175
 [2] update!(opt::ADAM, x::Transpose{Float32, Matrix{Float32}}, x̄::Matrix{Float64})
   @ Flux.Optimise ~/.julia/packages/Flux/6BByF/src/optimise/train.jl:23
 [3] update!(opt::ADAM, xs::Params, gs::Zygote.Grads)
   @ Flux.Optimise ~/.julia/packages/Flux/6BByF/src/optimise/train.jl:29

It looks like Flux is inferring the type as Transpose and then complaining when it receives a Matrix. I've tried using PermutedDimsArray instead, with similar results.

It's not clear to me how to address this. Any ideas?

dfenn avatar May 11 '21 05:05 dfenn

We should probably change that line in the Adam code to use Adapt.jl to get the correct type instead of hard-typing the return of get!.

darsnack avatar May 11 '21 13:05 darsnack

Probably better to incorporate directly in optimisers.jl

As long as we pass in the correct references we should be good. I don't think it needs to be addressed in the optimisers otherwise.

DhairyaLGandhi avatar May 11 '21 13:05 DhairyaLGandhi

I don't think we need the fix in Optimisers.jl because the state is initialized separately (and correctly). This appears to only be a bug for IdDict optimizers.

Agreed that we only need the references to be correct.

darsnack avatar May 11 '21 13:05 darsnack

possibly related to FluxML/Zygote.jl#991 and #1613

We should probably change that line in the Adam code to use Adapt.jl to get the correct type instead of hard-typing the return of get!.

Even if use something like #1613 to adapt the types, that wouldn't still be entirely correct because we would be taking 2 steps of adam with separate gradients instead of a single step with the accumulated one

CarloLucibello avatar Jun 10 '21 14:06 CarloLucibello

taking 2 steps of adam with separate gradients instead of a single step with the accumulated one

Yeah, with ADAM this will certainly be wrong. Referencing https://github.com/FluxML/Zygote.jl/issues/991#issuecomment-864411649, it's not two steps that's wrong. It's the momentum terms that will be incorrect leading to two steps not being equivalent to a single accumulated one. For simpler optimizers like Descent, this will be correct (assuming the gradients are correct which they are for explicit params).

darsnack avatar Jun 23 '21 00:06 darsnack

Hello,

I wanted to follow-up on this issue. Is it resolved in the lastest version of Flux.jl?

mleprovost avatar Oct 31 '23 19:10 mleprovost

On latest Flux, using new-style training with setup, something like dec = Dense(transpose(encoder.weight)) should just work. It will see through the transpose and notice that the same array appears twice.

(With old-style IdDict optimisers, I'm not sure.)

mcabbott avatar Oct 31 '23 19:10 mcabbott