multinerf icon indicating copy to clipboard operation
multinerf copied to clipboard

Failed to run with Jax 0.4.18

Open yuehaowang opened this issue 1 year ago • 1 comments

I was running RawNeRF with the latest Jax 0.4.18 but encountered the error message below after training ~300 iterations:

INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.graph.launch' failed: Failed to update gpu graph: Graph update result=kNodeTypeChanged: Failed to update CUDA graph: CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE: the graph update was not performed because it included changes which violated constraints specific to instantiated graph update; current profiling annotation: XlaModule:#hlo_module=pmap_train_step,program_id=237#.

After downgrading Jax from 0.4.18 to 0.4.16, this error was gone.

I was using CUDA 11.8. I installed Jax via jax[cuda11_local]. The installed packages were jax v0.4.18, jaxlib 0.4.18+cuda11.cudnn86. Not sure if this is due to conflicts with other packages.

yuehaowang avatar Oct 08 '23 16:10 yuehaowang

I encountered this error

ValueError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.graph.launch' failed: CaptureGpuGraph failed (the requested functionality is not supported; current tracing scope: custom-call.119): INTERNAL: Failed to end stream capture: CUDA_ERROR_STREAM_CAPTURE_INVALIDATED: operation failed due to a previous error during capture; current profiling annotation: XlaModule:#hlo_module=jit_body_fn,program_id=214#.

which was also solved by downgrading to 0.4.16 (thanks @yuehaowang), copypasting it here so it's google-able.

deoxyribose avatar Oct 17 '23 12:10 deoxyribose