Merlin icon indicating copy to clipboard operation
Merlin copied to clipboard

[RMP] Add PyTorch backend in Merlin Models

Open marcromeyn opened this issue 3 years ago • 3 comments

Problem:

We are currently in a situation where some customers are using merlin-models & some T4Rec to train models. The APIs of these 2 tools have diverged quite dramatically and some features (like extracting embeddings out of models) are only supported in Merlin Models. Both tools require some work in order to have easy to use APIs.

On the Merlin models side, we are in a in-between state where (because of time pressure) there are a bunch of V1 & V2 classes. We would like to migrate all our users to the V2 classes (while removing V2 from the name) & deprecate the old classes.

On the T4Rec side, we would like to keep using this project for session-based models in PyTorch because of the traction we've got. The idea would be to break out the core model-building parts (block-API) in favor of the pytorch-backend of Merlin Models. This roadmap-level ticket focusses on this new pytorch-backend, integration into T4Rec is left out for later. The first major deliverable of this backend is the creation of retrieval models, this because we typically frame session-based models as retrieval-models

Goal:

Reach feature parity & rough API parity between TF & PyTorch backends in Merlin models. This roadmap ticket will be around PyTorch, a future roadmap ticket will focus on TF.

New Functionality

  • Models
    • PyTorch: New backend, build from the ground up based on the TF implementation. Port the all retrieval examples.

Constraints:

  • We focus on just retrieval-models. Ranking-models will be tackled in a future roadmap ticket.
  • Migrating T4Rec to the new Block-API is future work and will be captured in another roadmap-level ticket.

Starting Point:

In order to properly plan out the work, a dev-branch is created to answer various design-questions around being able to create retrieval-models in PyTorch. This has lead to a rough MVP that contains all the major pieces. This has also given us a better idea how to break things down to turn the MVP into a fully fleshed product.

We are planning to have people work in parallel on 4 different major parts: inputs, outputs, models & masking.

Implement base-classes of block-API in PyTorch

People: @marcromeyn

Currently the block-API is T4Rec is using a similar design to Keras to allow for modules that lazily initialize their variables. We would like to deprecate this in favor of a native way to achieve the same thing that could launched recently.

  • [x] https://github.com/NVIDIA-Merlin/models/pull/1087
  • [x] https://github.com/NVIDIA-Merlin/models/pull/1088
  • [x] https://github.com/NVIDIA-Merlin/models/pull/1090
  • [x] https://github.com/NVIDIA-Merlin/models/pull/1091
  • [x] https://github.com/NVIDIA-Merlin/models/pull/1092
  • [x] https://github.com/NVIDIA-Merlin/models/pull/1095
  • [x] https://github.com/NVIDIA-Merlin/models/pull/1109
  • [x] https://github.com/NVIDIA-Merlin/models/pull/1112
  • [x] https://github.com/NVIDIA-Merlin/models/pull/1096
  • [x] https://github.com/NVIDIA-Merlin/models/pull/1093

Masking

People: @sararb, @gabrielspmoreira & @marcromeyn

This work is dependent on answering the design-question how to handle ragged-tensors.

Tasks: TODO

Input-blocks

People: @marcromeyn

PyTorch

Starting point: MVP

  • [ ] Implement Continuous & Embeddings
  • [ ] Implement TabularInputBlock
  • [ ] Implement Encoder
  • [ ] Add support for sequential-features in input-blocks
  • [ ] Do performance testing of holding multiple features in a single embedding-table

Output-blocks

People: @edknv & @marcromeyn

  • [x] (https://github.com/NVIDIA-Merlin/models/pull/1099)
  • [x] (https://github.com/NVIDIA-Merlin/models/pull/1115)
  • [ ] Port CategoricalOutput
  • [ ] Port ContrastiveOutput + negative samplers
  • [ ] Port TopKOutput
  • [ ] Port OutputBlock (for multi-task learning)

Models

People: @edknv & @marcromeyn

Starting point: MVP

One of the leading questions in the initial experimentation phase was to figure out if we can leverage PyTorch lightning for a high-level training-API (similar to how we use Keras on the TF-side). We are confident that PyTorch Lightning is the right path forward.

  • [ ] Implement Model class (using PyTorch lightning)
  • [ ] Create custom Trainer that can handle multi-GPU with data-loader
  • [ ] Implement RetrievalModel class
  • [ ] Port MatrixFactorizationModel, TwoTowerModel & YoutubeDNNRetrievalModel

Documentation

  • [ ] Create a migration guide from Transformers4Rec to Merlin Models session-based PyTorch API

marcromeyn avatar Apr 03 '23 11:04 marcromeyn

@marcromeyn , please create the tasks for PyT and create the tickets so that we can assign them

viswa-nvidia avatar May 16 '23 17:05 viswa-nvidia

@marcromeyn @gabrielspmoreira can you work to split this up into: Ranking, Retrieval and Session based

EvenOldridge avatar Jun 27 '23 16:06 EvenOldridge

Ranking ticket is here: https://github.com/NVIDIA-Merlin/Merlin/issues/1044

marcromeyn avatar Jul 03 '23 16:07 marcromeyn