axon
axon copied to clipboard
Block with multiple inputs fails to compile
When passing multiple inputs to a block, the model fails to compile:
Mix.install(
[
{:exla, ">= 0.0.0"},
{:axon, path: "./axon", overwrite: true},
{:table_rex, "~> 3.1.1"}
],
config: [nx: [default_backend: EXLA.Backend]]
)
input1 = Axon.input("input1")
input2 = Axon.input("input2")
input3 = Axon.input("input3")
input4 = Axon.input("input4")
reuse = Axon.block(fn x1, x2 ->
Axon.add(x1, x2)
end)
model_1 = reuse.([input1, input2])
model_2 = reuse.([input3, input4])
out = Axon.container({model_1, model_2})
template = %{
"input1" => Nx.template({2, 8}, :f32),
"input2" => Nx.template({2, 8}, :f32),
"input3" => Nx.template({2, 8}, :f32),
"input4" => Nx.template({2, 8}, :f32)
}
{init_fn, predict_fn} = Axon.compile(out, template)
The following error is thrown:
** (Axon.CompileError) exception found when compiling layer Axon.Layers.block/3 named block_0:
** (UndefinedFunctionError) function Axon.Layers.block/3 is undefined or private
(axon 0.6.0) Axon.Layers.block(#Nx.Tensor<
f32[2][8]
Nx.Defn.Expr
parameter a:0 f32[2][8]
>, #Nx.Tensor<
f32[2][8]
Nx.Defn.Expr
parameter a:1 f32[2][8]
>, [mode: :inference, block_fun: #Function<0.123201197 in file:block_without_start.exs>, block_id: 5])
(pass debug: true to build/compile see where the layer was defined)
Compiling of the model was initiated at:
(nx 0.6.2) lib/nx/defn/compiler.ex:158: Nx.Defn.Compiler.runtime_fun/3
(nx 0.6.2) lib/nx/defn/evaluator.ex:83: Nx.Defn.Evaluator.precompile/3
(nx 0.6.2) lib/nx/defn/evaluator.ex:61: Nx.Defn.Evaluator.__compile__/4
(nx 0.6.2) lib/nx/defn/evaluator.ex:54: Nx.Defn.Evaluator.__jit__/5
(nx 0.6.2) lib/nx/defn.ex:443: Nx.Defn.do_jit_apply/3
(axon 0.6.0) lib/axon.ex:3645: Axon.compile/4
Upon further investigation, it seems like Axon.Compiler.recur_model_funs
can never match on a %Axon.Node{op: :block}
with multiple parents (https://github.com/elixir-nx/axon/blob/main/lib/axon/compiler.ex#L464). Instead, the "generic" case (for built-in Axon.Layers
nodes) matches (https://github.com/elixir-nx/axon/blob/main/lib/axon/compiler.ex#L665), which leads to Axon.Layers.block/n
being built erroneously.
I am still trying to think of a good way to handle these cases. I think for now you should wrap all inputs in a container and then you can match on them