tntorch icon indicating copy to clipboard operation
tntorch copied to clipboard

tntorch.metrics.dot() is not perform the same calculation as the API documents described

Open zhaoran0072004 opened this issue 4 months ago • 0 comments

tntorch.metrics.dot() is not perform the same calculation as the API documents described

Example: suppose t1 has shape 3 x 4 and t2 has shape 3 x 4 x 5 x 6. Then, tn.dot(t1, t2) will have shape 5 x 6.

It wasn't going do this contraction, instead the function ouputs a Runtime Error

a=torch.randn(3,4) b=torch.randn(3,4,5,6) tn.metrics.dot(a,b) Traceback (most recent call last): File "", line 1, in File "D:\anaconda\envs\TT-PINN\Lib\site-packages\tntorch\metrics.py", line 67, in dot return t1.flatten().dot(t2.flatten()) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: inconsistent tensor size, expected tensor [12] and src [360] to have the same number of elements, but got 12 and 360 elements respectively

zhaoran0072004 avatar Feb 03 '24 22:02 zhaoran0072004