mlx icon indicating copy to clipboard operation
mlx copied to clipboard

alibi error if log2(num_heads) is not an integer

Open gyin94 opened this issue 1 year ago • 4 comments

hf alibi reference: https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/bloom/modeling_bloom.py#L82

axlearn alibi reference: https://github.com/apple/axlearn/blob/b92f666f661e6bacd757d7a37f1691d4f8985655/axlearn/common/adapter_torch.py#L561 https://github.com/apple/axlearn/blob/b92f666f661e6bacd757d7a37f1691d4f8985655/axlearn/common/attention.py#L3536

current mlx: https://github.com/ml-explore/mlx/blob/44c1ce5e6af2625571cd384e5be49e9778770ffc/python/mlx/nn/layers/positional_encoding.py#L184

gyin94 avatar Jan 02 '24 09:01 gyin94

candidate change:

class ALiBi(Module):
    @staticmethod
    def create_alibi_matrix(
        q_sequence_length: int,
        k_sequence_length: int,
        num_heads: int,
        offset: int,
        dtype=mx.float32,
    ):
        x1 = mx.arange(offset, q_sequence_length)
        x2 = mx.arange(0, k_sequence_length)
        distance_matrix = -mx.abs(
            mx.expand_dims(x1[:, None] - x2[None, :], axis=(0, 1))
        )
        alibi_slope = ALiBi.create_alibi_slope(num_heads=num_heads, dtype=dtype)
        alibi_mask = (distance_matrix * alibi_slope).astype(dtype)
        return alibi_mask

    @staticmethod
    def alibi_get_slopes(num_heads: int):
        """Get the slopes for different attention heads defined in ALiBi paper.

        This is a direct copy from ALiBi codebase.
        Ref:
        https://github.com/ofirpress/attention_with_linear_biases/tree/3b7c2eca/fairseq/models/transformer.py#L742-L752

        Args:
            num_heads: An integer for the number of attention heads.

        Returns:
            A tensor of slopes with shape of [num_heads]. Each value represents
            a slope for one attention head.
        """

        def get_slopes_power_of_2(n: int):
            start = 2 ** (-(2 ** -(math.log2(n) - 3)))
            ratio = start
            return [start * ratio**i for i in range(n)]

        if math.log2(num_heads).is_integer():
            return get_slopes_power_of_2(num_heads)
        else:
            closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
            return (
                get_slopes_power_of_2(closest_power_of_2)
                + ALiBi.alibi_get_slopes(2 * closest_power_of_2)[0::2][: num_heads - closest_power_of_2]
            )
    
    @staticmethod
    def create_alibi_slope(num_heads, dtype):
        slopes = ALiBi.alibi_get_slopes(num_heads)
        out = mx.array(slopes, dtype=dtype)
        return mx.expand_dims(out, axis=(-1, -2))

    def __call__(self, attention_scores, offset=0, mask=None):
        alibi_mask = ALiBi.create_alibi_matrix(
            q_sequence_length=attention_scores.shape[-2] + offset,
            k_sequence_length=attention_scores.shape[-1],
            num_heads=attention_scores.shape[1],
            offset=offset,
            dtype=attention_scores.dtype,
        )
        if mask is not None:
            alibi_mask = alibi_mask + mask
        return attention_scores + alibi_mask

gyin94 avatar Jan 02 '24 10:01 gyin94

Are you saying that it gives the wrong result when num_heads is not a power of 2? Do you want to send a PR?

awni avatar Jan 02 '24 14:01 awni