axon
axon copied to clipboard
Improve compiler error messages
Right now large model compiler stacktraces are very difficult to interpret. It's almost impossible to tell where your issue is because the compiler offers no straightforward information in the stacktrace. This needs to improve
@seanmor5 maybe you should do it like we do on macros, which is to wrap each layer "expansion" in a try catch, and the manipulate the stacktrace to include the layer: https://github.com/elixir-lang/elixir/blob/main/lib/elixir/src/elixir_dispatch.erl#L219-L231
If you give me a small NN with an error, I can PR this in. :)
Here's an example:
input = Axon.input("input")
x1 = Axon.dense(input, 32)
x2 = Axon.dense(input, 64)
model = Axon.add(x1, x2)
{init_fn, predict_fn} = Axon.build(model)
init_fn.(Nx.template({1, 16}, :f32), %{})
And the stacktrace:
** (ArgumentError) cannot broadcast tensor of dimensions {1, 32} to {1, 64}
(nx 0.3.0) lib/nx/shape.ex:335: Nx.Shape.binary_broadcast/4
(nx 0.3.0) lib/nx.ex:3448: Nx.element_wise_bin_op/4
(elixir 1.13.0) lib/enum.ex:2396: Enum."-reduce/3-lists^foldl/2-0-"/3
(axon 0.2.0) lib/axon/compiler.ex:537: Axon.Compiler.layer_predict_fun/14
(axon 0.2.0) lib/axon/compiler.ex:614: Axon.Compiler.layer_init_fun/9
(axon 0.2.0) lib/axon/compiler.ex:57: anonymous fn/3 in Axon.Compiler.build/2
(stdlib 3.17.1) timer.erl:166: :timer.tc/1
(axon 0.2.0) lib/axon/compiler.ex:56: anonymous fn/5 in Axon.Compiler.build/2
You can imagine with larger networks that it just ends up being alternating layer_predict_fun
and layer_init_fun
with no indication of what's going on. We can also do a better job catching these errors earlier if a shape is provided up front :)
This was resolved earlier :)