torchtune
torchtune copied to clipboard
Updating generation to use custom_generate_next_token for initial token generation
Context
What is the purpose of this PR? Is it to
- [ ] add a new feature
- [x] fix a bug
- [ ] update tests and/or documentation
- [ ] other (please add here)
I'n using generate
with a custom custom_generate_next_token
defined for a model which outputs both the output of TransformerDecoder
and an additional value head. generate
fails when making the first call to generate_next_token
as it encounters an unexpected return signature for model(...)
.
Let me know if there's a better way for me to do this.
Changelog
Updated the initial call to generate_next_token
to use custom_generate_next_token
.
Test plan
Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help.)
- [x] run pre-commit hooks and linters (make sure you've first installed via
pre-commit install
) - [ ] add unit tests for any new functionality
- [ ] update docstrings for any new or updated methods or classes
- [x] run unit tests via
pytest tests
- [ ] run recipe tests via
pytest tests -m integration_test
- [ ] manually run any new or modified recipes with sufficient proof of correctness
- [ ] include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/966
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
:white_check_mark: No Failures
As of commit a605d9490ea1dccaa708c5828fe6fca44533c08e with merge base dc2b9911cf5a793a8b4f9cd4df321c5271745b4a ():
:green_heart: Looks good so far! There are no failures yet. :green_heart:
This comment was automatically generated by Dr. CI and updates every 15 minutes.
Thanks so much for the review. I think I kind of get the gist of the original design choice - I'll ping any qs on discord : )