torchtune
torchtune copied to clipboard
Mistral testing
Context
What is the purpose of this PR? Is it to
- [ ] add a new feature
- [ ] fix a bug
- [x] update tests and/or documentation
- [ ] other (please add here)
Please link to any issues this PR addresses.
https://github.com/pytorch/torchtune/issues/848
Changelog
I've started adding scripts to verify the implementation of mistral
. I'm using the reference implementation from the official repo. There's another implementation in the repo which uses xformers
for the attention mechanism, but it's not straightforward to replicate. I ended up running into lots of issues when I initially tried.
So far, I've added a script to compare the attention implementation. I've verified the attention implementation produces consistent ouputs using python -m tests.torchtune.models.mistral.scripts.compare_attention
. I'll be keeping the reference implementation in tests/torchtune/models/mistral/scripts/mistral_reference.py
.
Next steps
I'm generally following this process - the plan is to continue copying and testing the components of the mistral implementation, and then testing models as a whole and implementing mapping torchune.models.mistral
into the reference implementation. Finally, I'll add unit tests to integrate into CI.
Good to make sure I'm not too far off the mark : )
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/888
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
:white_check_mark: No Failures
As of commit 190cc8a9cbd50ac1846c8c5537b9e595ec2e094b with merge base bec7babec9c924a0ee7ad27e3f6582bc5bd1fef5 ():
:green_heart: Looks good so far! There are no failures yet. :green_heart:
This comment was automatically generated by Dr. CI and updates every 15 minutes.
I've updated scripts for the rest of the mistral components. I need to write the comparison involving mapping state dicts, update the unit test, and (potentially) add LoRA comparisons.
Okay, all seems good. We now have a unit test for the base mistral
model using the copied implementation from the mistral repo.
For the unfortunate reviewer seeing my +1160 line PR (I hope you read this first!):
I'm hoping the mistral/scripts/compare_{component}.py
for the individual components weren't unnecessary - I'm realising that we already test each component in llama2/scripts/compare_{component}.py
. The only component that I'm testing that wasn't compared in llama2/scripts/
is mistral_mlp
. Maybe it's good that they've been verified with two implementations? If it's not useful I can take them out - mistral/scripts/compare_mistral.py
is the main one.
Thanks for all this extensive testing!
I'm hoping the mistral/scripts/compare_{component}.py for the individual components weren't unnecessary - I'm realising that we already test each component in llama2/scripts/compare_{component}.py. The only component that I'm testing that wasn't compared in llama2/scripts/ is mistral_mlp.
I think we wanna find the right balance of rigorous testing and maintenance here. So while I don't want your work to be in vain, I wonder if we should just add those comparison scripts that differ nontrivially from the Llama2 ones, and for other components point to the Llama2 ones. So in this case that would mean keep compare_mistral and compare_feedforward (since you mentioned it's not tested under llama2). Then you can add a readme to tests/torchtune/models/mistral/scripts
(similar to this one in the llama2 scripts directory) and state that components X, Y, and Z are identical to the llama2 ones and their comparison scripts can be found in that directory. (If you want you can even move the MLP comparison under llama2 so that everything is colocated, but tbh I have no strong preference here.)
not in vain at all - I learnt lots! I've updated and added a README.
Thanks again for your review @ebsmothers :)