torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

Updating generation to use custom_generate_next_token for initial token generation

Open SalmanMohammadi opened this issue 9 months ago • 2 comments

Context

What is the purpose of this PR? Is it to

  • [ ] add a new feature
  • [x] fix a bug
  • [ ] update tests and/or documentation
  • [ ] other (please add here)

I'n using generate with a custom custom_generate_next_token defined for a model which outputs both the output of TransformerDecoder and an additional value head. generate fails when making the first call to generate_next_token as it encounters an unexpected return signature for model(...).

Let me know if there's a better way for me to do this.

Changelog

Updated the initial call to generate_next_token to use custom_generate_next_token.

Test plan

Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help.)

  • [x] run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • [ ] add unit tests for any new functionality
  • [ ] update docstrings for any new or updated methods or classes
  • [x] run unit tests via pytest tests
  • [ ] run recipe tests via pytest tests -m integration_test
  • [ ] manually run any new or modified recipes with sufficient proof of correctness
  • [ ] include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

SalmanMohammadi avatar May 12 '24 15:05 SalmanMohammadi

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/966

Note: Links to docs will display an error until the docs builds have been completed.

:white_check_mark: No Failures

As of commit a605d9490ea1dccaa708c5828fe6fca44533c08e with merge base dc2b9911cf5a793a8b4f9cd4df321c5271745b4a (image): :green_heart: Looks good so far! There are no failures yet. :green_heart:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorch-bot[bot] avatar May 12 '24 15:05 pytorch-bot[bot]

Thanks so much for the review. I think I kind of get the gist of the original design choice - I'll ping any qs on discord : )

SalmanMohammadi avatar May 17 '24 18:05 SalmanMohammadi