axon_onnx
axon_onnx copied to clipboard
Cannot export BERT
When I pull down the BERT or RoBERTa models from hugging face I can't export them today with 0.4 because the op
in the Axon.Node doesn't match any of the available function heads %Axon.Node{op: :container}
Here is the Axon.Node that is a pattern match miss
%Axon.Node{
id: 999,
name: #Function<66.122028880/2 in Axon.name/2>,
mode: :both,
parent: [%{attentions: 996, hidden_states: 997, logits: 998}],
parameters: [],
args: [:layer],
op: :container,
policy: p=f32 c=f32 o=f32,
hooks: [],
opts: [],
op_name: :container,
stacktrace: [
{Axon, :layer, 3, [file: 'lib/axon.ex', line: 338]},
{Bumblebee, :load_model, 2, [file: 'lib/bumblebee.ex', line: 460]},
{MyCode.MyModule, :export_bert, 0, [file: 'lib/my_module/export.ex', line: 7]},
...
]
}
And here is the code I'm using to pull in BERT and export it
def export() do
{:ok, spec} = Bumblebee.load_spec({:hf, "bert-base-cased"}, architecture: :for_sequence_classification)
spec = Bumblebee.configure(spec, num_labels: 8)
{:ok, bert} = Bumblebee.load_model({:hf, "bert-base-cased"}, spec: spec)
%{model: the_model, params: params} = bert
sequence_length = 50
batch_size = 16
input_template = %{
"input_ids" => Nx.template({batch_size, sequence_length}, :f32),
"attention_mask" => Nx.template({batch_size, sequence_length}, :f32),
"token_type_ids" => Nx.template({batch_size, sequence_length}, :f32)
}
AxonOnnx.export(the_model, input_template, params, path: "tuned.onnx")
end
Further detail: I've cut this down to something easy to reproduce but my full example/motivation here is as follows
- I pull down a RoBERTa or BERT model from huggingface
- I then fine tune this model with domain specific data
- I evaluate the tuned model and it's working as expected with high accuracy
- I then want to export it so I can examine it further with tools like weight watcher