ColossalAI
ColossalAI copied to clipboard
Need more runtime hooks during a training step
Describe the feature
In the PyTorch fashion, we usually train a model like
for x, y in dataloader:
... # do something before forward
out = model(x)
loss = criterion(out, y)
... # do something between forward and backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
... # do something after backward
In the trainer of Colossal-AI, it is only allowed to add hooks before and after a training step, while users cannot customize the behaviors between fetching an input batch and forward pass, or between forward and backward pass. Also, since the OpHook is applied to modules recursively, it is not appropriate for this issue either. We may need to add at least two more hooks as mentioned above.
Data fetching, forward pass and back prop are implemented in the schedule. Thus, I don't think they are trainer hooks. Is there any use case for such hooks?
Data fetching, forward pass and back prop are implemented in the schedule. Thus, I don't think they are trainer hooks. Is there any use case for such hooks?
Correct, and that is why I didnt call them trainer hooks. There are some cases that this can be helpful, like splitting the batch in tensor parallelisms, applying mixup, etc. And the main issue is, such customization is allowed by PyTorch but currently not allowed by Colossal-AI.
I do agree that this is not supported by Colossal-AI. I found these use cases are indeed not related to schedule if we are adding hooks to schedule.Splitting the batch can be done at the dataset/dataloader or the first layer of model and applying mixup should be done at the dataset/dataloader.
I do agree that this is not supported by Colossal-AI. I found these use cases are indeed not related to schedule if we are adding hooks to schedule.Splitting the batch can be done at the dataset/dataloader or the first layer of model and applying mixup should be done at the dataset/dataloader.
I am also not sure how to implement such hooks. Just open the issue to collect ideas.
I think if we can abstract this part, it will provide some flexibility and extensibility to the schedule class. For example, there is a batch_data_process_func
parameter to allow some customization (e.g. apply mixup if a user really wants to).
We have updated a lot. This issue was closed due to inactivity. Thanks.