bumblebee icon indicating copy to clipboard operation
bumblebee copied to clipboard

Add zero-shot classification

Open seanmor5 opened this issue 2 years ago • 1 comments
trafficstars

I will add tests tomorrow when I get the multi-prompt case down, but right now it seems to be working fine:

Screen Shot 2022-12-15 at 6 45 05 PM

Btw, is there a reason we hid documentation for TokenClassification ?

seanmor5 avatar Dec 16 '22 02:12 seanmor5

Awesome! I think ideally we should follow the same format as image/text classification, so %{predictions: [%{label: ..., score: ...}, ...]}. For now I wouldn't include prompt and label in the output, since the user has the input, and for multiple they can just zip.

Btw, is there a reason we hid documentation for TokenClassification?

We use defdelegate in the Bumblebee.Text module and there are all the docs :)

jonatanklosko avatar Dec 16 '22 10:12 jonatanklosko

@seanmor5 I fixed the post processing to apply softmax per batch on the entailment label. I also adjusted the assertions based on this:

from transformers import pipeline

p = pipeline("zero-shot-classification", model="facebook/bart-large-mnli", candidate_labels=["cooking", "traveling", "dancing"])

p("one day I will see the world")

jonatanklosko avatar Dec 22 '22 16:12 jonatanklosko