PiPPy
PiPPy copied to clipboard
PipelineStage/Schedule issues
Just dumping issues here as I find them (applying PipelineStage to torchtrain) Stage
- fwd_inputs all forced to have 'requires_grad=True' -- why? what's our design here?
freqs_cis
could be passed from stage0 to stage1 but is an input value from dataloader and should not require grads - complex numbers aren't supported by send/recv, but
freqs_cis
is a complex tensor. We can work around this in the model design by avoiding sending this value, but should PPStage support this automatically? - manually computing input sizes for each stage is tedious. doing shape inference inside init is also tedious since the 'inputs_meta' you pass in has to match the meta-ness (or cpu/cuda-ness) of the model you pass in. Can we just do shape inference lazily on first forward call?
- dynamic shapes- what if seqlen changes per batch?
- Loss fn is not yet supported
- backward isn't implement correctly afaiu. see rewrite in
whc/pp
branch, fixes (a) .grad() wont set .grad on W's but .backward will; (b) funny issues with requires-gradness on inputs, disappeared after i simplified
Schedule
- loss value (for logging) is not returned from .step()
- nicer to pass List[Stage] to Schedule() or pass some kind of container (e.g. PipPy's 'Pipe' concept)
- need support for FSDP/DDP - e.g. no_grad() for both and more for fsdp
Checkpointing
- how to deal with checkpointing a list of stages per rank? (user calls .save / load APIs on each chunk separately, or do we provide an api on a top-level container (e.g. like the Pipe)
Thanks for the compiling this, I also hit a lot of these issues mentioned and Ke and I had some ideas so we can discuss it and update this issue. Also some issues of mine:
Stage
- need to add mesh / process group support to simplify stage constructor. Stages need to be updated away from recv from rank-1 and send to rank+1 and instead use pp dimension.
- improve erroring on
fwd()
, usually from logs we can tell which rank threw the error, but with virtual pipeline stages it is hard to tell which stage_id / model chunk threw the error.
Schedule
- microbatches is currently expected in the form [mb_0, mb_1, mb_2, ... mb_n] where one microbatch looks like (arg_0, arg_1, arg_2, ... ,arg_n). Is there a way to make this less clunky and error prone?
- data loading is only needed for stage 0 of the pipeline, rest of the stages dont need pass in the microbatches.
- some schedules take list of stages and some take a single stage, consolidate this.
- How does the user create a "manual pipeline splitter" correctly? For more complicated schedules (interleaved 1F1B), the users will need to know which virtual stages belong on which rank and also make sure they are splitting their model in the correct order. Is there a way to reduce this overhead?
I've compiled these into a few concrete work items which could resolve multiple issues:
- https://github.com/pytorch/PiPPy/issues/945
- https://github.com/pytorch/PiPPy/issues/946
- https://github.com/pytorch/PiPPy/issues/947
- https://github.com/pytorch/PiPPy/issues/949
- https://github.com/pytorch/PiPPy/issues/948
- https://github.com/pytorch/PiPPy/issues/950