axon icon indicating copy to clipboard operation
axon copied to clipboard

Passing f32 data into LSTM with Axon.Loop trainer+run causes while shape mismatch error

Open polvalente opened this issue 1 year ago • 8 comments

input = Axon.input("input_series", shape: put_elem(Nx.shape(time_x), 0, nil))

model =
  input
  |> Axon.lstm(128, activation: :relu)
  |> elem(0)
  |> Axon.dropout(rate: 0.25)
  |> Axon.lstm(64, activation: :relu)
  |> elem(0)
  |> Axon.dropout(rate: 0.25)
  |> Axon.lstm(32, activation: :relu)
  |> elem(0)
  |> Axon.dropout(rate: 0.25)
  |> Axon.dense(1)

model
|> Axon.Loop.trainer(:mean_squared_error, Axon.Optimizers.adam())
|> Axon.Loop.run(Stream.zip(Nx.to_batched(time_x, 50), Nx.to_batched(Nx.new_axis(time_y, 2), 50)))

polvalente avatar Apr 17 '23 07:04 polvalente

Btw, looking at this, it's not advisable to use dropout after an LSTM layer. See https://arxiv.org/pdf/1512.05287.pdf

This is still a bug though

seanmor5 avatar Apr 18 '23 00:04 seanmor5

I was just copying a Kaggle solution to practice RNNs :)

Thanks for the advice!

polvalente avatar Apr 18 '23 00:04 polvalente

@polvalente Looking into this, but I am not getting the error. Could you please tell me what are the shapes of time_x and time_y? Thanks!

krstopro avatar Jun 15 '23 09:06 krstopro

@krstopro I pivoted from this approach, but I believe that x and y were just rank 1 tensors.

The error appears for floating point inputs, but not integer inputs IIRC

polvalente avatar Jun 15 '23 11:06 polvalente

@polvalente time_y is for sure not rank 1, as third dimension is added with Nx.new_axis(time_y, 2).

The following code seems to be working for me (x and y are f32).

key = Nx.Random.key(12)
{x, _new_key} = Nx.Random.normal(key, shape: {12, 6, 3})
{y, _new_key} = Nx.Random.normal(key, shape: {12, 6})

input = Axon.input("input_series", shape: put_elem(Nx.shape(x), 0, nil))

model =
  input
  |> Axon.lstm(128, activation: :relu)
  |> elem(0)
  |> Axon.dropout(rate: 0.25)
  |> Axon.lstm(64, activation: :relu)
  |> elem(0)
  |> Axon.dropout(rate: 0.25)
  |> Axon.lstm(32, activation: :relu)
  |> elem(0)
  |> Axon.dropout(rate: 0.25)
  |> Axon.dense(1)

model
|> Axon.Loop.trainer(:mean_squared_error, Polaris.Optimizers.adam())
|> Axon.Loop.run(Stream.zip(Nx.to_batched(x, 6), Nx.to_batched(Nx.new_axis(y, 2), 6)))

So, everything seems legit, but I might need to further inspect the output shapes of LSTMs. I am not sure if LSTM returns just the last hidden state or all of them throughout the time. Also, I don't know if the first dimension should be batch or time (e.g. in PyTorch it's time or length https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html).

krstopro avatar Jun 15 '23 11:06 krstopro

@seanmor5 did you fix this bug at the time?

@krstopro the problem I ran into was related to some of the inner random key states being upcast from u32 to f32

polvalente avatar Jun 15 '23 11:06 polvalente

I am also seeing an issue which I think has to do with my model changing to :f32 (but from :f64 in my case). See: https://elixirforum.com/t/getting-batches-to-work-with-axon/57482/3

danieljaouen avatar Aug 08 '23 06:08 danieljaouen

Facing the same issue.

Manged to solve it using Nx.as_type(df, :f32) (but the original problem still remains):

...
  def df_to_tensor(df) do
    df
    |> Explorer.DataFrame.names()
    |> Enum.map(&(df[&1] 
      |> Explorer.Series.to_tensor() 
      |> Nx.new_axis(-1)
      |> Nx.as_type(:f32)))
    |> Nx.concatenate(axis: 1)
  end
...

dc0d avatar Nov 05 '23 01:11 dc0d

This issue should be fixed with the new Axon.ModelState changes - dropout keys and other model state are no longer considered part of the training parameters and so shouldn't accidentally get cast anywhere

seanmor5 avatar May 14 '24 12:05 seanmor5