ivy icon indicating copy to clipboard operation
ivy copied to clipboard

[bug-fix] fixed einsum throwing unwanted shape error at torch frontends

Open akshatvishu opened this issue 1 year ago • 1 comments

Close #17243

At present, using torch_frontend.einsum results in an shape error:

import ivy
import ivy.functional.frontends.torch as torch_frontend
import torch
import jax.numpy as jnp

x = torch.rand((1, 1, 100, 64))
y = torch.rand((1, 1, 50, 64))
torch.einsum('bhlk,bhtk->bhlt', [x, y])

ivy.set_jax_backend()
x = jnp.array(x)
y = jnp.array(y)
torch_frontend.einsum('bhlk,bhtk->bhlt', [x, y])

The reason for this error is that we're not handling operands passing as one list/tuple argument at torch_frontend for ivy. This PR fix this ; example below for demonstration.

import ivy
import ivy.functional.frontends.torch as torch_frontend
import torch
import jax.numpy as jnp
x = torch.rand((1, 1, 100, 64))
y = torch.rand((1, 1, 50, 64))
a = torch.einsum('bhlk,bhtk->bhlt', [x, y])
print(a)
ivy.set_jax_backend()
x = jnp.array(x)
y = jnp.array(y)
b=torch_frontend.einsum('bhlk,bhtk->bhlt', [x, y])
print(b)

"""Print
tensor([[[[17.4691, 18.0361, 17.0545,  ..., 18.4826, 16.6745, 19.7370],
          [15.3284, 15.7768, 15.8166,  ..., 15.3946, 14.3021, 15.2536],
          [15.8438, 16.3410, 15.6118,  ..., 17.3207, 14.9726, 16.0644],
          ...,
          [14.1006, 15.1445, 14.2358,  ..., 14.6510, 13.7721, 15.1843],
          [16.4025, 16.8922, 16.8847,  ..., 17.2964, 15.5659, 16.8074],
          [13.8321, 12.8741, 12.5600,  ..., 13.0061, 11.3301, 13.7633]]]])
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
ivy.frontends.torch.Tensor([[[[17.46906853, 18.03613091, 17.05452919, ..., 18.48261452,
          16.67450714, 19.73703003],
         [15.32843208, 15.77681828, 15.8165741 , ..., 15.39459133,
          14.30214882, 15.25363159],
         [15.84376431, 16.34097481, 15.61175632, ..., 17.32069778,
          14.97262001, 16.06442642],
         ...,
         [14.10061169, 15.14452839, 14.23581696, ..., 14.65102863,
          13.77206898, 15.1842823 ],
         [16.40252113, 16.89216995, 16.884655  , ..., 17.29643822,
          15.56585598, 16.80744553],
         [13.83210373, 12.8740654 , 12.55998802, ..., 13.00606918,
          11.33005714, 13.76333237]]]])
"""

akshatvishu avatar Jun 15 '23 18:06 akshatvishu

If you are working on an open task, please edit the PR description to link to the issue you've created.

For more information, please check ToDo List Issues Guide.

Thank you :hugs:

ivy-leaves avatar Jun 19 '23 18:06 ivy-leaves

Hi, it seems the issue has been fixed and closed already

zhumakhan avatar Jun 27 '23 05:06 zhumakhan