Merlin icon indicating copy to clipboard operation
Merlin copied to clipboard

[Task] Port the transformer blocks from Transformers4Rec (PyTorch) to Merlin Models (Tensorflow)

Open karlhigley opened this issue 3 years ago • 1 comments

Problem:

The stable API in Transformers4Rec is based on PyTorch and includes all components used to run session-based research experiments reported in T4Rec paper. On the other hand the Merlin Models does not support a stable PyTorch API yet.

Goal:

  • The goal of this work is to port all t4rec blocks needed for defining transformer-based recommendation models (PyTorch implementation) into Merlin Models (Tensorflow implementation).

  • This work is not about improving t4rec current API or adding new blocks other than the existing ones

Constraints:

  • The Pytorch T4Rec API is inheriting from HuggingFace trainer class for supporting optimized techniques such as fp16, multi-gpu, early-stopping… We need to provide clear guidance of how to set up these techniques using the Keras fit method.

  • Current T4Rec implementation is using the schema class from old merlin_standardlib, the migration to use merlin-core should happen before porting to Merlin Models to makes sure all blocks are correctly working with the new Schema class.

Starting Point:

  • [ ] Implement MaskingBlock : Causal LM, Masked LM, Permutation LM, and Replacement Token Detection

  • [ ] Port the Transformer-block class related to HugginFace architectures adapted for next item prediction (Link to HF layers)

  • [ ] Implementation of Transformer block based in the configs defined in HuggingFace - transformers-based library. Note: This is an example of how Transformer-based architectures are implemented in HF (as keras layers).

  • [ ] Support setting the masking task within model.compile()

  • [ ] Create example or/and guidelines of how to train transformer-based models with techniques supported in Pytorch trainer class: Early stopping, fp16, lr scheduler, model checkpoints

karlhigley avatar May 04 '22 16:05 karlhigley

@sararb , could you please help to define this ticket.

viswa-nvidia avatar Jul 27 '22 17:07 viswa-nvidia

This issue is fixed in #771 and #780

sararb avatar Oct 06 '22 22:10 sararb

I moved the task of creating the example to a separate ticket NVIDIA-Merlin/models#791 (expected in 22.11)

sararb avatar Oct 06 '22 22:10 sararb