`contract` failed to find required methods from `jax.numpy` backend
Hi !
I'm trying to use opt_einsum with jax, but it seems like it fail to find the correct methods from jax module. Resulting into a conversion from jax Array to ndarray or a traceback if jax is specified as backend.
This piece of code reproduce the error behavior.
import sys
import jax
import jax.numpy as jnp
import opt_einsum
print(sys.version)
print(jax.__version__)
print(opt_einsum.__version__)
jnp.startswith
x = jnp.linspace(0, 1, 32)
print(type(opt_einsum.contract("i->i", x)))
print(type(opt_einsum.contract("i->i", x, backend=jnp)))
issue_report$ python main.py
3.10.13 (main, Dec 15 2023, 19:01:59) [GCC 11.4.0]
0.4.24
v3.3.0
<class 'numpy.ndarray'>
Traceback (most recent call last):
File "/[...]/issue_report/venv/lib/python3.10/site-packages/opt_einsum/backends/dispatch.py", line 65, in get_func
return _cached_funcs[func, backend]
KeyError: ('einsum', <module 'jax.numpy' from '/[...]/issue_report/venv/lib/python3.10/site-packages/jax/numpy/__init__.py'>)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/[...]/issue_report/venv/lib/python3.10/site-packages/opt_einsum/backends/dispatch.py", line 38, in _import_func
lib = importlib.import_module(_aliases.get(backend, backend))
File "/[...]/.pyenv/versions/3.10.13/lib/python3.10/importlib/__init__.py", line 117, in import_module
if name.startswith('.'):
AttributeError: module 'jax.numpy' has no attribute 'startswith'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/[...]/issue_report/main.py", line 14, in <module>
print(type(opt_einsum.contract("i->i", x, backend=jnp)))
File "/[...]/issue_report/venv/lib/python3.10/site-packages/opt_einsum/contract.py", line 507, in contract
return _core_contract(operands, contraction_list, backend=backend, **einsum_kwargs)
File "/[...]/issue_report/venv/lib/python3.10/site-packages/opt_einsum/contract.py", line 591, in _core_contract
new_view = _einsum(einsum_str, *tmp_operands, backend=backend, **einsum_kwargs)
File "/[...]/issue_report/venv/lib/python3.10/site-packages/opt_einsum/sharing.py", line 151, in cached_einsum
return einsum(*args, **kwargs)
File "/[...]/issue_report/venv/lib/python3.10/site-packages/opt_einsum/contract.py", line 337, in _einsum
fn = backends.get_func('einsum', kwargs.pop('backend', 'numpy'))
File "/[...]/issue_report/venv/lib/python3.10/site-packages/opt_einsum/backends/dispatch.py", line 67, in get_func
fn = _import_func(func, backend, default)
File "/[...]/issue_report/venv/lib/python3.10/site-packages/opt_einsum/backends/dispatch.py", line 44, in _import_func
raise AttributeError(error_msg.format(backend, func))
AttributeError: <module 'jax.numpy' from '/[...]/issue_report/venv/lib/python3.10/site-packages/jax/numpy/__init__.py'> doesn't seem to provide the function einsum - see https://optimized-einsum.readthedocs.io/en/latest/backends.html for details on which functions are required for which contractions.
I've just strip in [...] personal folder information.
Thanks and best regards.
Looks like the the backend dispatch aliases need updating with "jaxlib": "jax.numpy" for it to work automatically. Note passing the module directly is not supported, instead if you call with backend="jax" it should work.
Hi, Indeed, it works better that way, I should have been more attentive Thanks a lot
@jcmgray Could you make an associated PR if you have a minute?
Should be closed by https://github.com/dgasmith/opt_einsum/pull/228.