torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

The first token generation does not use custom_generate_next_token

Open RomDeffayet opened this issue 9 months ago • 6 comments

Hi,

in torchtune.generation.generate, the first token is generated using generate_next_token instead of custom_generate_next_token. See the exact line

RomDeffayet avatar May 19 '25 15:05 RomDeffayet

You're right! I'll put up a PR right away to fix this

joecummings avatar May 19 '25 15:05 joecummings

Please see this comment for why this is the case https://github.com/pytorch/torchtune/pull/966#pullrequestreview-2062019488

salmanmohammadi avatar May 19 '25 16:05 salmanmohammadi

Ah I see. In that case, @RomDeffayet can I ask if you are running into a bug somewhere WRT this behavior or just noticed that it was a little strange?

If the former, we can try to properly address this. If the later, we can still clear up some confusion by renaming to something proper like compiled_generate_next_token or something like that.

joecummings avatar May 19 '25 16:05 joecummings

Okay, in my case I am actually replacing the next token generation function with a custom one. If this is not an intended behavior, changing the name compiled_generate_next_token definitely makes sense (also, why expose it to the user then ?)

RomDeffayet avatar May 19 '25 17:05 RomDeffayet

Okay, in my case I am actually replacing the next token generation function with a custom one. If this is not an intended behavior, changing the name compiled_generate_next_token definitely makes sense (also, why expose it to the user then ?)

Yeah I think this is the correct way to go about it. The only reason I can think of exposing it to the user is so that the user has control over how the function is compiled, rather than hardocding the compile kwargs inside the generate function.

salmanmohammadi avatar May 19 '25 17:05 salmanmohammadi

Alright, so the quick fix would be to rename the parameter custom_generate_next_token to compiled_decode_next_token.

However, this intuition was built off the design of the gpt-fast code from over a year ago. In that time, I believe the compile team has made significant strides.

@IvanKobzarev Does it still matter that much that we compile the decode_one_token function and then pass that into our generation function? Or does compile now easily catch the compile plan? When you take a look at how we use our generate function and how we implement our generate function is there any easy room for improvement that we're missing? It's unfortunate we no longer have the actively maintained gpt-fast project to signal what the best practices are here : /

joecummings avatar May 19 '25 20:05 joecummings