tntorch
tntorch copied to clipboard
tntorch.metrics.dot() is not perform the same calculation as the API documents described
tntorch.metrics.dot() is not perform the same calculation as the API documents described
Example: suppose t1 has shape 3 x 4 and t2 has shape 3 x 4 x 5 x 6. Then, tn.dot(t1, t2) will have shape 5 x 6.
It wasn't going do this contraction, instead the function ouputs a Runtime Error
a=torch.randn(3,4) b=torch.randn(3,4,5,6) tn.metrics.dot(a,b) Traceback (most recent call last): File "
", line 1, in File "D:\anaconda\envs\TT-PINN\Lib\site-packages\tntorch\metrics.py", line 67, in dot return t1.flatten().dot(t2.flatten()) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: inconsistent tensor size, expected tensor [12] and src [360] to have the same number of elements, but got 12 and 360 elements respectively