axon icon indicating copy to clipboard operation
axon copied to clipboard

Bidirectional Wrapper Implementation (Feature Request)

Open fantypants opened this issue 1 year ago • 6 comments

Hello,

@seanmor5 I had put up a bidirectional example in the Nx Slack channel. Currently planning out the implementation and looking at the use cases for it as follows:

Tf/Keras uses the Bidirectional Layer, and subsequent usage for LSTM's etc is contained within the LSTM layer itself, via flags/opts (go_backwards, for LSTM sake).

My plan is to implement a simple bidirectional layer and subsequently the applicable layers would have to have be refactored to include the new feature.

What are everyone's thoughts on how this should be implemented? similar to Keras using opts within the Layers + a Bidirectional Layer itself, or do we create Bidirectional LSTM layers instead, and use the LSTM layer as it is?

PR will be up soon for it, I'm assuming it'll need some good review!

fantypants avatar Jul 19 '22 11:07 fantypants

@fantypants My personal opinion is that this would be more similar to a combinator like in trax, which would match the Keras wrapper impl:

model =
  Axon.input("sequence")
  |> Axon.embedding(32, 64)
  |> Axon.bidirectional(&Axon.lstm/3, merge: &Axon.concatenate/2)

Or something similar to that. Basically the interface is:

def bidirectional(%Axon{} = input, layer_fun, opts \\ []) do
  opts = Keyword.validate!(opts, [:name, merge: &Axon.concatenate/2, axis: 1])
end

There would have to be a contract for layer_fun and merge_fun accepting and returning certain arguments, but at a high-level this is what I'd envision

seanmor5 avatar Jul 19 '22 22:07 seanmor5

@fantypants My personal opinion is that this would be more similar to a combinator like in trax, which would match the Keras wrapper impl:

model =
  Axon.input("sequence")
  |> Axon.embedding(32, 64)
  |> Axon.bidirectional(&Axon.lstm/3, merge: &Axon.concatenate/2)

Or something similar to that. Basically the interface is:

def bidirectional(%Axon{} = input, layer_fun, opts \\ []) do
  opts = Keyword.validate!(opts, [:name, merge: &Axon.concatenate/2, axis: 1])
end

There would have to be a contract for layer_fun and merge_fun accepting and returning certain arguments, but at a high-level this is what I'd envision

I really like that approach, I hadn't heard of trax before, however, in this context the above sounds like the correct solution, especially considering the issue I was having is the two different implementations & the naming conventions would've been quite confusing & exhaustive (sticking to my approach would've been more work per change like this)

I'll get some things together and start on the solution.

fantypants avatar Jul 20 '22 10:07 fantypants

@seanmor5 i've been working on the implementation, im getting stuck at the deep merge with the following error:

** (Protocol.UndefinedError) protocol Nx.Container not implemented for #Axon<
  inputs: ["inputs"]
> of type Axon (a struct), check the docs for Nx.Container for more information. This protocol is implemented for the following type(s): Any, Axon.None, Axon.StatefulOutput, Map, Tuple

The function i'm using is:

@doc type: :bidirectional
  def bidirectional(%Axon{} = input, layer_fn, opts \\ [] ) do
      opts = Keyword.validate!(opts, [:name, merge: &Axon.concatenate/2])

      forward_input = input
      backward_input = Axon.nx(input, &Nx.reverse(&1), op_name: :reverse)

      forward_result = layer_fn.(forward_input)
      backward_result = layer_fn.(backward_input)

      Axon.Shared.deep_merge(forward_result, backward_result, opts[:merge])
  end

However, at Axon.Shared.deep_merge it throws the above error. I'm presuming it's a straight forward error, i.e Nx.Container doesn't implement the Axon struct; in the Nx documentation it says that it can work with any type that inherits the Nx.Container so i'm wondering if it's Nx/Exla dependencies being incorrect?

I've used (from the PR comments you added, i've tried both configs):

    # {:exla, "~> 0.2", [github: "elixir-nx/nx", sparse: "exla"]}, 
    #{:nx, "~> 0.2", [github: "elixir-nx/nx", sparse: "nx", override: true]},
    {:exla, "~> 0.3.0-dev", github: "elixir-nx/nx", sparse: "exla"},
    {:nx, "~> 0.3.0-dev", github: "elixir-nx/nx", sparse: "nx", override: true}

Using apply/2 works, however, I don't think it's the same operation as deep_merge, and deep_merge sounds like the correct method for this (considering it's reducing the output layer, a deep traverse is probably much more accurate then a simple concat/merge fn applied on the inputs)

apply(merge_fn, [forward, backward])

Updates: After running some basic tests:

inp1 = Axon.input({nil, 1}, "input_0")
inp2 = Axon.input({nil, 2}, "input_1")

 Axon.container(%{a: inp1, b: inp2})
  |> Nx.Container.reduce([], fn x, a -> x end)

it seems even using a plain Axon.container & running it through Nx.Container.reduce reproduces the issue. Are we missing the containers option somewhere? similar too:

@derive {Nx.Container,
       containers: [:field_name, :other_field]}

fantypants avatar Jul 29 '22 09:07 fantypants

@fantypants You'll need to bring this implementation of deep_merge into axon.ex as a private helper:

  defp deep_merge(left, right, fun) do
    case Nx.Container.traverse(left, leaves(right), &recur_merge(&1, &2, fun)) do
      {merged, []} ->
        merged

      {_merged, _leftover} ->
        raise ArgumentError,
              "unable to merge arguments with incompatible" <>
                " structure"
    end
  end

  defp leaves(container) do
    container
    |> Nx.Container.reduce([], fn x, acc -> [x | acc] end)
    |> Enum.reverse()
  end

  defp recur_merge(left, [right | right_leaves], fun) do
    case {left, right} do
      {%Nx.Tensor{} = left, %Nx.Tensor{} = right} ->
        {fun.(left, right), right_leaves}

      {%Axon{} = left, %Axon{} = right} ->
        {fun.(left, right), right_leaves}

      {left, right} ->
        {deep_merge(left, right, fun), right_leaves}
    end
  end

seanmor5 avatar Jul 29 '22 16:07 seanmor5

@seanmor5 that did the trick, results are now coming in! I have to do some branch shuffling with the other tutorial PR branch first, then i'll push this PR.

What's the purpose of the Axon.Layer functions vs the Axon functions? i.e Axon.lstm vs Axon.Layers.lstm; is this something I should include now into the PR?

fantypants avatar Jul 31 '22 09:07 fantypants

@fantypants I have been thinking about this and it's actually more difficult than what I outlined. The implementation requires an implementation for #169, so I might have you hold off before implemnting something :)

seanmor5 avatar Aug 13 '22 21:08 seanmor5

Closing as tracked in #119 :)

seanmor5 avatar Sep 07 '22 00:09 seanmor5