TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

Generation tutorial for Gemma model

Open pggPL opened this issue 9 months ago • 8 comments

Description

I added the tutorials with finetuning and with generation for the Gemma model. Moreover I added few features that were neccessary to make my tutorials work.

Type of change

  • [ ] Documentation change (change only to the documentation, either a fix or a new content)
  • [ ] Bug fix (non-breaking change which fixes an issue)
  • [ ] New feature (non-breaking change which adds functionality)
  • [x] Breaking change (fix or feature that would cause existing functionality to not work as expected)

Changes

  • Two new notebooks in the docs: one with finetuning for Gemma - analogous to the tutorial with Llama, one with generation for the Gemma,
  • Generalized the kernel for rotary positional encoding to allow the sequences to start with different encoding positions,
  • Added the kernel to effectively save key and value to the kv_cache,
  • Expanded the class InferenceParams - which is responsible for caching k and v,
  • Changed DotProductAttention to run THD attention when there are ragged tensors in kv_cache.

Checklist:

  • [x] I have read and followed the contributing guidelines
  • [x] The functionality is complete
  • [x] I have commented my code, particularly in hard-to-understand areas
  • [x] I have made corresponding changes to the documentation
  • [x] My changes generate no new warnings
  • [x] I have added tests that prove my fix is effective or that my feature works
  • [x] New and existing unit tests pass locally with my changes

Future work:

TransformerLayer does not support thd and it is a problem. The solutions right now works that way:

  • one need to call setup_before_new_input before forward to indicate the sequence lengths,
  • then one passes forward with TransformerLayer with self_attn_format='thd' and padded sequences with shape bshd,
  • all layers get input is [b,s,*] format, not in [t, *] (including attention)
  • InferenceParams.retrieve_from_kv_cache retrieves key_layer in bshd or ths format depending of inference_params.qkv_format,

As can be seen, it is quite messy workaround. How I think it should be done in the future:

  • TransformerLayer supports thd and we do not need setup_before_new_input at all,
  • InferenceParams store lengths of cached sequences for each layer,
  • for each TransformerLayer invocation, provided sequences are copied to the cache and lengths of cached sequences for this layer are updated,

To do this one will need to remove save_to_kv_cache() kernel and write save_to_kv_cache_sbhd() and save_to_kv_cache_thd() (no bshd, because cache has shape sbhd both for bshd abd sbhd). Logic of updating sequence lenghts needs to be moved from the setup_before_new_input into save_to_kv_cache.

It is worth noting that we need to take care of backwards compatibility. Right now generation works only for bsdh/sbdh and one needs to manually update self.sequence_len_offset. I think we can write setter which will update statistic for each of the layer when sequence_len_offset will be changed.

If TransformerLayer support of thd will not be added in near future, I propose to write sequence lengths into inference_params.cu_seqlens, note that it is beta (in the future probably cu_seqlens will be added as an argument to the TransformerLayer). Then use TransformerLayer with bsdh. If MultiHeadAttention gets inference_params.cu_seqlens != None, it converts bshd with padding into thd, calls save_to_kv_cache etc. and run DotProductAttention with a thd and then converts output back to the bshd.

pggPL avatar May 01 '24 18:05 pggPL

/te-ci pytorch

sudhakarsingh27 avatar May 30 '24 00:05 sudhakarsingh27

/te-ci pytorch

sudhakarsingh27 avatar Jun 03 '24 18:06 sudhakarsingh27

/te-ci pytorch

sudhakarsingh27 avatar Jun 03 '24 21:06 sudhakarsingh27

@cyanguwa, need your help in reviewing the attention related files (attention.py test_fused_rope.py, etc.)

sudhakarsingh27 avatar Jun 03 '24 21:06 sudhakarsingh27

/te-ci pytorch

sudhakarsingh27 avatar Jun 04 '24 21:06 sudhakarsingh27

/te-ci pytorch

sudhakarsingh27 avatar Jun 04 '24 21:06 sudhakarsingh27

/te-ci pytorch

phu0ngng avatar Jun 05 '24 17:06 phu0ngng

/te-ci pytorch

sudhakarsingh27 avatar Jun 07 '24 00:06 sudhakarsingh27