axon
axon copied to clipboard
Compiled model will always result in `%Axon.None{}`
I'm trying to load a pre-trained Sentence Transformers model using bumblebee
in order to use it with axon
(using a custom layer for mean pooling instead of a postprocessing step). When I try to output the execution flow of the model as a table, it raises the following error:
** (ArgumentError) the compiled model will always result in %Axon.None{}. This most likely means you specified optional output and did not handle the case when it is missing
(axon 0.6.0) lib/axon/compiler.ex:107: anonymous fn/6 in Axon.Compiler.build/2
(nx 0.6.2) lib/nx/defn/compiler.ex:158: Nx.Defn.Compiler.runtime_fun/3
(axon 0.6.0) lib/axon/defn.ex:14: Axon.Defn.__jit__/5
(nx 0.6.2) lib/nx/defn.ex:443: Nx.Defn.do_jit_apply/3
(axon 0.6.0) lib/axon.ex:3364: Axon.get_output_shape/3
(axon 0.6.0) lib/axon/display.ex:182: Axon.Display.do_axon_to_rows/6
(axon 0.6.0) lib/axon/display.ex:86: Axon.Display.axon_to_rows/6
(axon 0.6.0) lib/axon/display.ex:152: anonymous fn/4 in Axon.Display.do_axon_to_rows/6
Here is the corresponding code snippet.
Mix.install(
[
{:bumblebee, "~> 0.4.2"},
{:exla, ">= 0.0.0"},
{:axon, git: "https://github.com/elixir-nx/axon.git", override: true},
{:table_rex, "~> 3.1.1"}
],
config: [nx: [default_backend: EXLA.Backend]]
)
model_repo = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
{:ok, model_info} = Bumblebee.load_model({:hf, model_repo})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, model_repo})
batch_size = 2
sequence_length = 128
output_dims = 768
template = %{
"attention_mask" => Nx.template({batch_size, sequence_length}, :u32),
"input_ids" => Nx.template({batch_size, sequence_length}, :u32),
"token_type_ids" => Nx.template({batch_size, sequence_length}, :u32)
}
%{model: base_model, params: params, spec: _spec} = model_info
string_inputs = ["test 1", "test 2"]
inputs = Bumblebee.apply_tokenizer(tokenizer, string_inputs, length: sequence_length)
defmodule CustomLayer do
import Nx.Defn
defn mean_pooling(model_output, attention_mask, _opts \\ []) do
input_mask_expanded = Nx.new_axis(attention_mask, -1)
model_output.hidden_state
|> Nx.multiply(input_mask_expanded)
|> Nx.sum(axes: [1])
|> Nx.divide(Nx.sum(input_mask_expanded, axes: [1]))
end
end
attention_mask = Axon.input("attention_mask")
complete_model =
Axon.layer(&CustomLayer.mean_pooling/3, [base_model, attention_mask], op_name: :mean_pooling)
complete_model |> Axon.get_inputs() |> IO.inspect()
{_init_fn, predict_fn} = Axon.build(complete_model)
predict_fn.(params, inputs) |> IO.inspect()
# This returns the correct output shape
complete_model |> Axon.get_output_shape(template) |> IO.inspect()
# This fails with "(ArgumentError) the compiled model will always result in %Axon.None{}"
complete_model |> Axon.Display.as_table(template) |> IO.puts()
As far as I'm aware, there is no "optional output" specified. Related to that: Maybe I'm missing something, but it should probably say "optional inputs", as these nullify downstream nodes when no input is given?