Transformers.jl icon indicating copy to clipboard operation
Transformers.jl copied to clipboard

Inference API

Open chengchingwen opened this issue 2 years ago • 5 comments

mentioned in #108. Currently we don't have an inference api, like the pipeline from huggingface transformers. Right now you need to manually load the model/tokenizer, apply them on the input data, and convert the prediction result to correct/corresponding labels.

chengchingwen avatar Aug 06 '22 11:08 chengchingwen

What's the way to save and load a model currently? I'm saving it like so

BSON.@save bsonname bert_model wordpiece tokenizer

And loading it using load_pretrain_bert(bsonname) but it throws ERROR: UndefVarError: Transformers not defined while loading the tokenizer. Moreover, Flux docs suggest you should do cpu(model) before saving it -- do you think that breaks anything?

Broever101 avatar Aug 06 '22 14:08 Broever101

Simply BSON.@save and BSON.@load. I guess the error is probably because you forget to using Transformers before loading. And yes it's better to do cpu(model) before saving.

chengchingwen avatar Aug 06 '22 15:08 chengchingwen

Simply BSON.@save and BSON.@load. I guess the error is probably because you forget to using Transformers before loading. And yes it's better to do cpu(model) before saving.

Weird. I have all the dependencies imported in the main module and I'm including the loading script in the module. Anyways, importing them in the REPL solved the issue -- probly a dumb mistake on my part.

Right now, I'm doing this:

struct Pipeline 
    bert_model
    wordpiece
    tokenizer
    bertenc
    function Pipeline(; ckpt::AbstractString="BERT_Twitter_Epochs_1")
        bert_model, wordpiece, tokenizer = load_bert_pretrain("ckpt/$ckpt.bson")
        bert_model = todevice(bert_model)

        bertenc = BertTextEncoder(tokenizer, wordpiece)
        Flux.testmode!(bert_model)
        new(bert_model, wordpiece, tokenizer, bertenc)
    end
end

function (p::Pipeline)(query::AbstractString)
    data = todevice(preprocess([[query], ["0"]]))
    e = p.bert_model.embed(data.input)
    t = p.bert_model.transformers(e, data.mask)

    prediction = p.bert_model.classifier.clf(
        p.bert_model.classifier.pooler(
            t[:,1,:]
        )
    )

    @info "Prediction: " prediction
end

I can do


>p = Pipeline()
>p("this classifier sucks")
┌ Info: Prediction:
│   prediction =
│    2×1 Matrix{Float32}:
│     -0.06848035
└     -2.7152526

I have no idea how to interpret the results (should I uhh take the absolute to know which one hot category is hot??) but is this the correct approach?

Broever101 avatar Aug 06 '22 16:08 Broever101

Several points:

  1. you don't need to use load_bert_pretrain, you can just use BSON.@load.
  2. BertTextEncoder contains both tokenizer and wordpiece, so you don't need to store all of them.
  3. you would need to do Flux.onecold(prediction) to turn the logits into the index of label.
  4. but the meaning of label is missing here, so you might want to store them in your checkpoint file as well.

chengchingwen avatar Aug 06 '22 16:08 chengchingwen

Any further thoughts on an Inference API?

@ashwani-rathee and I have been discussing a framework-agnostic API - in particular for inference - that might be relevant wrt. to an inference API for Transformers.jl: https://julialang.zulipchat.com/#narrow/stream/390029-image-processing/topic/DL.20based.20tools/near/383544112

stemann avatar Aug 11 '23 17:08 stemann