PiPPy icon indicating copy to clipboard operation
PiPPy copied to clipboard

Retrieving the Trained Model

Open dheerj188 opened this issue 9 months ago • 6 comments

How can we get back our trained model once we train using the pipe object and Gpipe Scheduler as a normal nn.Module class?

dheerj188 avatar Apr 29 '24 08:04 dheerj188

Also interested in this. Did you ever figure it out?

Xynonners avatar May 31 '24 06:05 Xynonners

Not as of now.

dheerj188 avatar May 31 '24 12:05 dheerj188

Sorry for replying late. We have migrated the PiPPy library to torch.distributed.pipelining Here is our new documentation: https://pytorch.org/docs/main/distributed.pipelining.html.

In section "Option 2", you can see:

The Pipe object provides a method for retrieving the “model partitions”: stage_mod : nn.Module = pipe.get_stage_module(stage_idx)

The return object is a nn.Module, so you can save it as you would with a regular module, such as:

torch.save(stage_mod, filepath)

or

torch.save(stage_mod.state_dict, filepath)

(Reference: https://pytorch.org/tutorials/beginner/saving_loading_models.html)

kwen2501 avatar Jun 10 '24 20:06 kwen2501

Sorry for replying late. We have migrated the PiPPy library to torch.distributed.pipelining Here is our new documentation: https://pytorch.org/docs/main/distributed.pipelining.html.

In section "Option 2", you can see:

The Pipe object provides a method for retrieving the “model partitions”: stage_mod : nn.Module = pipe.get_stage_module(stage_idx)

The return object is a nn.Module, so you can save it as you would with a regular module, such as:

torch.save(stage_mod, filepath)

or

torch.save(stage_mod.state_dict, filepath)

(Reference: https://pytorch.org/tutorials/beginner/saving_loading_models.html)

I think the question (at least for me) was if we could turn the model back into the non-pipelined version for modification and saving?

Xynonners avatar Jun 11 '24 09:06 Xynonners

Hmm, do you mean getting back the full model at the end of training, but before saving the final checkpoint? It might be hard, I think, because each stage's updated weights are now on different ranks. So unless we do an all-gather, the weight in the pipe object would only has part of it being up-to-date.

That said, imagine we would do a torch.load later, that would be a good time for gluing the model back together, because: (i) we have the full, original model; and (ii) PP does not change the FQN of the weights.

It is only a matter of loading from a single checkpoint file vs multiple checkpoint files. As far as I know, HF already uses multiple checkpoint files for large models.

kwen2501 avatar Jun 11 '24 15:06 kwen2501

OK, so here is what I want to do, Obtain gradients of each layer from each rank of the stage from the pipe object, and send it to the CPUs. Get some modifications done on the gradients on the CPU, then bring it back to the subsequent ranks of the pipeline stage and update the model with modified gradients. Is this possible with Pippy?

dheerj188 avatar Jun 12 '24 07:06 dheerj188