composer
composer copied to clipboard
Add option to accumulate train loss over tokens.
What does this PR do?
Adds an option to accumulate train loss over the number of tokens in a batch instead of the number of samples.
What issue(s) does this change relate to?
Currently, losses in the trainer are accumulated in a way that weights each sample in a batch equally. However, for NLP use cases where batches contain padding tokens, it makes more sense to accumulate the loss in a way that instead weights every (non-padding) token equally.
Before submitting
- [x] Have you read the contributor guidelines?
- [ ] Is this change a documentation change or typo fix? If so, skip the rest of this checklist.
- [ ] Was this change discussed/approved in a GitHub issue first? It is much more likely to be merged if so.
- [x] Did you update any related docs and document your change?
- [ ] Did you update any related tests and add any new tests related to your change? (see testing)
- [ ] Did you run the tests locally to make sure they pass?
- [x] Did you run
pre-commit
on your change? (see thepre-commit
section of prerequisites)
@dakinggg @irene-dea can you please look?
I agree we should have an option for this. I'm not sure if it's necessary to pass to Composer vs. check if its an attribute/property on a dataloader
@mvpatel2000 I think trainer arg is right for this...code looks fine at a glance, would want to test a bit more before merging.
@mvpatel2000 I think trainer arg is right for this...code looks fine at a glance, would want to test a bit more before merging.
You will own testing?
Can we add a unit test please?
Is there a good template I can base this test on? I'm not sure how to isolate the impact of this change from the rest of the trainer.
@aadyotb Here is a test that exercises the full NLP pipeline (https://github.com/mosaicml/composer/blob/01eec3ad1732074c52cf5da35ef5b8a24531279a/tests/test_full_nlp.py#L231-L338). I think to test this we probably want to (for just the training part of the code that I linked) construct a model that has deterministic loss (based on num padding tokens maybe?) and then test that the results are different in the expected way between sample weighting and token weighting.
So basically make a trainer with a dummy model and a dummy dataset, and then call it with sample weighting and token weighting (with microbatching on), and assert the losses are different in the expected way.
@dakinggg I've added a unit test that requires sample-based and token-based weighting result in different outcomes when padding is present.
@aadyotb awesome, thank you!! Will take a look soon.
@dakinggg after some further thought, I made one additional change to the composer code. Currently, the total loss is just averaged across all ranks, since the trainer assumes that all ranks will have the same batch size. However, this assumption will not hold for token-based batching. For this reason, instead of dividing the microbatch size by the device-local total batch size, I divide the microbatch size by the total batch size averaged across all ranks. This way, different ranks can have different gradient contributions based on how many samples/tokens they process. Let me know if this seems reasonable to you.
@aadyotb Just a heads up Daniel is out this week, and given the subtlety here, I would prefer he finish the review of this PR vs. bringing someone else to review. Please let me know if you need it earlier -- sorry for the delay!
Thanks @mvpatel2000. For the time being, I've implemented these changes by overriding the Trainer
class in our local repo, so we will be okay for now. Happy to get further review once Daniel returns.
@dakinggg bumping this PR for review.
Hey @aadyotb taking a look now, mostly convincing myself that the math is correct :)
Ok cool, LGTM. I'm running a before and after loss curve just to double check (in the normal case, with even number of samples per device batch), and will post that graph here when done.
For posterity, I validated that finetuning behavior is unchanged before and after this PR (for a case with constant samples per device batch), and does change if you specify the new flag.
@aadyotb I think one more run of precommit should do it, and we should be good to merge.