Adding a Deep Nearest Class Means Classifier model to Flair
This PR adds a DeepNCMClassifier to flair.models
My reasons for adding this model are outlined in the issue: #3531
This model requires a TrainerPlugin because it makes the prototype updates using an after_training_batch hook. Please let me know if there is a cleaner way to handle this.
Example Script:
from flair.data import Corpus
from flair.datasets import TREC_50
from flair.embeddings import TransformerDocumentEmbeddings
from flair.models import DeepNCMClassifier
from flair.trainers import ModelTrainer
from flair.trainers.plugins import DeepNCMPlugin
# load the TREC dataset
corpus: Corpus = TREC_50()
# make a transformer document embedding
document_embeddings = TransformerDocumentEmbeddings("roberta-base", fine_tune=True)
# create the classifier
classifier = DeepNCMClassifier(
document_embeddings,
label_dictionary=corpus.make_label_dictionary(label_type="class"),
label_type="class",
use_encoder=False,
mean_update_method="condensation",
)
# initialize the trainer
trainer = ModelTrainer(classifier, corpus)
# train the model
trainer.fine_tune(
"resources/taggers/deepncm_trec",
plugins=[DeepNCMPlugin()],
)
Hello @sheldon-roberts,
Thanks a lot for your contribution! This is had been buried deep in the backlog of things to implement.
I also don't see a way of how this could be implemented without a TrainerPlugin.
What do you think about implementing this as a decoder (such as the PrototypicalDecoder), such that it can be used with the default classifier? Then it could be used with all model types (i.e. span, text, etc. classification).
Additionally, what do you think about supporting the different distance functions similar to the PrototypicalDecoder?
Hi @plonerma, Thanks for taking a look!
What do you think about implementing this as a decoder (such as the
PrototypicalDecoder), such that it can be used with the default classifier? Then it could be used with all model types (i.e. span, text, etc. classification). Additionally, what do you think about supporting the different distance functions similar to thePrototypicalDecoder?
I really like both of these ideas! I will look into making these changes soon
Hello @sheldon-roberts,
Thanks a lot for your contribution! This is had been buried deep in the backlog of things to implement.
I also don't see a way of how this could be implemented without a
TrainerPlugin.What do you think about implementing this as a decoder (such as the
PrototypicalDecoder), such that it can be used with the default classifier? Then it could be used with all model types (i.e. span, text, etc. classification).Additionally, what do you think about supporting the different distance functions similar to the
PrototypicalDecoder?
In order to avoid using a trainer plugin, could we just add a function like def after_training_epoch(): pass that gets added to the base Model class, which gets called right before or after self.dispatch("after_training_epoch", epoch=epoch) in the train_custom function?
I think this would work with this being a class, but might not work when it gets changed to a decoder.
I am currently working on converting this class to a simpler decoder. I have gotten it to work, but it requires some changes to other classes; the label tensors have to be provided to the forward passes so they can go into the decoder call. Specifically, in DefaultClassifier.forward_loss, you need to have scores = self.decoder(data_point_tensor, label_tensor). In predict, this isn't necessary because you don't need to calculate the proto updates.
Would it make sense to always pass in this in, but just have most base cases ignore the parameter? Another alternative would be to have the class set self.label_tensor before the call so it doesn't need to be an input param at all. Not sure if anyone else has a suggestion of how to design this. I will be pushing up the specific code soon, but just looking for opinions.
This has
Hello @sheldon-roberts,
Thanks a lot for your contribution! This is had been buried deep in the backlog of things to implement.
I also don't see a way of how this could be implemented without a
TrainerPlugin.What do you think about implementing this as a decoder (such as the
PrototypicalDecoder), such that it can be used with the default classifier? Then it could be used with all model types (i.e. span, text, etc. classification).Additionally, what do you think about supporting the different distance functions similar to the
PrototypicalDecoder?
This has been updated to be a decoder. It's overall a lot less code and simpler, although it required some small changes to the DefaultClassifier class, and still requires a plugin. Am definitely open to any suggestion of how to better integrate this.
Looks like tests are passing except for a couple of MyPy checks that aren't directly related to the changes in the PR, I think just files that this PR touches. Do you have any suggestions for fixing these typing problems?
Would it be better to move this class into flair/nn/decoder.py now that it is a decoder?
@plonerma Are you able to re-review this before the next release?
I've moved this to decoder.py to be more consistent with other decoders
@sheldon-roberts and @MattGPT-ai : Thanks a lot for your collaborative effort!
I made a few minor changes and merged the current master branch into the PR. Now, all checks are passed.
This looks good now, can we merge?
Thanks a lot for adding this @MattGPT-ai and for reviewing @plonerma!