catalyst
catalyst copied to clipboard
Compiler crash while computing circuit gradient (qjit and vmap) with lightning.qubit and lightning.kokkos
A forum user is reporting a crash when using jax.grad on top of a qjit program with vmap.
While there are known Enzyme issues with complex programs like those using vmap, this one crashes at the MLIR layer so is likely a different problem.
The report is reproduced below:
I have an issue computing the ‘jit-ed’ and ‘vmap-ed’ gradient of a circuit. It seems that the compiler is crashing for some reason I do not understand. A minimal example and some further information are written below. Does anybody have an idea how to resolve this issue?
import catalyst
import pennylane as qml
import jax.numpy as jnp
dev = qml.device("lightning.qubit", wires=2)
@qml.qnode(dev, interface='jax')
def circuit(x):
qml.RX(x, wires=0)
return qml.expval(qml.PauliZ(0))
batch_circuit = catalyst.vmap(circuit)
compiled_circuit = catalyst.qjit(batch_circuit)
grad_compiled = catalyst.grad(compiled_circuit) # note this dispatches to jax.grad, which inserts catalyst.jvp within the compiled program
batch_input = jnp.array([0.1, 0.2, 0.3])
print(compiled_circuit(batch_input))
print(grad_compiled(batch_input))
Traceback
2025-07-17 09:07:09.765846: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver’s CUDA version is 12.4 which is older than the ptxas CUDA version (12.9.86). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
error code -6: catalyst: /__w/catalyst/catalyst/mlir/llvm-project/mlir/lib/Analysis/Liveness.cpp:45: {anonymous}::BlockInfoBuilder::BlockInfoBuilder(mlir::Block*)::<lambda(mlir::Value)>: Assertion ownerBlock && "Use leaves the current parent region"' failed. PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace. Stack dump: 0. Program arguments: /opt/conda/bin/catalyst -o /tmp/deriv_vmap.circuitzredee5n/deriv_vmap.circuit.ll --module-name deriv_vmap.circuit --workspace /tmp/deriv_vmap.circuitzredee5n -verify-each=false --catalyst-pipeline EnforceRuntimeInvariantsPass(split-multiple-tapes;builtin.module(apply-transform-sequence);inline-nested-module),HLOLoweringPass(canonicalize;func.func(chlo-legalize-to-hlo);stablehlo-legalize-to-hlo;func.func(mhlo-legalize-control-flow);func.func(hlo-legalize-to-linalg);func.func(mhlo-legalize-to-std);func.func(hlo-legalize-sort);convert-to-signless;canonicalize;scatter-lowering;hlo-custom-call-lowering;cse;func.func(linalg-detensorize{aggressive-mode});detensorize-scf;canonicalize),QuantumCompilationPass(annotate-function;lower-mitigation;lower-gradients;adjoint-lowering),BufferizationPass(one-shot-bufferize{dialect-filter=memref};inline;gradient-preprocess;gradient-bufferize;scf-bufferize;convert-tensor-to-linalg;convert-elementwise-to-linalg;arith-bufferize;empty-tensor-to-alloc-tensor;func.func(bufferization-bufferize);func.func(tensor-bufferize);catalyst-bufferize;func.func(linalg-bufferize);func.func(tensor-bufferize);quantum-bufferize;func-bufferize;func.func(finalizing-bufferize);canonicalize;gradient-postprocess;func.func(buffer-hoisting);func.func(buffer-loop-hoisting);func.func(buffer-deallocation);convert-arraylist-to-memref;convert-bufferization-to-memref;canonicalize;cp-global-memref),MLIRToLLVMDialect(expand-realloc;convert-gradient-to-llvm;memrefcpy-to-linalgcpy;func.func(convert-linalg-to-loops);convert-scf-to-cf;expand-strided-metadata;lower-affine;arith-expand;convert-complex-to-standard;convert-complex-to-llvm;convert-math-to-llvm;convert-math-to-libm;convert-arith-to-llvm;memref-to-llvm-tbaa;finalize-memref-to-llvm{use-generic-functions};convert-index-to-llvm;convert-catalyst-to-llvm;convert-quantum-to-llvm;emit-catalyst-py-interface;canonicalize;reconcile-unrealized-casts;gep-inbounds;register-inactive-callback), /tmp/deriv_vmap.circuitzredee5n/tmpk2xja2xe.mlir
Stack dump without symbol names (ensure you have llvm-symbolizer in your PATH or set the environment var LLVM_SYMBOLIZER_PATH` to point to it): 0 catalyst 0x00000000098d62ab llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) + 43 1 catalyst 0x00000000098d36cb llvm::sys::RunSignalHandlers() + 43 2 catalyst 0x00000000098d37f5 3 libc.so.6 0x00007fa8625b1520 4 libc.so.6 0x00007fa8626059fc pthread_kill + 300 5 libc.so.6 0x00007fa8625b1476 raise + 22 6 libc.so.6 0x00007fa8625977f3 abort + 211 7 libc.so.6 0x00007fa86259771b 8 libc.so.6 0x00007fa8625a8e96 9 catalyst 0x0000000009562997 10 catalyst 0x0000000009561f0a 11 catalyst 0x000000000956d2ee mlir::Liveness::build() + 174 12 catalyst 0x000000000513f0b9 13 catalyst 0x000000000952c5ee mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) + 1038 14 catalyst 0x000000000952caa8 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) + 344 15 catalyst 0x000000000952d25d mlir::detail::OpToOpPassAdaptor::runOnOperationImpl(bool) + 461 16 catalyst 0x000000000952c42a mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) + 586 17 catalyst 0x000000000952caa8 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) + 344 18 catalyst 0x000000000952da05 mlir::PassManager::run(mlir::Operation*) + 1205 19 catalyst 0x00000000036012c3 runPipeline(mlir::PassManager&, catalyst::driver::CompilerOptions const&, catalyst::driver::CompilerOutput&, catalyst::driver::Pipeline&, bool, mlir::ModuleOp) + 259 20 catalyst 0x000000000360172d runLowering(catalyst::driver::CompilerOptions const&, mlir::MLIRContext*, mlir::ModuleOp, catalyst::driver::CompilerOutput&, mlir::TimingScope&) + 445 21 catalyst 0x0000000003603fe4 QuantumDriverMain(catalyst::driver::CompilerOptions const&, catalyst::driver::CompilerOutput&, mlir::DialectRegistry&) + 6676 22 catalyst 0x0000000003608a0b QuantumDriverMainFromCL(int, char**) + 10395 23 libc.so.6 0x00007fa862598d90 24 libc.so.6 0x00007fa862598e40 __libc_start_main + 128 25 catalyst 0x00000000035e012e _start + 46
Simplifying the program a little further, it really appears to be another issue with grad + vmap (the jax integration is not involved):
@qjit
@jacobian
@vmap
@qml.qnode(qml.device("lightning.qubit", wires=1))
def circuit(x):
qml.RX(x, wires=0)
return qml.expval(qml.PauliZ(0))
print(circuit(jnp.array([0.1, 0.2, 0.3])))
The error also appears if one uses grad instead of jacobian:
@vmap
@qml.qnode(qml.device("lightning.qubit", wires=1))
def circuit(x):
qml.RX(x, wires=0)
return qml.expval(qml.PauliZ(0))
@qjit
@grad
def loss(x):
return jnp.sum(circuit(x))
print(loss(jnp.array([0.1, 0.2, 0.3])))
Actually it's very well possible that this issue is the same as the well-known issue with differentiating vmap / for loops. This old PR was actually fixing an error related to this in the compiler, but we didn't merge it because underneath we still had the Enzyme issues leading to incorrect results (thinking was better an error than wrong results):
https://github.com/PennyLaneAI/catalyst/pull/332