opt_einsum icon indicating copy to clipboard operation
opt_einsum copied to clipboard

`contract` failed to find required methods from `jax.numpy` backend

Open zazbone opened this issue 1 year ago • 3 comments

Hi !

I'm trying to use opt_einsum with jax, but it seems like it fail to find the correct methods from jax module. Resulting into a conversion from jax Array to ndarray or a traceback if jax is specified as backend.

This piece of code reproduce the error behavior.

import sys

import jax
import jax.numpy as jnp
import opt_einsum

print(sys.version)
print(jax.__version__)
print(opt_einsum.__version__)
jnp.startswith

x = jnp.linspace(0, 1, 32)
print(type(opt_einsum.contract("i->i", x)))
print(type(opt_einsum.contract("i->i", x, backend=jnp)))
issue_report$ python main.py 
3.10.13 (main, Dec 15 2023, 19:01:59) [GCC 11.4.0]
0.4.24
v3.3.0
<class 'numpy.ndarray'>
Traceback (most recent call last):
  File "/[...]/issue_report/venv/lib/python3.10/site-packages/opt_einsum/backends/dispatch.py", line 65, in get_func
    return _cached_funcs[func, backend]
KeyError: ('einsum', <module 'jax.numpy' from '/[...]/issue_report/venv/lib/python3.10/site-packages/jax/numpy/__init__.py'>)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/[...]/issue_report/venv/lib/python3.10/site-packages/opt_einsum/backends/dispatch.py", line 38, in _import_func
    lib = importlib.import_module(_aliases.get(backend, backend))
  File "/[...]/.pyenv/versions/3.10.13/lib/python3.10/importlib/__init__.py", line 117, in import_module
    if name.startswith('.'):
AttributeError: module 'jax.numpy' has no attribute 'startswith'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/[...]/issue_report/main.py", line 14, in <module>
    print(type(opt_einsum.contract("i->i", x, backend=jnp)))
  File "/[...]/issue_report/venv/lib/python3.10/site-packages/opt_einsum/contract.py", line 507, in contract
    return _core_contract(operands, contraction_list, backend=backend, **einsum_kwargs)
  File "/[...]/issue_report/venv/lib/python3.10/site-packages/opt_einsum/contract.py", line 591, in _core_contract
    new_view = _einsum(einsum_str, *tmp_operands, backend=backend, **einsum_kwargs)
  File "/[...]/issue_report/venv/lib/python3.10/site-packages/opt_einsum/sharing.py", line 151, in cached_einsum
    return einsum(*args, **kwargs)
  File "/[...]/issue_report/venv/lib/python3.10/site-packages/opt_einsum/contract.py", line 337, in _einsum
    fn = backends.get_func('einsum', kwargs.pop('backend', 'numpy'))
  File "/[...]/issue_report/venv/lib/python3.10/site-packages/opt_einsum/backends/dispatch.py", line 67, in get_func
    fn = _import_func(func, backend, default)
  File "/[...]/issue_report/venv/lib/python3.10/site-packages/opt_einsum/backends/dispatch.py", line 44, in _import_func
    raise AttributeError(error_msg.format(backend, func))
AttributeError: <module 'jax.numpy' from '/[...]/issue_report/venv/lib/python3.10/site-packages/jax/numpy/__init__.py'> doesn't seem to provide the function einsum - see https://optimized-einsum.readthedocs.io/en/latest/backends.html for details on which functions are required for which contractions.

I've just strip in [...] personal folder information.

Thanks and best regards.

zazbone avatar Feb 13 '24 09:02 zazbone

Looks like the the backend dispatch aliases need updating with "jaxlib": "jax.numpy" for it to work automatically. Note passing the module directly is not supported, instead if you call with backend="jax" it should work.

jcmgray avatar Feb 13 '24 19:02 jcmgray

Hi, Indeed, it works better that way, I should have been more attentive Thanks a lot

zazbone avatar Feb 14 '24 08:02 zazbone

@jcmgray Could you make an associated PR if you have a minute?

dgasmith avatar May 05 '24 18:05 dgasmith

Should be closed by https://github.com/dgasmith/opt_einsum/pull/228.

dgasmith avatar May 14 '24 23:05 dgasmith