torchtune
torchtune copied to clipboard
Transformer classifier
Context/Changelog
See https://github.com/pytorch/torchtune/issues/812
This PR adds a TransformerClassifier layer which extends the TransformerDecoder functionality to classification tasks. Exemplar component and model builders have been implemented for the base mistral model.
Test plan
Testing this was tricky as there is currently no reference implementation for base mistral to test against. I performed some numerical testing against HuggingFace MistralModel and AutoModelForSequenceClassification with mistralai/Mistral-7B-v0.1, with the same parameters found in test_llama3.py. However, neither the classification model outputs, nor the base MistralModel and torchtune.models.mistral.mistral produced similar outputs to the HF models. I've left a sort-of dummy test in which just asserts the output shapes are correct. I can probably test the sequence pooling and classification independently once we agree on how they integrate into the codebase.
Questions/Next steps
This is part of a broader plan to implement RLHF in Torchtune. The TransformerClassifier can hopefully be used against any sequence model we have. @kartikayk - we could implement a recipe for training a mistral reward model using this. I can start implementing a PPO recipe using this reward model, too.
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/840
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
:white_check_mark: No Failures
As of commit 13063d1fbfb26eaf5b5fa4d118394af6cd3464f0 with merge base fde0dc403fc7a471fe402f18f7f1f5f7e9646164 ():
:green_heart: Looks good so far! There are no failures yet. :green_heart:
This comment was automatically generated by Dr. CI and updates every 15 minutes.
Thanks @SalmanMohammadi for this PR! This looks mostly reasonable but I have a few high level suggestions.
The new functionality you're adding is mainly using the output layer of TransformerDecoder
to project to number of classes, and then extracting the last predicted token. Why not just make a separate classifier head class that consists of this linear projection and post-processing step in the forward? Then you can compose these together directly in the TransformerDecoder
constructor in mistral_classifier
:
TransformerDecoder(
tok_embeddings=tok_embeddings,
layer=layer,
num_layers=num_layers,
max_seq_len=max_seq_len,
num_heads=num_heads,
head_dim=head_dim,
norm=norm,
output=ClassificationHead(embed_dim, num_classes),
)
This still keeps the flexibility of swapping out the classification head with some other head, and you keep all the added logic contained in a new class. Your testing will be simpler too because you only need to test the head and not the entire transformer classifier.
Also, please make sure you've run the linters with pre-commit run --all-files
:)
Thanks @RdoubleA :) I've fixed the linting.
I tried setting it up this way, but in the forward
for TransformerClassifier
I also use the input token ids to grab the last non-padding token in each sequence of the batch:
padding_mask = tokens == 0
if padding_mask.any():
sequence_lengths = (
(padding_mask.logical_not().cumsum(-1) == 1).sum(-1).to(logits.device)
)
else:
sequence_lengths = -1
I thought I'd have to modify the function call in TransformerDecoder
's forward
:
h = self.norm(h)
output = self.output(h).float() # pass input tokens in here
return output
to achieve this - so I wrapped around TransformerDecoder
instead.
I agree your suggestion (and @ebsmothers on Discord) is cleaner. Do you have any thoughts on how I'd still be able to use the input tokens
in the output
callable?
I tried setting it up this way, but in the forward for TransformerClassifier I also use the input token ids to grab the last non-padding token in each sequence of the batch:
@SalmanMohammadi ah I wasn't aware of this in my suggestion on Discord. In that case it is trickier cause you are actually changing the signature of the output layer. Then at a high level I think it makes sense to add a separate module to handle taking the last token. I'll take a closer look at the exact implementation now.
Note, the type hints are wrong for the component and model builders (the docs are correct). I think you're right that my pre-commits still aren't working right. I'll fix those and address any comments in a bit.
I've updated my PR @ebsmothers with the changes we discussed :)
OK the changes look good, but the unit test is failing on CI. Does it pass for you locally? I can also help a bit here with debugging if you need
Yeah it's passing locally. It looks like it's failing the second test case. Are there more detailed logs? I sometimes fail some of the tests on my mac because the precision tolerance isn't the same on my machine for some reason (i.e. the numbers look right, but the errors are just slightly above the tolerance).
OK I think I cracked the case here. The lack of detailed logs can often be indication that the runner actually just crashed (can be due to out of memory or something like that). It crashed on the 2nd test case, which is the largest one. You can either reduce the batch size or even just remove that test case, since it's not actually testing fundamentally different logic than the other cases. I tested on my fork with a batch size of 4 and confirmed that the CI succeeds (but tbh I'd just scrap the test case for the reason I mentioned).
Thanks so much for your help debugging :) I'll keep that in mind for the future!
Good catch!
Thank you for your patience!!