PiPPy icon indicating copy to clipboard operation
PiPPy copied to clipboard

Add support for loss function, update the PipelineStage output

Open H-Huang opened this issue 1 year ago • 0 comments

Loss function is currently not implemented: https://github.com/pytorch/PiPPy/blob/f2e605d045cdc64cac31e2dd99a01706eb638a16/pippy/PipelineSchedule.py#L68-L73

We should add the loss function as an argument into PipelineSchedule.step(). This also means that we should change the output of forward():

  • For training (loss fn passed in): Allow the user to choose whether model output returns a single loss value or a list of microbatch losses
  • For inference (loss fn not passed in): .step() should return what the full model’s original model would return. This is only for the last stage, rest of stages should return None

H-Huang avatar Feb 14 '24 21:02 H-Huang