torchscale
torchscale copied to clipboard
Introducing padding_mask to RetNet
As opposed to the other architectures in this package, RetNet doesn't have support for padding as far as I'm aware. I was thinking the best place to introduce it was along with the positional mask. Here we don't have the luxury of the softmax, so we can't simply mask with infinity in the relevant positions.
From my attempt, the parallel code would be something along the following (assuming left padding and a padding_mask shape of (bsz, seq_len):
sin = torch.sin(index[:, None] * self.angle[None, :])
cos = torch.cos(index[:, None] * self.angle[None, :])
mask = torch.tril(torch.ones(slen, slen).to(self.decay))
mask = torch.masked_fill(
index[:, None] - index[None, :], ~mask.bool(), float("inf")
)
mask = torch.masked_fill(mask.unsqueeze(0), padding_mask.unsqueeze(-1), float("inf"))
mask = torch.exp(mask.unsqueeze(1) * self.decay[:, None, None])
mask = mask / mask.sum(dim=-1, keepdim=True).sqrt()
mask = torch.nan_to_num(mask)
retention_rel_pos = ((sin, cos), mask)
This would imply expanding the mask here instead of broadcasting it in the forward method.
In the recurrent formulation, perhaps masking the scaling factor accordingly works?
def recurrent_forward(
self, qr, kr, v, decay, padding_mask=None, incremental_state=None
):
bsz = v.size(0)
v = v.view(bsz, self.num_heads, self.head_dim, 1)
kv = kr * v
if "prev_key_value" in incremental_state:
prev_kv = incremental_state["prev_key_value"]
prev_scale = incremental_state["scale"]
scale = prev_scale * decay + 1
kv = prev_kv * (prev_scale.sqrt() * decay / scale.sqrt()).view(
bsz, self.num_heads, 1, 1
) + kv / scale.sqrt().view(bsz, self.num_heads, 1, 1)
# kv = prev_kv * decay.view(self.num_heads, 1, 1) + kv
else:
scale = torch.ones_like(decay)
incremental_state["prev_key_value"] = kv
scale = scale.unsqueeze(0).masked_fill(padding_mask.unsqueeze(1), 0)
incremental_state["scale"] = scale
output = torch.sum(qr * kv, dim=3)
return output
I would like some help on this, perhaps the authors have a better approach? @donglixp @sunyt32