torchtune
torchtune copied to clipboard
[RFC] TransformerDecoder refactor
TransformerDecoder
Refactor
Authors:
- @SalmanMohammadi with input from:
- @kartikayk
- @ebsmothers
- @pbontrager
Summary
Refactoring TransformerDecoder
to offer additional flexibility for new use-cases.
Motivation/Prior art
- https://github.com/pytorch/torchtune/issues/968 - not sure if this is in scope.
- https://github.com/pytorch/torchtune/pull/840
Currently, TransformerDecoder
can only be used for language-modelling tasks. There is interest in additional use-cases, such as:
- Classifier models akin to HF's
AutoModelForSequenceClassification
- Value head models akin to TRL's
AutoModelForCausalLMWithValueHead
- Reward models with custom initialisation as in
lm_human_preference_details
- Using arbitrary hidden states from the transformer backbone
- ...
Such a refactor could allow users to easily adapt a transformer backbone for a variety of down-stream tasks; lm_human_preference_details
demonstrates how HF's transformer backbone can be extended in just 8 lines. While this refactor initially targets recipes which will be provided within Torchtune, such as PPO, or sequence-classification training recipes (e.g. for reward models), it would allow users to write custom recipes for many fine-tuning tasks whilst utilising underlying Torchtune features.
Proposed Implementation
A small-scale implementation for Mistral models exists in this draft PR. In summary:
-
TransformerDecoder
will refer to the underlying transformer backbone agnostic to its downstream task. It will return, by default, the final hidden layer as an output of shape[batch_size, sequence_len, embed_dim]
. -
TransformerDecoder
could support returning hidden states from arbitrary layers (or other useful outputs). Some input on how we allow users to specify this would be helpful. We probably just want to return the last hidden state by default. - We define
TransformerLM
as so:
class TransformerLM(nn.Module):
def __init__(self, decoder: TransformerDecoder, output: nn.Linear) -> None:
super().__init__()
self.decoder = decoder
self.output = output
def forward(self, tokens: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
# shape: [b, s, d]
h = self.decoder(tokens, input_pos)
# shape: [b, s, v]
lm_output = self.output(h).float()
return lm_output
- Model builders and component builders* now return
TransformerLM
instead ofTransformerDecoder
. Component builders look like:
decoder = TransformerDecoder(
tok_embeddings=tok_embeddings,
layer=layer,
num_layers=num_layers,
max_seq_len=max_seq_len,
num_heads=num_heads,
head_dim=head_dim,
norm=RMSNorm(embed_dim, eps=norm_eps
)
output_proj = nn.Linear(embed_dim, vocab_size, bias=False)
return TransformerLM(decoder=decoder, output=output_proj)
* mistral_classifier
should now return an instance ofTransformerClassifier
.
* Gemma models define a GemmaTransformerDecoder
which has a unique output projection, but shares the underlying logic of a TransformerDecoder
. We can go two routes here:
-
TransformerLM
accepts aUnion[nn.Module, Callable[torch.tensor]]
(or even justCallable[torch.tensor]
) asoutput
. Then, the Gemma component builder is:
tok_embeddings = nn.Embedding(...)
decoder = TransformerDecoder(...)
output = lambda a: F.linear(a, tok_embeddings.weight)
return TransformerLM(decoder, output=lm_output)
- If that looks a bit clunky/we don't like anonymous functions, we can just define a
GemmaTransformerLM
which looks like:
class GemmaTransformerLM(nn.Module):
def __init__(self, decoder: TransformerDecoder) -> None:
super().__init__()
self.decoder = decoder
def forward(self, tokens: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
# shape: [b, s, d]
h = self.decoder(tokens, input_pos)
return F.linear(h, self.decoder.tok_embeddings.weight)
Input on how this affects FDSP would be appreciated.
Components in the codebase I estimate will be impacted, and changes necessary, include:
- Model state dict conversion mappings require changes: in
torchtune.models.convert_weights.py
_FROM_META
and_FROM_HF
should prependdecoder
to destination keys e.g."model.layers.{}.self_attn.q_proj.weight": "decoder.layers.{}.attn.q_proj.weight"
, instead of"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attn.q_proj.weight"
.** - All references, documentation, and type hints which expect a
TransformerDecoder
as a complete language modelling transformer should now useTransformerLM
. A quick search in VScode shows ~100 references. - I think tests for TransformerDecoder should now test
TransformerLM
- input appreciated. - Additional tests for
TransformerDecoder
to test functionality without output projections should be added. - ... additional impact I may have missed.
** A note on backwards compatibility
Users who have previous trained models with TransformerDecoders
will have checkpoints saved with dict keys in the original format (without the decoder
prefix). Am I right in thinking they're going to have issues loading these checkpoints into our new models? This could be a pretty disruptive change - some users will have spent a lot of resources fine-tuning their models.
Could we provide some well-documented deprecation support for converting state dicts until some release version?