catalyst icon indicating copy to clipboard operation
catalyst copied to clipboard

Find a better way to link blas routines from jaxlib

Open paul0403 opened this issue 1 year ago • 2 comments

Currently catalyst links to the blas library shipped with scipy, not jaxlib, because jaxlib does not ship their blas library with a .so shared object file.

This results in when calling numerical functions from jaxlib, for example those in jax.scipy.linalg, most functions will call undefined symbols. This is now being fixed by manually adding in the routines in frontend/catalyst/utils/libcustom_calls.cpp on a case by case basis, but this is not efficient. For example,

@qml.qjit
def func(x):
	res = jax.scipy.linalg.expm(x)
	return res

y = jnp.array([[1, 0], [0, 1]])
x = func(y)

>>>   [[2.71828183 0.        ]
 [0.         2.71828183]] 

but

@qml.qjit
def func(x):
	res = jax.scipy.linalg.sqrtm(x)
	return res

y = jnp.array([[1, 0], [0, 1]])
x = func(y)

>>> Traceback (most recent call last):
  File "/home/paul.wang/small_playgrounds_dump/expmfix.py", line 56, in <module>
    x = func(y)
  File "/home/paul.wang/catalyst/frontend/catalyst/jit.py", line 110, in __call__
    requires_promotion = self.jit_compile(args)
  File "/home/paul.wang/catalyst/frontend/catalyst/jit.py", line 171, in jit_compile
    self.compiled_function, self.qir = self.compile()
  File "/home/paul.wang/catalyst/frontend/catalyst/debug/instruments.py", line 143, in wrapper
    return fn(*args, **kwargs)
  File "/home/paul.wang/catalyst/frontend/catalyst/jit.py", line 278, in compile
    compiled_fn = CompiledFunction(shared_object, func_name, restype, self.compile_options)
  File "/home/paul.wang/catalyst/frontend/catalyst/compiled_functions.py", line 132, in __init__
    self.shared_object = SharedObjectManager(shared_object_file, func_name)
  File "/home/paul.wang/catalyst/frontend/catalyst/compiled_functions.py", line 61, in __init__
    self.open()
  File "/home/paul.wang/catalyst/frontend/catalyst/compiled_functions.py", line 65, in open
    self.shared_object = ctypes.CDLL(self.shared_object_file)
  File "/usr/lib/python3.10/ctypes/__init__.py", line 374, in __init__
    self._handle = _dlopen(self._name, mode)
OSError: /tmp/funcwym85xsl/func.so: undefined symbol: lapack_zgees

Since the blas routines required for expm were added in #752, but those for sqrtm are still missing.

See details in #752

paul0403 avatar May 17 '24 16:05 paul0403