alibi error if log2(num_heads) is not an integer
hf alibi reference: https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/bloom/modeling_bloom.py#L82
axlearn alibi reference: https://github.com/apple/axlearn/blob/b92f666f661e6bacd757d7a37f1691d4f8985655/axlearn/common/adapter_torch.py#L561 https://github.com/apple/axlearn/blob/b92f666f661e6bacd757d7a37f1691d4f8985655/axlearn/common/attention.py#L3536
current mlx: https://github.com/ml-explore/mlx/blob/44c1ce5e6af2625571cd384e5be49e9778770ffc/python/mlx/nn/layers/positional_encoding.py#L184
candidate change:
class ALiBi(Module):
@staticmethod
def create_alibi_matrix(
q_sequence_length: int,
k_sequence_length: int,
num_heads: int,
offset: int,
dtype=mx.float32,
):
x1 = mx.arange(offset, q_sequence_length)
x2 = mx.arange(0, k_sequence_length)
distance_matrix = -mx.abs(
mx.expand_dims(x1[:, None] - x2[None, :], axis=(0, 1))
)
alibi_slope = ALiBi.create_alibi_slope(num_heads=num_heads, dtype=dtype)
alibi_mask = (distance_matrix * alibi_slope).astype(dtype)
return alibi_mask
@staticmethod
def alibi_get_slopes(num_heads: int):
"""Get the slopes for different attention heads defined in ALiBi paper.
This is a direct copy from ALiBi codebase.
Ref:
https://github.com/ofirpress/attention_with_linear_biases/tree/3b7c2eca/fairseq/models/transformer.py#L742-L752
Args:
num_heads: An integer for the number of attention heads.
Returns:
A tensor of slopes with shape of [num_heads]. Each value represents
a slope for one attention head.
"""
def get_slopes_power_of_2(n: int):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
if math.log2(num_heads).is_integer():
return get_slopes_power_of_2(num_heads)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
return (
get_slopes_power_of_2(closest_power_of_2)
+ ALiBi.alibi_get_slopes(2 * closest_power_of_2)[0::2][: num_heads - closest_power_of_2]
)
@staticmethod
def create_alibi_slope(num_heads, dtype):
slopes = ALiBi.alibi_get_slopes(num_heads)
out = mx.array(slopes, dtype=dtype)
return mx.expand_dims(out, axis=(-1, -2))
def __call__(self, attention_scores, offset=0, mask=None):
alibi_mask = ALiBi.create_alibi_matrix(
q_sequence_length=attention_scores.shape[-2] + offset,
k_sequence_length=attention_scores.shape[-1],
num_heads=attention_scores.shape[1],
offset=offset,
dtype=attention_scores.dtype,
)
if mask is not None:
alibi_mask = alibi_mask + mask
return attention_scores + alibi_mask
Are you saying that it gives the wrong result when num_heads is not a power of 2? Do you want to send a PR?