torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

[RFC] TransformerDecoder refactor

Open SalmanMohammadi opened this issue 9 months ago • 1 comments

TransformerDecoder Refactor

Authors:

  • @SalmanMohammadi with input from:
  • @kartikayk
  • @ebsmothers
  • @pbontrager

Summary

Refactoring TransformerDecoder to offer additional flexibility for new use-cases.

Motivation/Prior art

  • https://github.com/pytorch/torchtune/issues/968 - not sure if this is in scope.
  • https://github.com/pytorch/torchtune/pull/840

Currently, TransformerDecoder can only be used for language-modelling tasks. There is interest in additional use-cases, such as:

Such a refactor could allow users to easily adapt a transformer backbone for a variety of down-stream tasks; lm_human_preference_details demonstrates how HF's transformer backbone can be extended in just 8 lines. While this refactor initially targets recipes which will be provided within Torchtune, such as PPO, or sequence-classification training recipes (e.g. for reward models), it would allow users to write custom recipes for many fine-tuning tasks whilst utilising underlying Torchtune features.

Proposed Implementation

A small-scale implementation for Mistral models exists in this draft PR. In summary:

  • TransformerDecoder will refer to the underlying transformer backbone agnostic to its downstream task. It will return, by default, the final hidden layer as an output of shape [batch_size, sequence_len, embed_dim].
  • TransformerDecoder could support returning hidden states from arbitrary layers (or other useful outputs). Some input on how we allow users to specify this would be helpful. We probably just want to return the last hidden state by default.
  • We define TransformerLM as so:
class TransformerLM(nn.Module):
    def __init__(self, decoder: TransformerDecoder, output: nn.Linear) -> None:
        super().__init__()
        self.decoder = decoder
        self.output = output
	
    def forward(self, tokens: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:	
        # shape: [b, s, d]
        h = self.decoder(tokens, input_pos)
        # shape: [b, s, v]
        lm_output = self.output(h).float()
        return lm_output
  • Model builders and component builders* now return TransformerLM instead of TransformerDecoder. Component builders look like:
decoder = TransformerDecoder(
    tok_embeddings=tok_embeddings,
    layer=layer,
    num_layers=num_layers,
    max_seq_len=max_seq_len,
    num_heads=num_heads,
    head_dim=head_dim,
    norm=RMSNorm(embed_dim, eps=norm_eps
)
output_proj = nn.Linear(embed_dim, vocab_size, bias=False)
return TransformerLM(decoder=decoder, output=output_proj)

* mistral_classifier should now return an instance ofTransformerClassifier. * Gemma models define a GemmaTransformerDecoder which has a unique output projection, but shares the underlying logic of a TransformerDecoder. We can go two routes here:

  1. TransformerLM accepts a Union[nn.Module, Callable[torch.tensor]] (or even just Callable[torch.tensor]) as output. Then, the Gemma component builder is:
tok_embeddings = nn.Embedding(...)
decoder = TransformerDecoder(...)
output = lambda a: F.linear(a, tok_embeddings.weight)
return TransformerLM(decoder, output=lm_output)
  1. If that looks a bit clunky/we don't like anonymous functions, we can just define a GemmaTransformerLM which looks like:
class GemmaTransformerLM(nn.Module):
    def __init__(self, decoder: TransformerDecoder) -> None:
        super().__init__()
        self.decoder = decoder
	
    def forward(self, tokens: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:	
        # shape: [b, s, d]
        h = self.decoder(tokens, input_pos)
        return F.linear(h, self.decoder.tok_embeddings.weight)

Input on how this affects FDSP would be appreciated.

Components in the codebase I estimate will be impacted, and changes necessary, include:

  • Model state dict conversion mappings require changes: in torchtune.models.convert_weights.py _FROM_META and _FROM_HFshould prepend decoder to destination keys e.g. "model.layers.{}.self_attn.q_proj.weight": "decoder.layers.{}.attn.q_proj.weight", instead of "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attn.q_proj.weight".**
  • All references, documentation, and type hints which expect a TransformerDecoder as a complete language modelling transformer should now use TransformerLM. A quick search in VScode shows ~100 references.
  • I think tests for TransformerDecoder should now test TransformerLM - input appreciated.
  • Additional tests for TransformerDecoder to test functionality without output projections should be added.
  • ... additional impact I may have missed.

** A note on backwards compatibility

Users who have previous trained models with TransformerDecoders will have checkpoints saved with dict keys in the original format (without the decoder prefix). Am I right in thinking they're going to have issues loading these checkpoints into our new models? This could be a pretty disruptive change - some users will have spent a lot of resources fine-tuning their models.

Could we provide some well-documented deprecation support for converting state dicts until some release version?

SalmanMohammadi avatar May 24 '24 12:05 SalmanMohammadi