torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

Transformer classifier

Open SalmanMohammadi opened this issue 10 months ago • 10 comments

Context/Changelog

See https://github.com/pytorch/torchtune/issues/812

This PR adds a TransformerClassifier layer which extends the TransformerDecoder functionality to classification tasks. Exemplar component and model builders have been implemented for the base mistral model.

Test plan

Testing this was tricky as there is currently no reference implementation for base mistral to test against. I performed some numerical testing against HuggingFace MistralModel and AutoModelForSequenceClassification with mistralai/Mistral-7B-v0.1, with the same parameters found in test_llama3.py. However, neither the classification model outputs, nor the base MistralModel and torchtune.models.mistral.mistral produced similar outputs to the HF models. I've left a sort-of dummy test in which just asserts the output shapes are correct. I can probably test the sequence pooling and classification independently once we agree on how they integrate into the codebase.

Questions/Next steps

This is part of a broader plan to implement RLHF in Torchtune. The TransformerClassifier can hopefully be used against any sequence model we have. @kartikayk - we could implement a recipe for training a mistral reward model using this. I can start implementing a PPO recipe using this reward model, too.

SalmanMohammadi avatar Apr 22 '24 20:04 SalmanMohammadi

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/840

Note: Links to docs will display an error until the docs builds have been completed.

:white_check_mark: No Failures

As of commit 13063d1fbfb26eaf5b5fa4d118394af6cd3464f0 with merge base fde0dc403fc7a471fe402f18f7f1f5f7e9646164 (image): :green_heart: Looks good so far! There are no failures yet. :green_heart:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorch-bot[bot] avatar Apr 22 '24 20:04 pytorch-bot[bot]

Thanks @SalmanMohammadi for this PR! This looks mostly reasonable but I have a few high level suggestions.

The new functionality you're adding is mainly using the output layer of TransformerDecoder to project to number of classes, and then extracting the last predicted token. Why not just make a separate classifier head class that consists of this linear projection and post-processing step in the forward? Then you can compose these together directly in the TransformerDecoder constructor in mistral_classifier:

TransformerDecoder(
    tok_embeddings=tok_embeddings,
    layer=layer,
    num_layers=num_layers,
    max_seq_len=max_seq_len,
    num_heads=num_heads,
    head_dim=head_dim,
    norm=norm,
    output=ClassificationHead(embed_dim, num_classes),
)

This still keeps the flexibility of swapping out the classification head with some other head, and you keep all the added logic contained in a new class. Your testing will be simpler too because you only need to test the head and not the entire transformer classifier.

Also, please make sure you've run the linters with pre-commit run --all-files :)

RdoubleA avatar Apr 22 '24 20:04 RdoubleA

Thanks @RdoubleA :) I've fixed the linting.

I tried setting it up this way, but in the forward for TransformerClassifier I also use the input token ids to grab the last non-padding token in each sequence of the batch:

        padding_mask = tokens == 0
        if padding_mask.any():
            sequence_lengths = (
                (padding_mask.logical_not().cumsum(-1) == 1).sum(-1).to(logits.device)
            )
        else:
            sequence_lengths = -1

I thought I'd have to modify the function call in TransformerDecoder's forward:

        h = self.norm(h)

        output = self.output(h).float() # pass input tokens in here
        return output

to achieve this - so I wrapped around TransformerDecoder instead.

I agree your suggestion (and @ebsmothers on Discord) is cleaner. Do you have any thoughts on how I'd still be able to use the input tokens in the output callable?

SalmanMohammadi avatar Apr 22 '24 21:04 SalmanMohammadi

I tried setting it up this way, but in the forward for TransformerClassifier I also use the input token ids to grab the last non-padding token in each sequence of the batch:

@SalmanMohammadi ah I wasn't aware of this in my suggestion on Discord. In that case it is trickier cause you are actually changing the signature of the output layer. Then at a high level I think it makes sense to add a separate module to handle taking the last token. I'll take a closer look at the exact implementation now.

ebsmothers avatar Apr 23 '24 16:04 ebsmothers

Note, the type hints are wrong for the component and model builders (the docs are correct). I think you're right that my pre-commits still aren't working right. I'll fix those and address any comments in a bit.

SalmanMohammadi avatar Apr 23 '24 17:04 SalmanMohammadi

I've updated my PR @ebsmothers with the changes we discussed :)

SalmanMohammadi avatar Apr 24 '24 16:04 SalmanMohammadi

OK the changes look good, but the unit test is failing on CI. Does it pass for you locally? I can also help a bit here with debugging if you need

ebsmothers avatar Apr 29 '24 18:04 ebsmothers

Yeah it's passing locally. It looks like it's failing the second test case. Are there more detailed logs? I sometimes fail some of the tests on my mac because the precision tolerance isn't the same on my machine for some reason (i.e. the numbers look right, but the errors are just slightly above the tolerance).

SalmanMohammadi avatar Apr 29 '24 21:04 SalmanMohammadi

OK I think I cracked the case here. The lack of detailed logs can often be indication that the runner actually just crashed (can be due to out of memory or something like that). It crashed on the 2nd test case, which is the largest one. You can either reduce the batch size or even just remove that test case, since it's not actually testing fundamentally different logic than the other cases. I tested on my fork with a batch size of 4 and confirmed that the CI succeeds (but tbh I'd just scrap the test case for the reason I mentioned).

ebsmothers avatar Apr 29 '24 23:04 ebsmothers

Thanks so much for your help debugging :) I'll keep that in mind for the future!

SalmanMohammadi avatar Apr 30 '24 09:04 SalmanMohammadi

Good catch!

SalmanMohammadi avatar Apr 30 '24 16:04 SalmanMohammadi

Thank you for your patience!!

SalmanMohammadi avatar Apr 30 '24 21:04 SalmanMohammadi