Merlin
Merlin copied to clipboard
[RMP] T4R quick fixes: MultiGPU data parallel training, multi-gpu .fit(), and Python based serving for Transformers4Rec
Problem:
We have customers who would like to use Transformers4Rec but are blocked by issues with our existing support for session-based models.
Goal:
- Unblock customer use cases so they can try out T4R to give us feedback
Constraints:
- We don't yet have Torchscript support (which is out of scope this issue)
Starting Point:
-
[ ] Enable Data Parallel training
- [x] Next item prediction - https://github.com/NVIDIA-Merlin/Transformers4Rec/issues/473 -
DataParallelworks if the model is wrapped manually by the user (i.e.model = torch.nn.DataParallel(model)for training, but that wrapping should happen automatically by the HF Trainer here. - [ ] Binary classification - https://github.com/NVIDIA-Merlin/Transformers4Rec/issues/423 - This task is about supporting
DataParallelwhen usingmodel.fit(), as withmodel.fit()the customer was able to build a model for binary classification (with single GPU)
- [x] Next item prediction - https://github.com/NVIDIA-Merlin/Transformers4Rec/issues/473 -
-
[x] Fix the serving sections of the existing T4R notebooks
- [x] https://github.com/NVIDIA-Merlin/NVTabular/pull/1628
- [x] https://github.com/NVIDIA-Merlin/Transformers4Rec/pull/468
Distributed Data Parallel training is something we want to do, but I don't think it's part of this effort to fix the immediate blockers. Does that match what y'all understand, @EvenOldridge @viswa-nvidia?
Binary Classification DP training would unblock the issue although DDP is preferred if more performant. More details are captured here
Either DistributedDataParallel training is part of the scope of the quick fixes or it isn't, and it sounds like it isn't so we should track that work somewhere (but not here.)
@gabrielspmoreira , please create a ticket for DP training binary classification and link it here.
@viswa-nvidia @EvenOldridge
Me, @rnyak, @sararb and @nzarif met today about the issues related to DataParallel.
We tested DataParallel for Next Item Prediction for one of the examples and it is not working, differently from what Sara found some weeks ago in another example.
So we have both Next Item Prediction and Binary Classification not working with DataParallel currently.
We have associated the issues for both in this RMP ticket description.
Should we remove the scope of DataParallel from this RMP and create another RMP ticket focused in DataParallel support (targeted for release 22.09)?
I don't think we should split the issue, let's just target this for 22.09