axon icon indicating copy to clipboard operation
axon copied to clipboard

Add a guide for writing a simple recurrent network

Open marcinkoziej opened this issue 11 months ago • 7 comments

Hi! Axon beginner here.

I struggle to figure out how to write a very simple RNN network. Basically I want to rewrite this pytorch example from a tutorial.

However, the Axon API makes it a bit convoluted to create networks that "scan" or "unroll" the input. After some digging I realized that I need to create something similar to lstm_cell and lstm, but these APIs are not well documented (what do dynamic_unroll arguments mean?). I am also not sure how to handle parameters in that case so the training mechanism (Axon.Loop.trainer with standard optimizer and loss functions) can do it's job.

marcinkoziej avatar Jul 10 '23 16:07 marcinkoziej

Staring some more into the Axon code I noticed that Axon.Layers.lstm and friends are defined using an arcane macro. Is something like this always necessary to create a recurrent network with Axon, or is it just means to not duplicate code?

marcinkoziej avatar Jul 10 '23 17:07 marcinkoziej

Does this example help? https://github.com/elixir-nx/axon/blob/main/examples/generative/text_generator.exs

polvalente avatar Jul 10 '23 22:07 polvalente

@polvalente thanks for a quick reply! I have seen this guide – the problem with it is that it uses lstm as a black box, as it encapsulates and abstracts away the iteration over input data. I wanted to write a simple RNN myself, in which I specify how the "cell" function looks like, and so I tried to use the dynamic_unroll(cell_fn, input_sequence, carry, input_kernel, recurrent_kernel, bias) to build something similar myself. However, I cannot figure out how to use it properly, as it has little documentation and its arguments nor return values are not described, and even it's argument list seems to be quite arbitrary (eh why. carry, input_kernel, recurrent_kernel, bias, and not just paramters; or on the other hand, why just one bias instead of input_bias and recurrent_bias)...

Right now I figured out maybe i should avoid the Axon's unroll_* functions, and just use Nx.while myself in a custom layer?

In general, I am not sure if Axon intends to provide a re-usable building block (like Axon.Loop) to scan/unroll inputs, or are these functions tailor made for the 3 implemented recurrent NNs in axon package? If the former, then it would be great to see a guide on how to use such abstraction to implement an custom RNN model (not lstm, not gru, etc).

marcinkoziej avatar Jul 11 '23 07:07 marcinkoziej

I can't speak too much about the intention behind the design, but since we don't have an explicit @doc false I'd expect this to be a public interface.

Some things I concluded (pending any corrections by @seanmor5) that might help you:

dynamic_unroll and static_unroll are, from the way I see it, ways to apply cell_fn over the "sequence" axis of your input - that is, axis 1, the first axis after the batch dimension - carrying the results forward to the next entry in that dimension, effectively doing the equivalent of Enum.scan over those entries. The only difference is that dynamic_ and static_ refer to whether you're building this scan onto your computational graph unrolled (static) or as a while loop (dynamic).

So if you have an input which is of the shape {batch, sequence, m, n, ...} cell_fn is a function which takes {batch, m, n, ...} tensors, as well as a carry, which is whichever state you need to carry over to the next entry in the sequence, and outputs a batch of output and a batch of carry.

Also you want to receive input_kernel, recurrent_kernel, bias, which are the trainable params, and which shape will depend on your actual cell_fn definition. For instance, in the code for lstm_cell, below, we can see that input_kernel and recurrent_kernel (which is the hidden_kernel referred above) are weights that, together with bias, compose linear transformations in the form of an Axon.dense layer. Your cell_fn could very well just have those as empty maps or constant values if you didn't want to apply any transformations whatsoever to your input and carry values.

    {cell, hidden} = carry
    {wii, wif, wig, wio} = input_kernel
    {whi, whf, whg, who} = hidden_kernel

    {bi, bf, bg, bo} = bias

    i = gate_fn.(dense(input, wii, bi) + dense(hidden, whi, 0))
    f = gate_fn.(dense(input, wif, bf) + dense(hidden, whf, 0))
    g = activation_fn.(dense(input, wig, bg) + dense(hidden, whg, 0))
    o = gate_fn.(dense(input, wio, bo) + dense(hidden, who, 0))

    new_c = f * cell + i * g
    new_h = o * activation_fn.(new_c)

output and carry are the conceptual equivalent of the {x, acc} result in Enum.scan or Enum.map_reduce

polvalente avatar Jul 13 '23 01:07 polvalente

Thanks for some clarifications.

I am trying to go forward without using the _unroll APIs but now I stumbled on another problem in RNN and sequential data: uneven length of input data.

I work on an example livebook where I rewrite a pytorch example to Axon. In pytorch example, it was possible to work on uneven data (last names which have variable number of letters), but it seems Axon prefers a fixed length input (in LSTM example, there is a fixed sequence_size). I learned that I should pad the data with 0 to make all inputs same length, but there should be some way to tell Axon to ignore this padding (something like masking and padding described here for TensorFlow). Does Axon has a concept like this?

marcinkoziej avatar Jul 19 '23 13:07 marcinkoziej

There is a new API Axon.mask which does this that you can pass to Axon.lstm and other RNNs. Something like this should work:

input = Axon.input("seq")
# pad token is 0
mask = Axon.mask(input, 0)
embed = Axon.embedding(input, ...)
{seq, state} = Axon.lstm(embed, 32, mask: mask)

seanmor5 avatar Jul 19 '23 13:07 seanmor5

Thanks! I saw it was just committed few days ago! When will it be released?

I would like to reiterate that an example of a custom RNN using all these features (unroll, masking, how to implement a "cell", can we call other layers from a RNN "cell") would be awesome to see in Axon guide!

marcinkoziej avatar Jul 19 '23 13:07 marcinkoziej