transformers icon indicating copy to clipboard operation
transformers copied to clipboard

GPTNeoX Flax support

Open OhadRubin opened this issue 2 years ago • 1 comments

@sanchit-gandhi

OhadRubin avatar Apr 23 '23 16:04 OhadRubin

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Hey @OhadRubin - sorry for the late reply here! How are you getting on with this PR? I see that a lot of the modelling code has already been implemented - happy to do a first pass of this code if you want a preliminary review? We can also look to adding a test file and also make sure all the imports are properly defined (see https://huggingface.co/docs/transformers/add_new_model#stepbystep-recipe-to-add-a-model-to-transformers)

sanchit-gandhi avatar May 05 '23 15:05 sanchit-gandhi

Offer for a review still stands if you'd like me to take a look @OhadRubin!

sanchit-gandhi avatar May 30 '23 18:05 sanchit-gandhi

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Jun 24 '23 15:06 github-actions[bot]

Leaving this one open to the community to complete! Feel free to take up the PR if you come across this and are interested in a Flax model addition. @OhadRubin has made a nice start on porting the model, you can use the Flax GPT Neo code as reference for the fast attention mechanism we use in Transformers Flax: https://github.com/huggingface/transformers/blob/7d150d68ff6eaecc75b446aa06160b6bc8466e38/src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py#L108

sanchit-gandhi avatar Jun 27 '23 17:06 sanchit-gandhi

So I suggested to change __call__ method of FlaxGPTNeoXAttention to below


    def __call__(
        self,
        hidden_states,
        attention_mask,
        position_ids,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
    ):  
        # Compute QKV
        # Attention heads [batch, seq_len, hidden_size]
        #   --> [batch, seq_len, (num_heads * 3 * head_size)]
        qkv = self.query_key_value(hidden_states)
        batch, seq_len, _ = qkv.shape
        # [batch, seq_len, (num_heads * 3 * head_size)]
        #   --> [batch, seq_len, num_heads, 3, head_size]
        qkv = qkv.reshape([batch, seq_len,self.num_attention_heads,3,self.head_size])
        # [batch, seq_len, num_heads, 3, head_size]
        #   --> [3,batch, seq_len, num_heads, head_size]
        qkv = jnp.moveaxis(qkv, source=-2, destination=0)
        # [3, batch, seq_len, num_heads, head_size]
        #   --> [3,batch, num_heads, seq_len, head_size]
        qkv = jnp.swapaxes(qkv, 3, 2)
        # [3,batch, num_heads, seq_len, head_size]
        #   --> 3 [batch, num_heads, seq_len, head_size]
        query, key, value = qkv

        query_rot = query[..., : self.rotary_ndims]
        query_pass = query[..., self.rotary_ndims :]
        key_rot = key[..., : self.rotary_ndims]
        key_pass = key[..., self.rotary_ndims :]

        cos, sin = self.rotary_emb(value, seq_len=seq_len)
        query, key = apply_rotary_pos_embNP(query_rot, key_rot, cos, sin, position_ids)
        query = jnp.concatenate((query, query_pass), axis=-1)
        key = jnp.concatenate((key, key_pass), axis=-1)

        # revert swap
        query, key, value = jnp.swapaxes(query, 1, 2), jnp.swapaxes(key, 1, 2), jnp.swapaxes(value, 1, 2)
        query_length, key_length = query.shape[1], key.shape[1]

        if self.has_variable("cache", "cached_key"):
            mask_shift = self.variables["cache"]["cache_index"]
            max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
            causal_mask = lax.dynamic_slice(
                self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
            )
        else:
            causal_mask = self.causal_mask[:, :, :query_length, :key_length]

        batch_size = hidden_states.shape[0]
        causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])

        attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
        attention_mask = combine_masks(attention_mask, causal_mask)

        # During fast autoregressive decoding, we feed one position at a time,
        # and cache the keys and values step by step.
        if self.has_variable("cache", "cached_key") or init_cache:
            key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask)
            
        # transform boolean mask into float mask
        attention_bias = lax.select(
            attention_mask > 0,
            jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
            jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
        )
        attn_weights = dot_product_attention_weights(
            query, #jnp.moveaxis(query, source=-3, destination=-2),
            key, #jnp.moveaxis(key, source=-3, destination=-2),
            bias=attention_bias,
            dropout_rng=None,
            # dropout_rate=self.config.attn_pdrop,
            deterministic=deterministic,
            dtype=jnp.promote_types(self.dtype, jnp.float32),
            precision=None,
        )
        attn_output = jnp.einsum("bhqk,bkhd->bqhd", attn_weights, value)
        attn_output = self._merge_heads(attn_output)
        attn_output = self.dense(attn_output)

        outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
        return outputs

HeegyuKim avatar Jul 12 '23 13:07 HeegyuKim

This code doesn't differ much from FlaxGPTNeoSelfAttention. Which part is the fast attention mechanism? @sanchit-gandhi

HeegyuKim avatar Jul 13 '23 13:07 HeegyuKim

The logic for constructing a static k/v cache and computing the attention weights efficiently is quite nicely summarised in the Flax GPT Neo attention layer: https://github.com/huggingface/transformers/blob/7d150d68ff6eaecc75b446aa06160b6bc8466e38/src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py#L108

We should strive to match this implementation as closely as possible (rather than optimising it again ourselves). It's largely inspired by the Flax attention implementation from T5x: https://github.com/google-research/t5x/blob/eb08ffbdec78e231aab1c747720ffb076f83bf18/t5x/examples/scalable_t5/layers.py#L196

This logic can be quite different from PyTorch attention layers, but is much better suited to the static nature of Flax and leverages the Flax dot product attention call. It's great if the current code is by-and-large the same as the reference Flax GPT Neo code, that's a big green tick as far as I'm concerned!

sanchit-gandhi avatar Jul 25 '23 13:07 sanchit-gandhi

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Nov 06 '23 08:11 github-actions[bot]