axon icon indicating copy to clipboard operation
axon copied to clipboard

Block with multiple inputs fails to compile

Open mntns opened this issue 8 months ago • 1 comments

When passing multiple inputs to a block, the model fails to compile:

Mix.install(
  [
    {:exla, ">= 0.0.0"},
    {:axon, path: "./axon", overwrite: true},
    {:table_rex, "~> 3.1.1"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

input1 = Axon.input("input1")
input2 = Axon.input("input2")
input3 = Axon.input("input3")
input4 = Axon.input("input4")

reuse = Axon.block(fn x1, x2 ->
  Axon.add(x1, x2)
end)
model_1 = reuse.([input1, input2])
model_2 = reuse.([input3, input4])

out = Axon.container({model_1, model_2})

template = %{
  "input1" => Nx.template({2, 8}, :f32),
  "input2" => Nx.template({2, 8}, :f32),
  "input3" => Nx.template({2, 8}, :f32),
  "input4" => Nx.template({2, 8}, :f32)
}

{init_fn, predict_fn} = Axon.compile(out, template)

The following error is thrown:

** (Axon.CompileError) exception found when compiling layer Axon.Layers.block/3 named block_0:

    ** (UndefinedFunctionError) function Axon.Layers.block/3 is undefined or private
        (axon 0.6.0) Axon.Layers.block(#Nx.Tensor<
      f32[2][8]

      Nx.Defn.Expr
      parameter a:0   f32[2][8]
    >, #Nx.Tensor<
      f32[2][8]

      Nx.Defn.Expr
      parameter a:1   f32[2][8]
    >, [mode: :inference, block_fun: #Function<0.123201197 in file:block_without_start.exs>, block_id: 5])


(pass debug: true to build/compile see where the layer was defined)


Compiling of the model was initiated at:

    (nx 0.6.2) lib/nx/defn/compiler.ex:158: Nx.Defn.Compiler.runtime_fun/3
    (nx 0.6.2) lib/nx/defn/evaluator.ex:83: Nx.Defn.Evaluator.precompile/3
    (nx 0.6.2) lib/nx/defn/evaluator.ex:61: Nx.Defn.Evaluator.__compile__/4
    (nx 0.6.2) lib/nx/defn/evaluator.ex:54: Nx.Defn.Evaluator.__jit__/5
    (nx 0.6.2) lib/nx/defn.ex:443: Nx.Defn.do_jit_apply/3
    (axon 0.6.0) lib/axon.ex:3645: Axon.compile/4

mntns avatar Nov 01 '23 18:11 mntns

Upon further investigation, it seems like Axon.Compiler.recur_model_funs can never match on a %Axon.Node{op: :block} with multiple parents (https://github.com/elixir-nx/axon/blob/main/lib/axon/compiler.ex#L464). Instead, the "generic" case (for built-in Axon.Layers nodes) matches (https://github.com/elixir-nx/axon/blob/main/lib/axon/compiler.ex#L665), which leads to Axon.Layers.block/n being built erroneously.

mntns avatar Nov 01 '23 19:11 mntns

I am still trying to think of a good way to handle these cases. I think for now you should wrap all inputs in a container and then you can match on them

seanmor5 avatar May 10 '24 19:05 seanmor5