Possible numerical error in log-norm computation
In current implementation, emissions and the predictions subtract their own maximum values respectively. But consider this case
emission[0, 0] = [0, -1000]
prediction[0, 0] = [-1000, 0]
->
# current impl
logNorm[0, 0, 0] = log(exp(emission[0, 0]-maxEs) @ exp(prediction[0, 0]-maxPs)) + maxEs + maxPs
= log(exp([0, -1000]) @ exp([-1000, 0]))
= log([1, exp(-1000)] @ [exp(-1000), 1]) <-- exp(-1000) would give 0 in FP32 precision
= log(0)
= -inf
# correct result
logNorm[0, 0, 0] = log(2) - 1000
I also tried convert emission and prediction into FP64 before calculating the logNorm, but it still didn't work in my asr experiment.
The broadcast-sum way is more numerical stable, but would consume O(B*T*U*V) memory.
logNorm = torch.log_softmax(emission.unsqueeze(2) + prediction.unsqueeze(1), dim=-1)
https://github.com/awni/transducer/blob/e90c6f45f10ccb404befddb0a99463fa6cb2e753/transducer/torch_binding.py#L162-L167
There is a similar loss function impl from K2
https://github.com/danpovey/fast_rnnt/blob/2c2dc4b96a6b9a8c0dbedada94cdee53a9337402/fast_rnnt/python/fast_rnnt/rnnt_loss.py#L159-L162
It seems they just add a small value to avoid log(0), which would also introduce errors in calculation.
@pkufool @csukuangfj Could you take a look at this? I believe the implementation from k2 would also faces this issue.
The only way I could figure out is to implement a custom functionlogmmexp(a, b), where we need to compute a_k + b_k twice (Since we don't want to store the super large intermedia tensor).
At first time, reduce a_k + b_k to obtain max values, then we can get maxSum in shape (B, T, U);
At second time, compute logsumexp(a_k + b_k - maxSum)+maxSum at each position.
Update: Just found people working at it https://github.com/pytorch/pytorch/issues/54064