PyTorch module emits UserWarning
When running nara_wpe.torch_wpe.wpe_v6 with PyTorch 1.11, I'm seeing the following warning:
torch.linalg.solve has its arguments reversed and does not return the LU factorization.
To get the LU factorization see torch.lu, which can be used with torch.lu_solve or torch.lu_unpack.
X = torch.solve(B, A).solution
should be replaced with
X = torch.linalg.solve(A, B) (Triggered internally at ../aten/src/ATen/native/BatchLinearAlgebra.cpp:766.)
G, _ = torch.solve(P, R)
It seems modifying G, _ = torch.solve(P, R) to G = torch.linalg.solve(R, P) does the trick.
It'd also be cool if you could release v0.0.8 in PyPI which contains the torch_wpe submodule; I'm adding support for WPE as data augmentation in Lhotse which leverages your library here https://github.com/lhotse-speech/lhotse/pull/663
I published that torch code, because the pytorch people wanted to have an example, that uses complex numbers. In #46 I accidentally merged the code and forgot it. I checked the code, and it contains indeed an error. The hermite operation does no conjugate. I will fix it.
In General, I would recommend using the numpy code instead of the torch code, it is
- much more tested and used,
- has different implementations that are faster in different situations or use less memory,
- has a stabilization implemented, (Not sure, but I got the impression, that numpy has better implementations for solve, ...)
- and it is not constraint to be differentiable.
Nevertheless, I should fix the torch code in nara_wpe. With the newer torch versions, all required operations are implemented.
Thank you!
BTW You might find it interesting -- I performed a simple benchmark on a single utterance with Jupyter's %%timeit and saw that the numpy version took 300ms on average, while torch implementation took 130ms on average (on CPU). It's likely partially explained by the missing conjugate, but still seems worthy of attention.
The conjugate should not have this effect. I think it comes from the solve operation or the memory view. Pytorch does not support negative strides, hence I used a view that does not match the theory, but produces the correct final result. Thanks for reporting this, I am not sure, when I will find the time to investigate this. Is your benchmark a shareable toy example?
I'll try to find a moment to post it tomorrow.
Check out https://drive.google.com/file/d/12hM2rpt6xzKDfPeRPvRAl0aWccIAMSUx/view?usp=sharing
I fixed the torch wpe code, removed several deprecation warnings and pushed a new version to PyPI.
Thanks for sharing the notebook. I am not yet sure, when I find the time to check it and see if I can speedup the numpy code.
Thanks!