german-sentiment-lib icon indicating copy to clipboard operation
german-sentiment-lib copied to clipboard

Feature Request: Return confidence of prediction

Open Max-Jesch opened this issue 3 years ago • 2 comments

Pretty awesome project. Thanks a lot for sharing.

One Feature that would be really valuable to me would be to get some sort of "confidence" for the predictions. Do you think that is tricky to do? I would offer my help if you think that makes sense.

Max-Jesch avatar Oct 28 '21 16:10 Max-Jesch

I am glad you like it. You totally can do this - just modify this line, to get the value of the logit:

https://github.com/oliverguhr/german-sentiment-lib/blob/master/germansentiment/sentimentmodel.py#L32

If you decide to change the code I would appreciate a pull request :)

oliverguhr avatar Nov 01 '21 20:11 oliverguhr

I needed the same thing and I overwrite the predict method in the class:

from typing import List
import torch
from germansentiment import SentimentModel

class SentimentModel(SentimentModel):
    def __init__(self):
        super().__init__()
        
    def predict_sentiment_proba(self, texts: List[str])-> List[str]:
        texts = [self.clean_text(text) for text in texts]
        # Add special tokens takes care of adding [CLS], [SEP], <s>... tokens in the right way for each model.
        # truncation=True limits number of tokens to model's limitations (512)
        encoded = self.tokenizer.batch_encode_plus(texts, padding=True, add_special_tokens=True,truncation=True, return_tensors="pt")
        encoded = encoded.to(self.device)
        with torch.no_grad():
                logits = self.model(**encoded)
        
        #label_ids = torch.argmax(logits[0], axis=1)
        return [[i.item() for i in r] for r in torch.nn.Softmax(dim=1)(logits[0])], self.model.config.id2label

bpfrd avatar Jul 11 '22 07:07 bpfrd

I added an API feature that does this with version 1.1.0

from germansentiment import SentimentModel

model = SentimentModel()

classes, probabilities = model.predict_sentiment(["das ist super"], output_probabilities = True) 
print(classes, probabilities)
['positive'] [[['positive', 0.9761366844177246], ['negative', 0.023540444672107697], ['neutral', 0.00032294404809363186]]]

oliverguhr avatar Oct 10 '22 10:10 oliverguhr