rl icon indicating copy to clipboard operation
rl copied to clipboard

[BUG] RuntimeError when passing dialogue data to LLMEnv

Open albertbou92 opened this issue 9 months ago • 1 comments

Describe the bug

I see very cool advancements in the direction of LLM RL training in the repo, awesome work! :)

After playing a bit with the LLMEnv I got the following error when passing dialogue data to the env.

RuntimeError: modifying the batch size of a lazy representation of a tensordict is not permitted. Consider instantiating the tensordict first by calling td = td.to_tensordict() before resetting the batch size.

Dialogue data is a pretty common format when using LLMs to allow the model to see inputs from the system and the user, and you can easily format the data into a single string by using different chat templates. Not sure if the intention is to support dialogue data in the format I am passing it to the env, but I feel like it would be convenient.

I detected that the bug goes away if I comment out this line, but I do not think this is the solution. https://github.com/pytorch/rl/blob/main/torchrl/envs/transforms/llm.py#L534

To Reproduce

from datasets import Dataset
from torch.utils.data import DataLoader
from torchrl.envs.custom.llm import LLMEnv


def collate_fn(batch: list[dict]) -> dict[str, list]:
    return {k: [el[k] for el in batch] for k in batch[0]}

# Dummy data
sample = {
    "messages": [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "What's the capital of France?"}
    ]
}


# Repeat the sample 1000 times
data = [sample] * 1000

# Create the dataset
dataset = Dataset.from_list(data)

# Create a PyTorch DataLoader
batch_size = 16
dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn
)

env = LLMEnv.from_dataloader(
    dataloader=dataloader,
    str2str=True,
    batch_size=4,
    str_key='messages'
)

obs = env.reset()

Additional context

Add any other context about the problem here.

Reason and Possible fixes

If you know or suspect the reason for this bug, paste the code lines and suggest modifications.

Checklist

  • [ ] I have checked that there is no similar issue in the repo (required)
  • [ ] I have read the documentation (required)
  • [ ] I have provided a minimal working example to reproduce the bug (required)

albertbou92 avatar Mar 25 '25 23:03 albertbou92

Hello @albertbou92 Glad you like it! I'm going to work on a version of this env that works fine with these kinds of data structures, something along the line of this:

@tensorclass
class History:
    role: str
    content: str

Then we need an env that stacks histories together and can call the parser from the HF tokenizer.

I already have a prototype working with MLGym, I'll ping you once I merge it in torchrl (should be pretty soon!)

vmoens avatar Mar 26 '25 09:03 vmoens