x-transformers
x-transformers copied to clipboard
ALiBi: buffered bias slicing gets confusing when `i != j`
https://github.com/lucidrains/x-transformers/blame/f71f3279de539ddb0f58bd50f22b84b6920e0ef6/x_transformers/x_transformers.py#L336
Hi! I've noticed that slicing a buffered ALiBi bias can get confusing in case of i != j
. Say we save the bias when j - i = 1
, so zeros will be on the diagonal 1 above the main diagonal. Then we process a smaller sequence where i = j
and the zeros diagonal should be the main one but the bias is cached so they are still one element above. It works fine because AliBi encoding is shift invariant but maybe it would me more clear to always have zeros on the main diagonal as in the i == j
case?