PiPPy
PiPPy copied to clipboard
Add support for loss function, update the PipelineStage output
Loss function is currently not implemented: https://github.com/pytorch/PiPPy/blob/f2e605d045cdc64cac31e2dd99a01706eb638a16/pippy/PipelineSchedule.py#L68-L73
We should add the loss function as an argument into PipelineSchedule.step(). This also means that we should change the output of forward()
:
- For training (loss fn passed in): Allow the user to choose whether model output returns a single loss value or a list of microbatch losses
- For inference (loss fn not passed in):
.step()
should return what the full model’s original model would return. This is only for the last stage, rest of stages should return None