flair icon indicating copy to clipboard operation
flair copied to clipboard

Adding a Deep Nearest Class Means Classifier model to Flair

Open sheldon-roberts opened this issue 1 year ago • 3 comments

This PR adds a DeepNCMClassifier to flair.models My reasons for adding this model are outlined in the issue: #3531

This model requires a TrainerPlugin because it makes the prototype updates using an after_training_batch hook. Please let me know if there is a cleaner way to handle this.

Example Script:

from flair.data import Corpus
from flair.datasets import TREC_50
from flair.embeddings import TransformerDocumentEmbeddings
from flair.models import DeepNCMClassifier
from flair.trainers import ModelTrainer
from flair.trainers.plugins import DeepNCMPlugin

# load the TREC dataset
corpus: Corpus = TREC_50()

# make a transformer document embedding
document_embeddings = TransformerDocumentEmbeddings("roberta-base", fine_tune=True)

# create the classifier
classifier = DeepNCMClassifier(
    document_embeddings,
    label_dictionary=corpus.make_label_dictionary(label_type="class"),
    label_type="class",
    use_encoder=False,
    mean_update_method="condensation",
)

# initialize the trainer
trainer = ModelTrainer(classifier, corpus)

# train the model
trainer.fine_tune(
    "resources/taggers/deepncm_trec",
    plugins=[DeepNCMPlugin()],
)

sheldon-roberts avatar Aug 19 '24 00:08 sheldon-roberts

Hello @sheldon-roberts,

Thanks a lot for your contribution! This is had been buried deep in the backlog of things to implement.

I also don't see a way of how this could be implemented without a TrainerPlugin.

What do you think about implementing this as a decoder (such as the PrototypicalDecoder), such that it can be used with the default classifier? Then it could be used with all model types (i.e. span, text, etc. classification).

Additionally, what do you think about supporting the different distance functions similar to the PrototypicalDecoder?

plonerma avatar Aug 19 '24 14:08 plonerma

Hi @plonerma, Thanks for taking a look!

What do you think about implementing this as a decoder (such as the PrototypicalDecoder), such that it can be used with the default classifier? Then it could be used with all model types (i.e. span, text, etc. classification). Additionally, what do you think about supporting the different distance functions similar to the PrototypicalDecoder?

I really like both of these ideas! I will look into making these changes soon

sheldon-roberts avatar Aug 20 '24 04:08 sheldon-roberts

Hello @sheldon-roberts,

Thanks a lot for your contribution! This is had been buried deep in the backlog of things to implement.

I also don't see a way of how this could be implemented without a TrainerPlugin.

What do you think about implementing this as a decoder (such as the PrototypicalDecoder), such that it can be used with the default classifier? Then it could be used with all model types (i.e. span, text, etc. classification).

Additionally, what do you think about supporting the different distance functions similar to the PrototypicalDecoder?

In order to avoid using a trainer plugin, could we just add a function like def after_training_epoch(): pass that gets added to the base Model class, which gets called right before or after self.dispatch("after_training_epoch", epoch=epoch) in the train_custom function?

I think this would work with this being a class, but might not work when it gets changed to a decoder.

MattGPT-ai avatar Aug 28 '24 22:08 MattGPT-ai

I am currently working on converting this class to a simpler decoder. I have gotten it to work, but it requires some changes to other classes; the label tensors have to be provided to the forward passes so they can go into the decoder call. Specifically, in DefaultClassifier.forward_loss, you need to have scores = self.decoder(data_point_tensor, label_tensor). In predict, this isn't necessary because you don't need to calculate the proto updates.

Would it make sense to always pass in this in, but just have most base cases ignore the parameter? Another alternative would be to have the class set self.label_tensor before the call so it doesn't need to be an input param at all. Not sure if anyone else has a suggestion of how to design this. I will be pushing up the specific code soon, but just looking for opinions.

MattGPT-ai avatar Nov 20 '24 20:11 MattGPT-ai

This has

Hello @sheldon-roberts,

Thanks a lot for your contribution! This is had been buried deep in the backlog of things to implement.

I also don't see a way of how this could be implemented without a TrainerPlugin.

What do you think about implementing this as a decoder (such as the PrototypicalDecoder), such that it can be used with the default classifier? Then it could be used with all model types (i.e. span, text, etc. classification).

Additionally, what do you think about supporting the different distance functions similar to the PrototypicalDecoder?

This has been updated to be a decoder. It's overall a lot less code and simpler, although it required some small changes to the DefaultClassifier class, and still requires a plugin. Am definitely open to any suggestion of how to better integrate this.

MattGPT-ai avatar Nov 24 '24 00:11 MattGPT-ai

Looks like tests are passing except for a couple of MyPy checks that aren't directly related to the changes in the PR, I think just files that this PR touches. Do you have any suggestions for fixing these typing problems?

MattGPT-ai avatar Nov 24 '24 02:11 MattGPT-ai

Would it be better to move this class into flair/nn/decoder.py now that it is a decoder?

MattGPT-ai avatar Nov 25 '24 21:11 MattGPT-ai

@plonerma Are you able to re-review this before the next release?

MattGPT-ai avatar Dec 17 '24 18:12 MattGPT-ai

I've moved this to decoder.py to be more consistent with other decoders

MattGPT-ai avatar Dec 18 '24 20:12 MattGPT-ai

@sheldon-roberts and @MattGPT-ai : Thanks a lot for your collaborative effort!

I made a few minor changes and merged the current master branch into the PR. Now, all checks are passed.

plonerma avatar Jan 03 '25 14:01 plonerma

This looks good now, can we merge?

MattGPT-ai avatar Jan 03 '25 23:01 MattGPT-ai

Thanks a lot for adding this @MattGPT-ai and for reviewing @plonerma!

alanakbik avatar Jan 07 '25 12:01 alanakbik