flair
flair copied to clipboard
[Bug]: optimizer state not saved
Describe the bug
Thank you for developing and maintaining this invaluable module!
We would like to save the state of the optimizer at the end of each epoch.
The save_optimizer_state parameter of the fine_tune function seems to be designed for this purpose.
However, the state of the optimizer is not saved even if we set save_optimizer_state=True.
Thank you!
To Reproduce
%pip install scipy==1.10.1 datasets transformers torch==2.0 flair==0.13.1
import torch
import flair
from flair.data import Corpus
from flair.datasets import TREC_6
from flair.embeddings import TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer
# 1. get the corpus
corpus: Corpus = TREC_6()
# 2. what label do we want to predict?
label_type = 'question_class'
# 3. create the label dictionary
label_dict = corpus.make_label_dictionary(label_type=label_type)
# 4. initialize transformer document embeddings (many models are available)
document_embeddings = TransformerDocumentEmbeddings('distilbert-base-uncased', fine_tune=True)
# 5. create the text classifier
classifier = TextClassifier(document_embeddings, label_dictionary=label_dict, label_type=label_type)
# 6. initialize trainer
trainer = ModelTrainer(classifier, corpus)
# 7. run training with fine-tuning
trainer.fine_tune('resources/taggers/question-classification-with-transformer',
learning_rate=5.0e-5,
mini_batch_size=4,
max_epochs=10,
save_optimizer_state=True,
save_model_each_k_epochs=1
)
checkpoint = torch.load('resources/taggers/question-classification-with-transformer/model_epoch_1.pt', map_location=flair.device)
Expected behavior
When save_optimizer_state is true, the checkpoint contains the state_dict of the optimizer.
Logs and Stack traces
No response
Screenshots
No response
Additional Context
No response
Environment
Versions:
Flair
0.13.1
Pytorch
2.0.0+cu117
Transformers
4.40.0
GPU
True
Hi @chelseagzr
thank you for reporting, currently saving/loading the optimizer state is not possible, and that flag should have been removed when the trainer was reworked.
I discussed with @alanakbik to introduce this again, but change it to not only store the optimizer state but also lr-scheduler and plugin states.
I am thinking of a state that can be loaded via the trainer (e.g. trainer = ModelTrainer.load_checkpoint("checkpoint.pt") to load the states, while still allowing Classifier.load("checkpoint.pt") to load the model without training states.
Thank you for the timely response! It would be great if saving the state of a trainer can be enabled!