ART icon indicating copy to clipboard operation
ART copied to clipboard

Add @art.rollout decorator to gather trajectories

Open arcticfly opened this issue 6 months ago • 6 comments

Proposal

  • Create an @art.rollout decorator which wraps a rollout function and constructs a trajectory (potentially with multiple histories) automatically, similar to how @weave.op automatically wraps an LLM-enabled function and records all function calls then reports a trace.
  • Allow rollout functions to access the current trajectory through some kind of get_current_trajectory() helper function.
  • Store completion ids on messages to make it possible to access and manipulate a certain history using trajectory.get_history(completion_id)
    • Useful when adding tool messages after executing a tool
  • Also create a gather_trajectory helper function that calls a rollout function decorated with @art.rollout and returns the generated trajectory.

Worth taking a good look at @weave.op, and we may even want to integrate with them or wrap their decorator since they've already done the integration work to read completions through a lot of LLM clients.

Messy ideas in proposal doc.

Caveats

This @art.rollout decorator will need to automatically determine when LLM completions are part of the same history or separate histories.

Example

Our current rollout functions require the user to initialize and add messages to an art.Trajectory object, like so:

async def get_summary(model: art.Model, scenario: Scenario) -> art.Trajectory:
    traj = art.Trajectory(
        messages_and_choices=[
            {
                "role": "system",
                "content": f"Summarize: {scenario.text}"
            },
        ]
    )

    completion = await client.chat.completions.create(
        model=model.name,
        messages=traj.messages()
    )

    traj.messages_and_choices.append(completion.choices[0])

    return traj

However, this makes our rollout functions verbose (because they have to initialize and update the trajectories) and difficult to use elsewhere in the codebase (because they don't return the processed type that the rollout function was meant to generate).

By decorating our function with @art.rollout and returning the summary as a string, our code will be made much cleaner:

@art.rollout
async def get_summary(model: art.Model, scenario: Scenario) -> str:
    completion = await client.chat.completions.create(
        model=model.name,
        messages=[
            {
                "role": "system",
                "content": f"Summarize: {scenario.text}"
            },
        ]
    )

    return completion.choices[0].message.content

Used in production flow:

async def caller():
    summary = await get_summary(model, scenario)
    print(summary)

Used in training flow:

trajectory = await gather_trajectory(get_summary(model, scenario))

arcticfly avatar Aug 19 '25 22:08 arcticfly