axon
axon copied to clipboard
Create mechanism for easy model composition
For now, we'll only consider how this should work in the model creation and execution API, but it will touch the training API as well.
Consider the models in a basic GAN:
generator =
Axon.input({nil, 100})
|> Axon.dense(128, activation: :tanh)
|> Axon.dense(512, activation: :tanh)
|> Axon.dense(784, activation: :tanh)
|> Axon.reshape({1, 28, 28})
discriminator =
Axon.input({nil, 1, 28, 28})
|> Axon.dense(128, activation: :relu)
|> Axon.dense(1, activation: :sigmoid)
In order to train, what you'd want to do is something like:
combined = compose(discriminator, generator) # represents D(G(input))
step_d = Axon.Training.step(discriminator, :binary_cross_entropy, Axon.Optimizers.sgd(0.005)
step_g = Axon.Training.step(combined, :binary_cross_entropy, Axon.Optimizers.adam(0.01)
And then you can alternate using step_d and step_g to train on valid / fake images. Unfortunately, we currently don't support model composition in this sense - you can define functions generator and discriminator without an input block, but there's no way to cleanly determine which parameters belong to which model. Ideally, you'd be able to compose models in some way so that when you initialize, predict, train, etc. parameters are grouped:
combined = compose(discriminator, generator)
{d_params, g_params} = combined_params = Axon.init(combined)
Axon.predict(combined, combined_params)
{{d_params, g_params}, _} =
combined
|> Axon.Training.step(:binary_cross_entropy, Axon.Optimizers.adam(0.01)
|> Axon.Training.train(inputs, targets)
Whatever the implementation is, it will involve adding some metadata to parameters to express that expresses their ownership to a given model. From an API perspective, one option is to introduce Axon.compose for composing Axon structs into a single model while preserving parameter information, although I'm not sure I love that right now.
I've been experimenting a bit and after starting #81 I believe I have a solution to this issue. Introduce Axon.function. The idea is that Axon.function takes a block of layers with inputs and returns an anonymous function with arity matching the number of inputs in the block. So the GAN would look like:
generator =
Axon.input({nil, 100})
|> Axon.dense(128, activation: :tanh)
|> Axon.dense(512, activation: :tanh)
|> Axon.dense(784, activation: :tanh)
|> Axon.reshape({1, 28, 28})
|> Axon.function()
discriminator =
Axon.input({nil, 1, 28, 28})
|> Axon.dense(128, activation: :relu)
|> Axon.dense(1, activation: :sigmoid)
|> Axon.function()
joint = discriminator.(generator.(Axon.input({nil, 100}))
And generator and discriminator are still separate objects. The biggest question then becomes how do execution and compilation act when they encounter an Axon.function.
Also note the reason we can't just do:
generator = fn x ->
x
|> Axon.dense(128, activation: :tanh)
|> Axon.dense(512, activation: :tanh)
|> Axon.dense(784, activation: :tanh)
|> Axon.reshape({1, 28, 28})
|> Axon.function()
end
discriminator = fn x ->
x
|> Axon.dense(128, activation: :relu)
|> Axon.dense(1, activation: :sigmoid)
|> Axon.function()
end
g = generator.(Axon.input({nil, 100})
d = discriminator.(Axon.input({nil, 784})
joint = discriminator.(generator.(Axon.input({nil, 100}))
is because of how Axon's compiler works. Subsequent calls to both generator and discriminator in the above yield brand new models with new uniquely named parameters rather than yielding the same model on each call - which is what Axon.function would do.
This is possible with blocks now