flair icon indicating copy to clipboard operation
flair copied to clipboard

[Bug]: optimizer state not saved

Open chelseagzr opened this issue 1 year ago • 2 comments

Describe the bug

Thank you for developing and maintaining this invaluable module!

We would like to save the state of the optimizer at the end of each epoch. The save_optimizer_state parameter of the fine_tune function seems to be designed for this purpose. However, the state of the optimizer is not saved even if we set save_optimizer_state=True.

Thank you!

To Reproduce

%pip install scipy==1.10.1 datasets transformers torch==2.0 flair==0.13.1 

import torch
import flair
from flair.data import Corpus
from flair.datasets import TREC_6
from flair.embeddings import TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer

# 1. get the corpus
corpus: Corpus = TREC_6()

# 2. what label do we want to predict?
label_type = 'question_class'

# 3. create the label dictionary
label_dict = corpus.make_label_dictionary(label_type=label_type)

# 4. initialize transformer document embeddings (many models are available)
document_embeddings = TransformerDocumentEmbeddings('distilbert-base-uncased', fine_tune=True)

# 5. create the text classifier
classifier = TextClassifier(document_embeddings, label_dictionary=label_dict, label_type=label_type)

# 6. initialize trainer
trainer = ModelTrainer(classifier, corpus)

# 7. run training with fine-tuning
trainer.fine_tune('resources/taggers/question-classification-with-transformer',
                  learning_rate=5.0e-5,
                  mini_batch_size=4,
                  max_epochs=10,
                  save_optimizer_state=True,
                  save_model_each_k_epochs=1
)

checkpoint = torch.load('resources/taggers/question-classification-with-transformer/model_epoch_1.pt', map_location=flair.device)

Expected behavior

When save_optimizer_state is true, the checkpoint contains the state_dict of the optimizer.

Logs and Stack traces

No response

Screenshots

No response

Additional Context

No response

Environment

Versions:

Flair

0.13.1

Pytorch

2.0.0+cu117

Transformers

4.40.0

GPU

True

chelseagzr avatar Apr 19 '24 08:04 chelseagzr

Hi @chelseagzr

thank you for reporting, currently saving/loading the optimizer state is not possible, and that flag should have been removed when the trainer was reworked.

I discussed with @alanakbik to introduce this again, but change it to not only store the optimizer state but also lr-scheduler and plugin states. I am thinking of a state that can be loaded via the trainer (e.g. trainer = ModelTrainer.load_checkpoint("checkpoint.pt") to load the states, while still allowing Classifier.load("checkpoint.pt") to load the model without training states.

helpmefindaname avatar May 03 '24 15:05 helpmefindaname

Thank you for the timely response! It would be great if saving the state of a trainer can be enabled!

chelseagzr avatar May 06 '24 05:05 chelseagzr