MoChA-pytorch
MoChA-pytorch copied to clipboard
implementation of `safe_cumprod`
cumprod in the MoChA paper is defined to be exclusive, while the safe_cumprod
in this repo does not. Shouldn't it be:
def safe_cumprod(self, x, exclusive=False):
"""Numerically stable cumulative product by cumulative sum in log-space"""
bsz = x.size(0)
logsum = torch.cumsum(torch.log(torch.clamp(x, min=1e-20, max=1)), dim=1)
if exclusive:
logsum = torch.cat([torch.zeros(bsz, 1).to(logsum), logsum], dim=1)[:, :-1]
return torch.exp(logsum)
And in the function soft()
of MonotonicAttention
:
cumprod_1_minus_p = self.safe_cumprod(1 - p_select, exclusive=True)
@bo-son I think you're right