GPTNeoX Flax support
@sanchit-gandhi
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.
Hey @OhadRubin - sorry for the late reply here! How are you getting on with this PR? I see that a lot of the modelling code has already been implemented - happy to do a first pass of this code if you want a preliminary review? We can also look to adding a test file and also make sure all the imports are properly defined (see https://huggingface.co/docs/transformers/add_new_model#stepbystep-recipe-to-add-a-model-to-transformers)
Offer for a review still stands if you'd like me to take a look @OhadRubin!
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Leaving this one open to the community to complete! Feel free to take up the PR if you come across this and are interested in a Flax model addition. @OhadRubin has made a nice start on porting the model, you can use the Flax GPT Neo code as reference for the fast attention mechanism we use in Transformers Flax: https://github.com/huggingface/transformers/blob/7d150d68ff6eaecc75b446aa06160b6bc8466e38/src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py#L108
So I suggested to change __call__ method of FlaxGPTNeoXAttention to below
def __call__(
self,
hidden_states,
attention_mask,
position_ids,
deterministic: bool = True,
init_cache: bool = False,
output_attentions: bool = False,
):
# Compute QKV
# Attention heads [batch, seq_len, hidden_size]
# --> [batch, seq_len, (num_heads * 3 * head_size)]
qkv = self.query_key_value(hidden_states)
batch, seq_len, _ = qkv.shape
# [batch, seq_len, (num_heads * 3 * head_size)]
# --> [batch, seq_len, num_heads, 3, head_size]
qkv = qkv.reshape([batch, seq_len,self.num_attention_heads,3,self.head_size])
# [batch, seq_len, num_heads, 3, head_size]
# --> [3,batch, seq_len, num_heads, head_size]
qkv = jnp.moveaxis(qkv, source=-2, destination=0)
# [3, batch, seq_len, num_heads, head_size]
# --> [3,batch, num_heads, seq_len, head_size]
qkv = jnp.swapaxes(qkv, 3, 2)
# [3,batch, num_heads, seq_len, head_size]
# --> 3 [batch, num_heads, seq_len, head_size]
query, key, value = qkv
query_rot = query[..., : self.rotary_ndims]
query_pass = query[..., self.rotary_ndims :]
key_rot = key[..., : self.rotary_ndims]
key_pass = key[..., self.rotary_ndims :]
cos, sin = self.rotary_emb(value, seq_len=seq_len)
query, key = apply_rotary_pos_embNP(query_rot, key_rot, cos, sin, position_ids)
query = jnp.concatenate((query, query_pass), axis=-1)
key = jnp.concatenate((key, key_pass), axis=-1)
# revert swap
query, key, value = jnp.swapaxes(query, 1, 2), jnp.swapaxes(key, 1, 2), jnp.swapaxes(value, 1, 2)
query_length, key_length = query.shape[1], key.shape[1]
if self.has_variable("cache", "cached_key"):
mask_shift = self.variables["cache"]["cache_index"]
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
causal_mask = lax.dynamic_slice(
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
)
else:
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
batch_size = hidden_states.shape[0]
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
attention_mask = combine_masks(attention_mask, causal_mask)
# During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step.
if self.has_variable("cache", "cached_key") or init_cache:
key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask)
# transform boolean mask into float mask
attention_bias = lax.select(
attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
)
attn_weights = dot_product_attention_weights(
query, #jnp.moveaxis(query, source=-3, destination=-2),
key, #jnp.moveaxis(key, source=-3, destination=-2),
bias=attention_bias,
dropout_rng=None,
# dropout_rate=self.config.attn_pdrop,
deterministic=deterministic,
dtype=jnp.promote_types(self.dtype, jnp.float32),
precision=None,
)
attn_output = jnp.einsum("bhqk,bkhd->bqhd", attn_weights, value)
attn_output = self._merge_heads(attn_output)
attn_output = self.dense(attn_output)
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
return outputs
This code doesn't differ much from FlaxGPTNeoSelfAttention. Which part is the fast attention mechanism? @sanchit-gandhi
The logic for constructing a static k/v cache and computing the attention weights efficiently is quite nicely summarised in the Flax GPT Neo attention layer: https://github.com/huggingface/transformers/blob/7d150d68ff6eaecc75b446aa06160b6bc8466e38/src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py#L108
We should strive to match this implementation as closely as possible (rather than optimising it again ourselves). It's largely inspired by the Flax attention implementation from T5x: https://github.com/google-research/t5x/blob/eb08ffbdec78e231aab1c747720ffb076f83bf18/t5x/examples/scalable_t5/layers.py#L196
This logic can be quite different from PyTorch attention layers, but is much better suited to the static nature of Flax and leverages the Flax dot product attention call. It's great if the current code is by-and-large the same as the reference Flax GPT Neo code, that's a big green tick as far as I'm concerned!
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.