catalyst icon indicating copy to clipboard operation
catalyst copied to clipboard

Error management for large wire numbers

Open mehrdad2m opened this issue 4 months ago • 6 comments

I have noticed that the behaviour of qjiting circuits with large unsupported wire numbers are inconsistent. I have observed three different messages with different circuits. It would be nice to unify this behaviour at some point.

@qjit
def test_large_wires():
    dev = qml.device("lightning.qubit", wires=1000)
    @qml.qnode(dev)
    def circuit():
        qml.PauliX(wires=999)
        return qml.probs()
    return circuit()
result = test_large_wires()

would raise

Traceback (most recent call last):
  File "/Users/mehrdad.malek/tmp/playground.py", line 7, in <module>
    @qjit
     ^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 502, in qjit
    return QJIT(fn, CompileOptions(**kwargs))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 65, in wrapper_exit
    output = func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 565, in __init__
    self.aot_compile()
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 623, in aot_compile
    self.mlir_module = self.generate_ir()
                       ^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/debug/instruments.py", line 145, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 776, in generate_ir
    mlir_module, ctx = lower_jaxpr_to_mlir(self.jaxpr, self.__name__)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_tracer.py", line 631, in lower_jaxpr_to_mlir
    mlir_module, ctx = jaxpr_to_mlir(func_name, jaxpr)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/lowering.py", line 74, in jaxpr_to_mlir
    module, context = custom_lower_jaxpr_to_module(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/lowering.py", line 143, in custom_lower_jaxpr_to_module
    lower_jaxpr_to_fun(
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1452, in lower_jaxpr_to_fun
    output_types = map(aval_to_ir_type, jaxpr.out_avals)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 228, in aval_to_ir_type
    return ir_type_handlers[type(aval)](aval)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 213, in _array_ir_types
    return ir.RankedTensorType.get(aval.shape, dtype_to_ir_type(aval.dtype))  # type: ignore
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: get(): incompatible function arguments. The following argument types are supported:
    1. get(shape: collections.abc.Sequence[int], element_type: jaxlib.mlir._mlir_libs._mlir.ir.Type, encoding: jaxlib.mlir._mlir_libs._mlir.ir.Attribute | None = None, loc: mlir.ir.Location | None = None) -> jaxlib.mlir._mlir_libs._mlir.ir.RankedTensorType

Invoked with types: tuple, jaxlib.mlir._mlir_libs._mlir.ir.F64Type
@qjit
def test_large_wires():
    dev = qml.device("lightning.qubit")
    @qml.qnode(dev)
    def circuit():
        qml.PauliX(wires=999)
        return qml.probs()
    return circuit()
result = test_large_wires()

raises

Traceback (most recent call last):
  File "/Users/mehrdad.malek/tmp/playground.py", line 16, in <module>
    result = test_large_wires()
             ^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 608, in __call__
    return self.run(args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/debug/instruments.py", line 145, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 835, in run
    results = self.compiled_function(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/compiled_functions.py", line 345, in __call__
    result = CompiledFunction._exec(
             ^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/compiled_functions.py", line 166, in _exec
    retval = wrapper.wrap(lib.function, args, result_desc, lib.mem_transfer, numpy_dict)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: vector
@qjit
def test_large_wires():
    dev = qml.device("lightning.qubit")
    @qml.qnode(dev)
    def circuit():
        qml.CNOT(wires=[0, 999])
        return qml.probs()
    return circuit()
result = test_large_wires()

would halt indefinitely

if a smaller number (like 40) is use instead of 999, it would halt for some time and raise:

[1]    95672 killed     python /Users/mehrdad.malek/tmp/playground.py

mehrdad2m avatar Jul 07 '25 20:07 mehrdad2m

What you're running into is an inconsistency in the linux operating system :(

Allocation requests are not immediately fulfilled by the OS, instead they may be filled upon access to that memory. This makes it somewhat unpredictable when and how running out of memory will manifest. In the worst case, the out of memory killer will come out and kill your process without any way to catch and rectify the situation. Usually with extremely large requests you will be lucky and get a C++ exception upfront.

EDIT: If you're on mac I'm not sure, it appears that it is similarly unpredictable from what you posted :/

dime10 avatar Jul 10 '25 20:07 dime10

You're first example is weird though, the error doesn't look memory related at all?

dev = qml.device("lightning.qubit", wires=1000)

This line should fail during tracing, as lightning tries to instantiate a statevector of size 2^1000, why it doesn't raise some kind of allocation error is certainly strange.

EDIT: I bet the problem is the use of qml.probs, the output shape of that would be astronomically large, and likely fails even as a type specification (a 1000 bit integer type is not available in mlir).

dime10 avatar Jul 10 '25 20:07 dime10

Scenario number 2 is very likely a C++ exception happening in the runtime or device, since it does propagate to Python. Only shame is it doesn't say much except vector, but I'm guessing this means a C++ vector failed to instantiate given the size of the request. This is something we could actually catch (in the device, Catalyst is not supposed to know anything about device internals), but it usually only happens with requests that are outrageously large.

dime10 avatar Jul 10 '25 20:07 dime10

yes the first was my main problem as well. I agree that it is unpredictable and very hard to know when during the execution memory runs out, but we should be able to raise a nice error if we predict a memory issue during tracing. In the case of the first example the failure happens within lower_jaxpr_to_fun but seems like the jaxpr has no idea of the limit of f64. Even if I rerun the example with 1500 qubits it prints out 2^1500 in the jaxpr which is beyond the limit of f64:

{ lambda ; . let
    a:f64[35074662110434038747627587960280857993524015880330828824075798024790963850563322203657080886584969261653150406795437517399294548941469959754171038918004700847889956485329097264486802711583462946536682184340138629451355458264946342525383619389314960644665052551751442335509249173361130355796109709885580674313954210217657847432626760733004753275317192133674703563372783297041993227052663333668509952000175053355529058880434182538386715523683713208549376] = quantum_kernel[
      call_jaxpr={ lambda ; . let
          device_init[
            auto_qubit_management=False
            rtd_kwargs={'mcmc': False, 'num_burnin': 0, 'kernel_name': None}
            rtd_lib=/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib
            rtd_name=LightningSimulator
          ] 0
          b:AbstractQreg() = qalloc 1500
          c:AbstractQbit() = qextract b 999
          d:AbstractQbit() = qinst[
            adjoint=False
            ctrl_len=0
            op=PauliX
            params_len=0
            qubits_len=1
          ] c
          e:AbstractQreg() = qinsert b 999 d
          f:AbstractObs(num_qubits=None,qreg=AbstractQreg(),primitive=compbasis) = compbasis[
            qreg_available=True
          ] e
          g:f64[35074662110434038747627587960280857993524015880330828824075798024790963850563322203657080886584969261653150406795437517399294548941469959754171038918004700847889956485329097264486802711583462946536682184340138629451355458264946342525383619389314960644665052551751442335509249173361130355796109709885580674313954210217657847432626760733004753275317192133674703563372783297041993227052663333668509952000175053355529058880434182538386715523683713208549376] = probs[
            static_shape=(35074662110434038747627587960280857993524015880330828824075798024790963850563322203657080886584969261653150406795437517399294548941469959754171038918004700847889956485329097264486802711583462946536682184340138629451355458264946342525383619389314960644665052551751442335509249173361130355796109709885580674313954210217657847432626760733004753275317192133674703563372783297041993227052663333668509952000175053355529058880434182538386715523683713208549376,)
          ] f
          qdealloc e
          device_release 
        in (g,) }
      pipeline=()
      qnode=<QNode: device='<lightning.qubit device (wires=1500) at 0x138c83170>', interface='auto', diff_method='best'>
    ] 
  in (a,) }

I think this eventually causes the lowering to fail.

mehrdad2m avatar Jul 10 '25 20:07 mehrdad2m

Scenario 3 is the memory killer that is unpredictable, for instance the program might slow down initially as it attempts to use a page file, but ultimately the process just gets killed at some point. This is true for linux, and apparently for mac as well:

Mac OS X lets you allocate "a lot", though the system may flail frantically or be just fine, depending on how that allocated memory is used by any subsequent code. ... OpenBSD, by contrast, fails the malloc if you ask for more than is available.

Seems like a different OS would be required if we want better control of this^^

dime10 avatar Jul 10 '25 20:07 dime10

but we should be able to raise a nice error if we predict a memory issue during tracing

Happy to see some additional checks on our end regarding size in our primitives, but we can't check all type conversions since they mostly happen in JAX.

dime10 avatar Jul 10 '25 20:07 dime10