flair icon indicating copy to clipboard operation
flair copied to clipboard

[Bug]: Custom trained model not correctly loaded when using FlairEmbeddings

Open petermartens1992 opened this issue 3 months ago • 2 comments

Describe the bug

The F-scores shown after training a SequenceTagger model with FlairEmbeddings are way higher than when loading the model from file and doing an evaluation on it. Also the predictions from a loaded model seem way off compared to the initial F-scores.

The issue doesn't appear when initializing a new model and only loading the state from file:

sequence_tagger.load_state_dict(SequenceTagger.load(model_path).state_dict())

Also, when removing the FlairEmbeddings the F-scores between training and evaluation correctly align again.

See here the code to reproduce the issue (using flair 0.13.1):

import sys
import logging
import os

from typing import List
from pathlib import Path

from flair.datasets import UD_ENGLISH
from flair.data import Corpus
from flair.models import SequenceTagger
from flair.embeddings import TokenEmbeddings, WordEmbeddings, StackedEmbeddings, FlairEmbeddings, CharacterEmbeddings, OneHotEmbeddings, TransformerWordEmbeddings
from flair.datasets import ColumnCorpus
from flair.trainers import ModelTrainer

language = "nl"
model_type = "upos"

configs["train"] = {
    "language": language,
    "model_type": model_type,
    "tag_format": "BIOES",
    "rnn_hidden_size": 256,
    "rnn_layers": 1,
    "use_crf": True,
    "word_dropout": 0.2,
    "dropout": 0.2,
    "learning_rate": 0.1,
    "mini_batch_size": 32,
    "mini_batch_chunk_size": 32,
    "train_initial_hidden_state": True,
    "embeddings_storage_ratio": 1.0,
    "fine_tune_flair_embeddings": False,
}

output_model_name = configs["train"]["language"] + "-flair-" + configs["train"]["model_source"] + "-example"

output_path = "~/models/flair/" + configs["train"]["model_type"] + "/" + output_model_name
model_path = output_path + "/best-model.pt"
logger.info("model path: " + model_path)

# 1. what label do we want to predict?
label_type = configs["train"]["model_type"]

# 2. load the corpus
corpus = UD_ENGLISH().downsample(0.1)
print(corpus)

# 3. make the label dictionary from the corpus
label_dict = corpus.make_label_dictionary(label_type=label_type)
print(label_dict)

# 4. initialize embeddings
embedding_types: List[TokenEmbeddings] = [
    WordEmbeddings(configs["train"]["language"] + "-wiki"),

    CharacterEmbeddings(),

    FlairEmbeddings(configs["train"]["language"] + "-forward", fine_tune=configs["train"]["fine_tune_flair_embeddings"]),
    FlairEmbeddings(configs["train"]["language"] + "-backward", fine_tune=configs["train"]["fine_tune_flair_embeddings"])
]
embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embedding_types)

# 5. initialize sequence tagger
sequence_tagger = SequenceTagger(hidden_size=256,
                                 embeddings=embeddings,
                                 tag_dictionary=label_dict,
                                 tag_type=label_type)

# 6. initialize trainer
trainer = ModelTrainer(sequence_tagger, corpus)

# 7. start training
trainer.train(
    output_path,
    learning_rate=0.1,
    mini_batch_size=configs["train"]["mini_batch_size"],
    mini_batch_chunk_size=configs["train"]["mini_batch_chunk_size"],
    max_epochs=5
)

logger.info("evaluating model via load_state_dict...")
sequence_tagger.load_state_dict(SequenceTagger.load(model_path).state_dict())
logger.info(f'Model: "{sequence_tagger}"')

result = sequence_tagger.evaluate(
    data_points=corpus.test,
    gold_label_type=label_type,
    mini_batch_size=configs["train"]["mini_batch_size"],
    mini_batch_chunk_size=configs["train"]["mini_batch_chunk_size"],
    return_loss=False,
)

logger.info("detailed results:")
logger.info(result.detailed_results)
logger.info("main score:")
logger.info(str(result.main_score))
logger.info("classification report:")
logger.info(str(result.classification_report))
logger.info("scores:")
logger.info(str(result.scores))

logger.info("evaluating model via load...")
sequence_tagger = SequenceTagger.load(model_path)  # bad results when loading the model in the default way if the model contains flair embeddings
result = sequence_tagger.evaluate(
    data_points=corpus.test,
    gold_label_type=label_type,
    mini_batch_size=configs["train"]["mini_batch_size"],
    mini_batch_chunk_size=configs["train"]["mini_batch_chunk_size"],
    return_loss=False,
)

logger.info("detailed results:")
logger.info(result.detailed_results)
logger.info("main score:")
logger.info(str(result.main_score))
logger.info("classification report:")
logger.info(str(result.classification_report))
logger.info("scores:")
logger.info(str(result.scores))

To Reproduce

.

Expected behavior

.

Logs and Stack traces

No response

Screenshots

No response

Additional Context

No response

Environment

.

petermartens1992 avatar May 19 '24 21:05 petermartens1992