torchtune
torchtune copied to clipboard
Tool Calling Fine-Tuning fails because of validation logic in messages
Hi Team,
First, please correct me if my understanding is wrong or if I missed something.
However, I believe the validation logic in torch.data.messages here doesn't account for tool calling flow.
I've a dataset which follows the following flow: system -> user -> assistant -> tool -> assistant -> user and so on and so forth.
I've setup the correct role mappings for this dataset and also tried extending SFTDataset by creating a Transform which ensures that eot=False logic is followed for a) Assistant call before Tool call, b) Tool Call
However both fail the validation.
Please let me know if I missed something super obvious.
Thanks!
Thanks for flagging this. This function is quite outdated and definitely needs to be updated to account for tool calling logic. At a quick glance it looks like the last_turn logic needs to account for potential tool calls that will involve a non user role in between assistant messages. Your use case sounds like the standard tool call flow and we should update this utility to support this.
I had the same problem before and fixed it locally.
Just opened a PR for that support: https://github.com/pytorch/torchtune/pull/2407
@init27 does #2407 resolve this issue if it lands?
@pbontrager thanks and thanks @musabgultekin for quick PR! Actually I need some clarification from the team:
Context: When we do tool call, a single flow is:
(step-1) user (with ask requiring tool call) -> (step-2) Assistant (responds with tool call arguments) -> (step-3) Tool (responds with Tool response) -> (step-4) Assistant (answers user query)
Now, my Q is regarding the eot token-which denotes if the turn of one role (assistant/user/system is over)
For the above flow, do we set eot=False in both (step-2) and (step-3) or is it eot=False JUST in (step-3)?
From @musabgultekin's PR:
eot=False if message["role"] in ["tool", "ipython"] else True
I'm curious since you mentioned you faced this issue locally earlier and this fixed it-My understanding is eot should ALSO be set to False when Assistant is returning the tool call, step-2 in above example (before the tool call happens)?
cc: @RdoubleA
@init27
Step 2 (tool calls):
eot should be true when assistant is calling a custom tool.
See:
While builtin-tool calls end with <|eom_id|>, notice the <|eot_id|> for zero-shot function calls.
https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_3/#-zero-shot-function-calling-
Step 3 (tool responses)
On 3.3 format, there Is no specific mention whether we should add eot on tool responses (step 3). I did eot=False, because llm doesn't need to stop there on forward pass, As the generation prompt will be appended after the tool responses. BUT, since it's not strictly mentioned in the docs, one could add eot on tool responses too. I don't think it would make a difference.
@musabgultekin thanks for the reply, actually referring to Tokens section here
<|eom_id|>: End of message. A message represents a possible stopping point for execution where the model can inform the executor that a tool call needs to be made. This is used for multi-step interactions between the model and any available tools. This token is emitted by the model when the Environment: ipython instruction is used in the system prompt, or if the model calls for a built-in tool. <|eot_id|>: End of turn. Represents when the model has determined that it has finished interacting with the user message that initiated its response. This is used in two scenarios: at the end of a direct interaction between the model and the user at the end of multiple interactions between the model and any available tools This token signals to the executor that the model has finished generating a response.
It seems logically a turn is "over" after the assistant -> tool -> assistant interaction is over.
I'll also wait for the TorchTune team to weigh in.
It seems logically a turn is "over" after the assistant -> tool -> assistant interaction is over.
This is our understanding as well, and is also discussed on the Message documentation: https://github.com/pytorch/torchtune/blob/e6cba2532d51a53936c7646bd4cdaa6b2b57ed66/torchtune/data/_messages.py#L50
While builtin-tool calls end with <|eom_id|>, notice the <|eot_id|> for zero-shot function calls.
This example you linked is confusing, it does seem that the actual tool output is not in the list of messages so eot=True makes sense here. But if you were training with tool returns and an assistant response afterward it wouldn't make sense to have eot=True until the end. but if there's other counter-examples in other codebases, we can reconsider this.
Just catching up on the conversation here. It seems to me that the eot flag needs to be made context-aware? I.e. the logic should be something like
eot = True if (role == 'assistant' and next_message_role != 'tool') else False
Lmk if I'm missing the point here though. (And if you guys already have this figured out please ignore me 😛 )
@ebsmothers Actually, @RdoubleA and I agree it should be:
eot = False if (role == 'tool' or (role == 'assistant' and next_message_role == 'tool')) else True
It seems we need to clarify the model card examples a bit as well (we're looking into it)
has anybody tried to load an openai / sharegpt fromat dataset with tools? It simply doesn't work:
the current main branch has this validation code: line
if message.role in ["tool", "ipython"] and not last_message.ipython:
raise ValueError(
f"Tool or ipython message at index {i} must follow an ipython message"
So before the tool call the prev message should be a message with .role == 'assistant' and .ipython==True... How is this attribute set? It isn't! Check for ex. https://github.com/pytorch/torchtune/blob/dda297784c6001b443b65179d11ea754a4627d2c/torchtune/data/_messages.py#L566C7-L566C23
Btw, I'd be greatfull for one example in tutorial of openai/sharegpt dataset with tools that is shown to be loaded correctly.
has anybody tried to load an openai / sharegpt fromat dataset with tools? It simply doesn't work:
the current main branch has this validation code: line
if message.role in ["tool", "ipython"] and not last_message.ipython: raise ValueError( f"Tool or ipython message at index {i} must follow an ipython message"So before the tool call the prev message should be a message with .role == 'assistant' and .ipython==True... How is this attribute set? It isn't! Check for ex.
dda2977/torchtune/data/_messages.py#L566C7-L566C23Btw, I'd be greatfull for one example in tutorial of openai/sharegpt dataset with tools that is shown to be loaded correctly.
Great catch - sorry about that! We only have text and image support for these transforms OOTB. You can subclass and add the support yourself if you need to use a stable version of the library (should be very straightforward), but I've also opened #2618 and #2617 so you can track my progress there. I'll get right on adding this and will be available in nightlies soon.