torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

support for multigpu PPO

Open parthsarthi03 opened this issue 10 months ago • 3 comments

Hi, thank you for the awesome work! I wanted to know if there are plans to support a multigpu full PPO recipe, for my usecase, it is very slow to train on a single GPU. I also have to have generate tokens in the 2000-3000 range so forward batch size has to be very low as well.

parthsarthi03 avatar Apr 17 '25 11:04 parthsarthi03

Hi @parthsarthi03 thanks for creating the issue. We are not currently working on a distributed PPO recipe. But have you seen our GRPO recipe? This will definitely work with distributed (even up to multiple nodes). It may be pretty doable to take the linear combination of this and the single-device recipe to get what you're looking for.

ebsmothers avatar Apr 17 '25 22:04 ebsmothers

@ebsmothers Gotcha, thanks for the pointers! I had a quick question while implementing it:

In the GRPO recipe we call

opt_state_dict = training.get_full_optimizer_state_dict(
    self._model,
    self._optimizer,
    self._is_rank_zero,
    device=self._device,
)

because only self._model’s parameters are in the optimizer.

In the PPO recipe, however the optimizer is built from

chain(self._policy_model.parameters(), self._value_model.parameters())

so the optimizer state spans two separate FSDP-sharded modules.

Is the right approach to wrap those two sub-modules in a dummy nn.Module/ModuleDict and pass that wrapper to training.get_full_optimizer_state_dict, so that every parameter seen by the optimizer is reachable from model.state_dict()? Something like:

self._actor_critic = nn.ModuleDict({
    "policy": self._policy_model,
    "value":  self._value_model,
})
opt_state_dict = training.get_full_optimizer_state_dict(
    self._actor_critic,
    self._optimizer,
    self._is_rank_zero,
    device=self._device,
)

parthsarthi03 avatar Apr 25 '25 03:04 parthsarthi03

Unfortunately, it looks like right now in get_optimizer_state_dict, it expects at most one nn.Module for the model param. As such, the simplest thing to do (albeit not the most efficient would be to get the optimizer state dict separately for each module.

Or you could write your own function that slightly modifies the one in DCP to accommodate this use case.

joecummings avatar Apr 28 '25 17:04 joecummings