catalyst
catalyst copied to clipboard
Error management for large wire numbers
I have noticed that the behaviour of qjiting circuits with large unsupported wire numbers are inconsistent. I have observed three different messages with different circuits. It would be nice to unify this behaviour at some point.
@qjit
def test_large_wires():
dev = qml.device("lightning.qubit", wires=1000)
@qml.qnode(dev)
def circuit():
qml.PauliX(wires=999)
return qml.probs()
return circuit()
result = test_large_wires()
would raise
Traceback (most recent call last):
File "/Users/mehrdad.malek/tmp/playground.py", line 7, in <module>
@qjit
^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 502, in qjit
return QJIT(fn, CompileOptions(**kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 65, in wrapper_exit
output = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 565, in __init__
self.aot_compile()
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 623, in aot_compile
self.mlir_module = self.generate_ir()
^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/debug/instruments.py", line 145, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 776, in generate_ir
mlir_module, ctx = lower_jaxpr_to_mlir(self.jaxpr, self.__name__)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_tracer.py", line 631, in lower_jaxpr_to_mlir
mlir_module, ctx = jaxpr_to_mlir(func_name, jaxpr)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/lowering.py", line 74, in jaxpr_to_mlir
module, context = custom_lower_jaxpr_to_module(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/lowering.py", line 143, in custom_lower_jaxpr_to_module
lower_jaxpr_to_fun(
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1452, in lower_jaxpr_to_fun
output_types = map(aval_to_ir_type, jaxpr.out_avals)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 228, in aval_to_ir_type
return ir_type_handlers[type(aval)](aval)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 213, in _array_ir_types
return ir.RankedTensorType.get(aval.shape, dtype_to_ir_type(aval.dtype)) # type: ignore
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: get(): incompatible function arguments. The following argument types are supported:
1. get(shape: collections.abc.Sequence[int], element_type: jaxlib.mlir._mlir_libs._mlir.ir.Type, encoding: jaxlib.mlir._mlir_libs._mlir.ir.Attribute | None = None, loc: mlir.ir.Location | None = None) -> jaxlib.mlir._mlir_libs._mlir.ir.RankedTensorType
Invoked with types: tuple, jaxlib.mlir._mlir_libs._mlir.ir.F64Type
@qjit
def test_large_wires():
dev = qml.device("lightning.qubit")
@qml.qnode(dev)
def circuit():
qml.PauliX(wires=999)
return qml.probs()
return circuit()
result = test_large_wires()
raises
Traceback (most recent call last):
File "/Users/mehrdad.malek/tmp/playground.py", line 16, in <module>
result = test_large_wires()
^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 608, in __call__
return self.run(args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/debug/instruments.py", line 145, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 835, in run
results = self.compiled_function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/compiled_functions.py", line 345, in __call__
result = CompiledFunction._exec(
^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/compiled_functions.py", line 166, in _exec
retval = wrapper.wrap(lib.function, args, result_desc, lib.mem_transfer, numpy_dict)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: vector
@qjit
def test_large_wires():
dev = qml.device("lightning.qubit")
@qml.qnode(dev)
def circuit():
qml.CNOT(wires=[0, 999])
return qml.probs()
return circuit()
result = test_large_wires()
would halt indefinitely
if a smaller number (like 40) is use instead of 999, it would halt for some time and raise:
[1] 95672 killed python /Users/mehrdad.malek/tmp/playground.py
What you're running into is an inconsistency in the linux operating system :(
Allocation requests are not immediately fulfilled by the OS, instead they may be filled upon access to that memory. This makes it somewhat unpredictable when and how running out of memory will manifest. In the worst case, the out of memory killer will come out and kill your process without any way to catch and rectify the situation. Usually with extremely large requests you will be lucky and get a C++ exception upfront.
EDIT: If you're on mac I'm not sure, it appears that it is similarly unpredictable from what you posted :/
You're first example is weird though, the error doesn't look memory related at all?
dev = qml.device("lightning.qubit", wires=1000)
This line should fail during tracing, as lightning tries to instantiate a statevector of size 2^1000, why it doesn't raise some kind of allocation error is certainly strange.
EDIT: I bet the problem is the use of qml.probs, the output shape of that would be astronomically large, and likely fails even as a type specification (a 1000 bit integer type is not available in mlir).
Scenario number 2 is very likely a C++ exception happening in the runtime or device, since it does propagate to Python. Only shame is it doesn't say much except vector, but I'm guessing this means a C++ vector failed to instantiate given the size of the request. This is something we could actually catch (in the device, Catalyst is not supposed to know anything about device internals), but it usually only happens with requests that are outrageously large.
yes the first was my main problem as well. I agree that it is unpredictable and very hard to know when during the execution memory runs out, but we should be able to raise a nice error if we predict a memory issue during tracing. In the case of the first example the failure happens within lower_jaxpr_to_fun but seems like the jaxpr has no idea of the limit of f64. Even if I rerun the example with 1500 qubits it prints out 2^1500 in the jaxpr which is beyond the limit of f64:
{ lambda ; . let
a:f64[35074662110434038747627587960280857993524015880330828824075798024790963850563322203657080886584969261653150406795437517399294548941469959754171038918004700847889956485329097264486802711583462946536682184340138629451355458264946342525383619389314960644665052551751442335509249173361130355796109709885580674313954210217657847432626760733004753275317192133674703563372783297041993227052663333668509952000175053355529058880434182538386715523683713208549376] = quantum_kernel[
call_jaxpr={ lambda ; . let
device_init[
auto_qubit_management=False
rtd_kwargs={'mcmc': False, 'num_burnin': 0, 'kernel_name': None}
rtd_lib=/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib
rtd_name=LightningSimulator
] 0
b:AbstractQreg() = qalloc 1500
c:AbstractQbit() = qextract b 999
d:AbstractQbit() = qinst[
adjoint=False
ctrl_len=0
op=PauliX
params_len=0
qubits_len=1
] c
e:AbstractQreg() = qinsert b 999 d
f:AbstractObs(num_qubits=None,qreg=AbstractQreg(),primitive=compbasis) = compbasis[
qreg_available=True
] e
g:f64[35074662110434038747627587960280857993524015880330828824075798024790963850563322203657080886584969261653150406795437517399294548941469959754171038918004700847889956485329097264486802711583462946536682184340138629451355458264946342525383619389314960644665052551751442335509249173361130355796109709885580674313954210217657847432626760733004753275317192133674703563372783297041993227052663333668509952000175053355529058880434182538386715523683713208549376] = probs[
static_shape=(35074662110434038747627587960280857993524015880330828824075798024790963850563322203657080886584969261653150406795437517399294548941469959754171038918004700847889956485329097264486802711583462946536682184340138629451355458264946342525383619389314960644665052551751442335509249173361130355796109709885580674313954210217657847432626760733004753275317192133674703563372783297041993227052663333668509952000175053355529058880434182538386715523683713208549376,)
] f
qdealloc e
device_release
in (g,) }
pipeline=()
qnode=<QNode: device='<lightning.qubit device (wires=1500) at 0x138c83170>', interface='auto', diff_method='best'>
]
in (a,) }
I think this eventually causes the lowering to fail.
Scenario 3 is the memory killer that is unpredictable, for instance the program might slow down initially as it attempts to use a page file, but ultimately the process just gets killed at some point. This is true for linux, and apparently for mac as well:
Mac OS X lets you allocate "a lot", though the system may flail frantically or be just fine, depending on how that allocated memory is used by any subsequent code. ... OpenBSD, by contrast, fails the malloc if you ask for more than is available.
Seems like a different OS would be required if we want better control of this^^
but we should be able to raise a nice error if we predict a memory issue during tracing
Happy to see some additional checks on our end regarding size in our primitives, but we can't check all type conversions since they mostly happen in JAX.