flair icon indicating copy to clipboard operation
flair copied to clipboard

[Question]: I got different results when I loaded a trained model to do the inference task on the same test data

Open xiulinyang opened this issue 1 month ago • 5 comments

Question

Hi,

I'm training a SequenceTagger to do the NER task using my customed dataset with the customed features. After training was done, I got a file named test.tsv which is the prediction of the test split. However, when I loaded the trained model (final-model.pt) and did inference on the same test data, I got way lower results (0.86 vs 0.64 in accuracy).

Here is the prediction function I'm using. I checked the sents list, and all the labels are correctly added to each token. During training, I stacked all the features - should I do the same during prediction? I find the main issue is that the model does not understand the labels added - I have a specific label named __TARGET__ which is the signal for the model to give predictions on specific tokens, but it seems that the model ignores the tag. It would be very appreciated for any suggestions. Thanks!

def predict(self, in_path=None, in_format="conllu", out_format="flair", as_text=False, sent="s"):
        model = self.model
        model.eval()
        tagcol = -1
        poscol = 1
        depcol = 2
        attachcol = 3
        arg1col = 4
        arg2col = 5
        arg3col = 6
        if as_text:
            data = in_path
            #data = (data + "\n").replace("<s>\n", "").replace("</s>\n", "\n").strip()
        else:
            data = io.open(in_path,encoding="utf8").read()
        sents = []
        words = []
        true_tags = []
        true_pos = []
        true_dep = []
        true_attach = []
        true_arg1 = []
        true_arg2 = []
        true_arg3 = []
        data = data.strip() + "\n"  # Ensure final new line for last sentence
        for line in data.split("\n"):
            if len(line.strip())==0 or in_format=="sgml" and "</" + sent + ">" in line:
                if len(words) > 0:
                    if flair_version > 8:
                        tokenizer = False
                    else:
                        tokenizer = lambda x:x.split(" ")
                    sents.append(Sentence(" ".join(words),use_tokenizer=tokenizer))
                    for i, word in enumerate(sents[-1]):
                            word.add_label("upos",true_pos[i])
                            word.add_label('deprel', true_dep[i])
                            word.add_label('attach', true_attach[i])
                            word.add_label('arg1', true_arg1[i])
                            word.add_label('arg2', true_arg2[i])
                            word.add_label('arg3', true_arg3[i])
                    words = []
                    true_pos = []
                    true_dep = []
                    true_attach = []
                    true_arg1 = []
                    true_arg2 = []
                    true_arg3 = []
            else:
                  if "\t" in line:
                      fields = line.split("\t")
                      if "." in fields[0]:
                          continue
                      if "-" in fields[0]:
                          continue
                      words.append(line.split("\t")[0])
                      true_tags.append(line.split("\t")[tagcol])
                      true_pos.append(line.split("\t")[poscol])
                      true_attach.append(line.split("\t")[attachcol])
                      true_dep.append(line.split("\t")[depcol])
                      true_arg1.append(line.split("\t")[arg1col])
                      true_arg2.append(line.split("\t")[arg2col])
                      true_arg3.append(line.split("\t")[arg3col])

        # predict tags and print
        if flair_version > 8:
            model.predict(sents, force_token_predictions=True, return_probabilities_for_all_classes=True)
        else:
            model.predict(sents)#, all_tag_prob=True)

        preds = []
        scores = []
        words = []
        for i, sent in enumerate(sents):
            for tok in sent.tokens:
                if flair_version > 8:
                    pred = tok.labels[6].value if len(tok.labels)>6 else "O"
                    score = tok.labels[6].score if len(tok.labels) > 6 else "1.0"
                else:
                    pred = tok.labels[6].value
                    score = str(tok.labels[6].score)
                preds.append(pred)
                scores.append(score)
                words.append(tok.text)
                tok.clear_embeddings()  # Without this, there will be an OOM issue

        toknum = 0
        output = []
        #out_format="diff"
        for i, sent in enumerate(sents):
            tid=1
            if i>0 and out_format=="conllu":
                output.append("")
            for tok in sent.tokens:
                pred = preds[toknum]
                score = str(scores[toknum])
                if len(score)>5:
                    score = score[:5]
                if out_format == "conllu":
                    pred = pred if not pred == "O" else "_"
                    fields = [str(tid),tok.text,"_",pred,pred,"_","_","_","_","_"]
                    output.append("\t".join(fields))
                    tid+=1
                elif out_format == "xg":
                    output.append("\t".join([pred, tok.text, score]))
                elif out_format == "tt":
                    output.append("\t".join([tok.text, pred]))
                else:
                    true_tag = true_tags[toknum]
                    corr = "T" if true_tag == pred else "F"
                    output.append("\t".join([tok.text, pred, true_tag, corr, score]))
                toknum += 1

        if as_text:
            return "\n".join(output)
        else:
            ext = "gme.conllu" if out_format == "conllu" else "txt"
            partition = "test" if "test" in in_path else "dev"
            with io.open(script_dir + TRAIN_PATH +os.sep + "flair-"+partition+"-pred." + ext,'w',encoding="utf8",newline="\n") as f:
                f.write("\n".join(output))

xiulinyang avatar Jun 30 '24 03:06 xiulinyang