transducer icon indicating copy to clipboard operation
transducer copied to clipboard

Possible numerical error in log-norm computation

Open maxwellzh opened this issue 2 years ago • 2 comments

In current implementation, emissions and the predictions subtract their own maximum values respectively. But consider this case

emission[0, 0] = [0, -1000]
prediction[0, 0] = [-1000, 0]
->
# current impl
logNorm[0, 0, 0] = log(exp(emission[0, 0]-maxEs) @ exp(prediction[0, 0]-maxPs)) + maxEs + maxPs
                             = log(exp([0, -1000]) @ exp([-1000, 0]))
                             = log([1, exp(-1000)] @ [exp(-1000), 1])  <-- exp(-1000) would give 0 in FP32 precision
                             = log(0)
                             = -inf

# correct result
logNorm[0, 0, 0] = log(2) - 1000

I also tried convert emission and prediction into FP64 before calculating the logNorm, but it still didn't work in my asr experiment.

The broadcast-sum way is more numerical stable, but would consume O(B*T*U*V) memory.

logNorm = torch.log_softmax(emission.unsqueeze(2) + prediction.unsqueeze(1), dim=-1)

https://github.com/awni/transducer/blob/e90c6f45f10ccb404befddb0a99463fa6cb2e753/transducer/torch_binding.py#L162-L167

maxwellzh avatar Feb 17 '23 06:02 maxwellzh

There is a similar loss function impl from K2

https://github.com/danpovey/fast_rnnt/blob/2c2dc4b96a6b9a8c0dbedada94cdee53a9337402/fast_rnnt/python/fast_rnnt/rnnt_loss.py#L159-L162

It seems they just add a small value to avoid log(0), which would also introduce errors in calculation.

@pkufool @csukuangfj Could you take a look at this? I believe the implementation from k2 would also faces this issue.

maxwellzh avatar Feb 17 '23 06:02 maxwellzh

The only way I could figure out is to implement a custom functionlogmmexp(a, b), where we need to compute a_k + b_k twice (Since we don't want to store the super large intermedia tensor).

At first time, reduce a_k + b_k to obtain max values, then we can get maxSum in shape (B, T, U); At second time, compute logsumexp(a_k + b_k - maxSum)+maxSum at each position.

Update: Just found people working at it https://github.com/pytorch/pytorch/issues/54064

maxwellzh avatar Feb 17 '23 06:02 maxwellzh