Avalon.jl icon indicating copy to clipboard operation
Avalon.jl copied to clipboard

How to freeze parameters

Open lbotecur opened this issue 3 years ago • 3 comments

Hello,

I would like to know how to freeze parameters in a model, that is, how to training only a subset of parameters.

Thank you.

lbotecur avatar Mar 17 '21 08:03 lbotecur

The update!() function takes an optional argument ignore - a set of field paths that should not be updated. A field path is a tuple of symbols representing path to a specific parameter. For example, if your model looks like this:

mutable struct Foo
    x
    y
end

mutable struct Bar
   foo::Foo
   z
end

m = Bar(Foo(1, 2), 3)

And you want to ignore Foo's x and Bar's z, use it like this:

ignore = Set([
    (:foo, :x),
    (:z,)
])

update!(m, gm, ignore=ignore)

Note that this is a low-level and unstable API. I'm currently working on such small things, including this very specific task - freezing the parameters - but I have several uses cases and no specific design yet. I'll be grateful if you describe your use case so that I could make the API more convenient.

dfdx avatar Mar 17 '21 21:03 dfdx

Thank you for the answer. This solution is great. My use case is just the case that you have exposed: to use a pretrained model (Foo) as part of a new model (Bar) and train this one with Foo parameters freezed. After that, to perform a fine-tuned of the model with Foo parameters unfreezed.

I don't know if there is any possibility to pass only the parameters to calculate the gradients to Yötä, in similar way that JAX done.

Thanks.

lbotecur avatar Mar 18 '21 14:03 lbotecur

Great, pretraining is a very important use case for Avalon, so we will definitely have a more concise syntax for freezing parameters, but exact API will arrive later, perhaps shortly after the high-level training API.

Please note that the ignore list expects full field paths, so using just (:foo,) in the example above won't have any effect. To recursively collect the list of field paths, you can use the following:

function collect_fields(obj)
    paths = []
    for p in propertynames(obj)
        subpaths = collect_fields(getproperty(obj, p))
        if !isempty(subpaths)
            for subpath in subpaths
                path = [p; subpath...]
                push!(paths, path)
            end
        else
            push!(paths, [p])
        end
    end
    return [(path...,) for path in paths]
end

I don't know if there is any possibility to pass only the parameters to calculate the gradients to Yötä, in similar way that JAX done.

I'm not sure I've got you correctly, but if you are looking for a semantics like:

f(x) = ...
gf = grad(f)
gf(x)

Unfortunately it's not possible out of the box because without concrete arguments Yota doesn't really know which method of f() to trace. Yet it should be possible to make a simple wrapper, something like:

grad_fn(f) = args -> grad(f, args...)

Is it what you were asking about?

dfdx avatar Mar 18 '21 22:03 dfdx