Transformers.jl
Transformers.jl copied to clipboard
Inference API
mentioned in #108. Currently we don't have an inference api, like the pipeline
from huggingface transformers. Right now you need to manually load the model/tokenizer, apply them on the input data, and convert the prediction result to correct/corresponding labels.
What's the way to save and load a model currently? I'm saving it like so
BSON.@save bsonname bert_model wordpiece tokenizer
And loading it using load_pretrain_bert(bsonname)
but it throws ERROR: UndefVarError: Transformers not defined
while loading the tokenizer. Moreover, Flux docs suggest you should do cpu(model)
before saving it -- do you think that breaks anything?
Simply BSON.@save
and BSON.@load
. I guess the error is probably because you forget to using Transformers
before loading. And yes it's better to do cpu(model)
before saving.
Simply
BSON.@save
andBSON.@load
. I guess the error is probably because you forget tousing Transformers
before loading. And yes it's better to docpu(model)
before saving.
Weird. I have all the dependencies imported in the main module and I'm including the loading script in the module. Anyways, importing them in the REPL solved the issue -- probly a dumb mistake on my part.
Right now, I'm doing this:
struct Pipeline
bert_model
wordpiece
tokenizer
bertenc
function Pipeline(; ckpt::AbstractString="BERT_Twitter_Epochs_1")
bert_model, wordpiece, tokenizer = load_bert_pretrain("ckpt/$ckpt.bson")
bert_model = todevice(bert_model)
bertenc = BertTextEncoder(tokenizer, wordpiece)
Flux.testmode!(bert_model)
new(bert_model, wordpiece, tokenizer, bertenc)
end
end
function (p::Pipeline)(query::AbstractString)
data = todevice(preprocess([[query], ["0"]]))
e = p.bert_model.embed(data.input)
t = p.bert_model.transformers(e, data.mask)
prediction = p.bert_model.classifier.clf(
p.bert_model.classifier.pooler(
t[:,1,:]
)
)
@info "Prediction: " prediction
end
I can do
>p = Pipeline()
>p("this classifier sucks")
┌ Info: Prediction:
│ prediction =
│ 2×1 Matrix{Float32}:
│ -0.06848035
└ -2.7152526
I have no idea how to interpret the results (should I uhh take the absolute to know which one hot category is hot??) but is this the correct approach?
Several points:
- you don't need to use
load_bert_pretrain
, you can just useBSON.@load
. -
BertTextEncoder
contains bothtokenizer
andwordpiece
, so you don't need to store all of them. - you would need to do
Flux.onecold(prediction)
to turn the logits into the index of label. - but the meaning of label is missing here, so you might want to store them in your checkpoint file as well.
Any further thoughts on an Inference API?
@ashwani-rathee and I have been discussing a framework-agnostic API - in particular for inference - that might be relevant wrt. to an inference API for Transformers.jl: https://julialang.zulipchat.com/#narrow/stream/390029-image-processing/topic/DL.20based.20tools/near/383544112