catalyst
catalyst copied to clipboard
Find a better way to link blas routines from jaxlib
Currently catalyst links to the blas library shipped with scipy, not jaxlib, because jaxlib does not ship their blas library with a .so shared object file.
This results in when calling numerical functions from jaxlib, for example those in jax.scipy.linalg, most functions will call undefined symbols. This is now being fixed by manually adding in the routines in frontend/catalyst/utils/libcustom_calls.cpp on a case by case basis, but this is not efficient. For example,
@qml.qjit
def func(x):
res = jax.scipy.linalg.expm(x)
return res
y = jnp.array([[1, 0], [0, 1]])
x = func(y)
>>> [[2.71828183 0. ]
[0. 2.71828183]]
but
@qml.qjit
def func(x):
res = jax.scipy.linalg.sqrtm(x)
return res
y = jnp.array([[1, 0], [0, 1]])
x = func(y)
>>> Traceback (most recent call last):
File "/home/paul.wang/small_playgrounds_dump/expmfix.py", line 56, in <module>
x = func(y)
File "/home/paul.wang/catalyst/frontend/catalyst/jit.py", line 110, in __call__
requires_promotion = self.jit_compile(args)
File "/home/paul.wang/catalyst/frontend/catalyst/jit.py", line 171, in jit_compile
self.compiled_function, self.qir = self.compile()
File "/home/paul.wang/catalyst/frontend/catalyst/debug/instruments.py", line 143, in wrapper
return fn(*args, **kwargs)
File "/home/paul.wang/catalyst/frontend/catalyst/jit.py", line 278, in compile
compiled_fn = CompiledFunction(shared_object, func_name, restype, self.compile_options)
File "/home/paul.wang/catalyst/frontend/catalyst/compiled_functions.py", line 132, in __init__
self.shared_object = SharedObjectManager(shared_object_file, func_name)
File "/home/paul.wang/catalyst/frontend/catalyst/compiled_functions.py", line 61, in __init__
self.open()
File "/home/paul.wang/catalyst/frontend/catalyst/compiled_functions.py", line 65, in open
self.shared_object = ctypes.CDLL(self.shared_object_file)
File "/usr/lib/python3.10/ctypes/__init__.py", line 374, in __init__
self._handle = _dlopen(self._name, mode)
OSError: /tmp/funcwym85xsl/func.so: undefined symbol: lapack_zgees
Since the blas routines required for expm were added in #752, but those for sqrtm are still missing.
See details in #752