PiPPy icon indicating copy to clipboard operation
PiPPy copied to clipboard

PipelineStage/Schedule issues

Open wconstab opened this issue 1 year ago • 2 comments

Just dumping issues here as I find them (applying PipelineStage to torchtrain) Stage

  1. fwd_inputs all forced to have 'requires_grad=True' -- why? what's our design here? freqs_cis could be passed from stage0 to stage1 but is an input value from dataloader and should not require grads
  2. complex numbers aren't supported by send/recv, but freqs_cis is a complex tensor. We can work around this in the model design by avoiding sending this value, but should PPStage support this automatically?
  3. manually computing input sizes for each stage is tedious. doing shape inference inside init is also tedious since the 'inputs_meta' you pass in has to match the meta-ness (or cpu/cuda-ness) of the model you pass in. Can we just do shape inference lazily on first forward call?
  4. dynamic shapes- what if seqlen changes per batch?
  5. Loss fn is not yet supported
  6. backward isn't implement correctly afaiu. see rewrite in whc/pp branch, fixes (a) .grad() wont set .grad on W's but .backward will; (b) funny issues with requires-gradness on inputs, disappeared after i simplified

Schedule

  1. loss value (for logging) is not returned from .step()
  2. nicer to pass List[Stage] to Schedule() or pass some kind of container (e.g. PipPy's 'Pipe' concept)
  3. need support for FSDP/DDP - e.g. no_grad() for both and more for fsdp

Checkpointing

  1. how to deal with checkpointing a list of stages per rank? (user calls .save / load APIs on each chunk separately, or do we provide an api on a top-level container (e.g. like the Pipe)

wconstab avatar Feb 09 '24 19:02 wconstab

Thanks for the compiling this, I also hit a lot of these issues mentioned and Ke and I had some ideas so we can discuss it and update this issue. Also some issues of mine:

Stage

  1. need to add mesh / process group support to simplify stage constructor. Stages need to be updated away from recv from rank-1 and send to rank+1 and instead use pp dimension.
  2. improve erroring on fwd(), usually from logs we can tell which rank threw the error, but with virtual pipeline stages it is hard to tell which stage_id / model chunk threw the error.

Schedule

  1. microbatches is currently expected in the form [mb_0, mb_1, mb_2, ... mb_n] where one microbatch looks like (arg_0, arg_1, arg_2, ... ,arg_n). Is there a way to make this less clunky and error prone?
  2. data loading is only needed for stage 0 of the pipeline, rest of the stages dont need pass in the microbatches.
  3. some schedules take list of stages and some take a single stage, consolidate this.
  4. How does the user create a "manual pipeline splitter" correctly? For more complicated schedules (interleaved 1F1B), the users will need to know which virtual stages belong on which rank and also make sure they are splitting their model in the correct order. Is there a way to reduce this overhead?

H-Huang avatar Feb 09 '24 20:02 H-Huang

I've compiled these into a few concrete work items which could resolve multiple issues:

  1. https://github.com/pytorch/PiPPy/issues/945
  2. https://github.com/pytorch/PiPPy/issues/946
  3. https://github.com/pytorch/PiPPy/issues/947
  4. https://github.com/pytorch/PiPPy/issues/949
  5. https://github.com/pytorch/PiPPy/issues/948
  6. https://github.com/pytorch/PiPPy/issues/950

H-Huang avatar Feb 14 '24 21:02 H-Huang