jax
jax copied to clipboard
XLA Check failed: common_utilization <= producer_output_utilization
Description
When trying to run a longer algorithm, the execution fails with an error message without a more precise indication of where in the code the issue occurred:
F external/xla/xla/service/gpu/gpu_performance_model.cc:119] Check failed: common_utilization <= producer_output_utilization (500.867 vs. 500.867)
This error only occurs in some use cases of the algorithm, and slightly changing parameters such as iterations and batch size sometimes permit it.
What jax/jaxlib version are you using?
jax v0.4.9, jaxlib v0.4.9+cuda12.cudnn88
Which accelerator(s) are you using?
GPU
Additional system info
python 3.8
NVIDIA GPU info
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17 Driver Version: 525.105.17 CUDA Version: 12.0 |
|-------------------------------+----------------------+----------------------+
I am running the code with disabled parallel compilation:
TF_USE_NVLINK_FOR_PARALLEL_COMPILATION=0
Well, that sounds like an XLA bug.
First, can you try with the latest jaxlib release (0.4.16)? The bug may already be fixed, so this is the first thing to try. You will need to update your Python version to 3.9 or newer to do this.
If that doesn't work, can you please provide instructions to reproduce? If it's hard to do that, one way is to provide an HLO dump from XLA, which you can get by setting
XLA_FLAGS=--xla_dump_to=/somewhere
and JAX_COMPILER_DETAILED_LOGGING_MIN_OPS=0
, running your script, and zip up and attach the output of /somewhere
to this issue.
Any updates? Can you share instructions to reproduce?
I upgraded the jax and jaxlib versions, but the error persists.
Unfortunately, I could not track the error to a specific part of the code. However, I did the steps you described and attached the dump.
Hi Peter Hawkins,
Do you have any updates on this issue?
Best Sebastien
I've run into the same error
Description
(jax) -bash-4.2$ JAX_COMPILER_DETAILED_LOGGING_MIN_OPS=0 XLA_FLAGS=--xla_dump_to=/tmp/xladump PYTHONPATH="./" python tests/tests.py
.F1127 20:16:06.642083 26806 gpu_performance_model.cc:358] Check failed: common_utilization <= producer_output_utilization (2.2 vs. 2.2)
*** Check failure stack trace: ***
@ 0x7f901feca23e absl::lts_20230802::log_internal::LogMessage::Flush()
@ 0x7f901feca349 absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
@ 0x7f90178ba3f4 xla::gpu::GpuPerformanceModel::ProducerInputAccessTime()
@ 0x7f9019f9f203 xla::gpu::GpuPerformanceModel::EstimateRunTimes()
@ 0x7f9019f8b521 xla::gpu::FusionInstructionMerger::ShouldFuse()
@ 0x7f9019f8cc76 xla::gpu::FusionInstructionMerger::Run()
@ 0x7f9019f8d4f4 xla::gpu::FusionMerger::Run()
@ 0x7f901da2daf5 xla::HloPassPipeline::RunPassesInternal<>()
@ 0x7f901da2e8e7 xla::HloPassPipeline::Run()
@ 0x7f9018e54c71 xla::HloPassInterface::Run()
@ 0x7f9018e6a5dd xla::gpu::GpuCompiler::OptimizeHloModule()
@ 0x7f9018e6ee61 xla::gpu::GpuCompiler::RunHloPasses()
@ 0x7f9018d8d3a9 xla::Service::BuildExecutable()
@ 0x7f9018b472ad xla::LocalService::CompileExecutables()
@ 0x7f9018b41f82 xla::LocalClient::Compile()
@ 0x7f9018b00d7c xla::PjRtStreamExecutorClient::Compile()
@ 0x7f9018adcc9f xla::StreamExecutorGpuClient::Compile()
@ 0x7f9018b1400a xla::PjRtStreamExecutorClient::Compile()
@ 0x7f9018a2e66f xla::ifrt::PjRtLoadedExecutable::Create()
@ 0x7f9018a24d14 xla::ifrt::PjRtCompiler::Compile()
@ 0x7f9017fa9a52 xla::PyClient::Compile()
@ 0x7f9017cdb0b3 pybind11::detail::argument_loader<>::call_impl<>()
@ 0x7f9017cdb560 pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()
@ 0x7f9017c90768 pybind11::cpp_function::dispatcher()
@ 0x525d17 cfunction_call
Aborted (core dumped)
The code in question is quite heavy in integer arithmetic, which may be part of the problem. You can find it here
Jax version
Jax/Jaxlib: jax-0.4.20 jaxlib-0.4.20+cuda12.cudnn89
System info
Using GPU on an RTX 4090. My code runs successfully on CPU.
Python 3.11.4, installed through anaconda
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02 Driver Version: 530.30.02 CUDA Version: 12.1 |
|-----------------------------------------+----------------------+----------------------+
Same problem here. Running on GPU. Any solutions ? Depends on the hyper parameters.
EDIT: Originally thought I had the same problem but it looks like I'm failing a different XLA check so opened a separate issue. https://github.com/google/jax/issues/20024
Any updates on this?
Running into a similar error on A6000 GPU:
F external/xla/xla/service/gpu/model/gpu_performance_model.cc:540] Check failed: common_utilization <= producer_output_utilization (1.4 vs. 1.4)
jax : 0.4.24 jaxlib : 0.4.24+cuda12.cudnn89 cuda installed via pip wheels
F0320 13:19:47.472121 32626 gpu_performance_model.cc:358] Check failed: common_utilization <= producer_output_utilization (1.4 vs. 1.4) *** Check failure stack trace: *** @ 0x7fd84c68423e absl::lts_20230802::log_internal::LogMessage::Flush() @ 0x7fd84c684349 absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal() @ 0x7fd8440743f4 xla::gpu::GpuPerformanceModel::ProducerInputAccessTime() @ 0x7fd846759203 xla::gpu::GpuPerformanceModel::EstimateRunTimes() @ 0x7fd84674e678 xla::gpu::GpuMultiOutputFusion::DoMultiOutputFusion() @ 0x7fd8467503cc xla::gpu::GpuMultiOutputFusion::Run() @ 0x7fd84a1e7af5 xla::HloPassPipeline::RunPassesInternal<>() @ 0x7fd84a1e88e7 xla::HloPassPipeline::Run() @ 0x7fd84560ec71 xla::HloPassInterface::Run() @ 0x7fd8456245dd xla::gpu::GpuCompiler::OptimizeHloModule() @ 0x7fd845628e61 xla::gpu::GpuCompiler::RunHloPasses() @ 0x7fd8455473a9 xla::Service::BuildExecutable() @ 0x7fd8453012ad xla::LocalService::CompileExecutables() @ 0x7fd8452fbf82 xla::LocalClient::Compile() @ 0x7fd8452bad7c xla::PjRtStreamExecutorClient::Compile() @ 0x7fd845296c9f xla::StreamExecutorGpuClient::Compile() @ 0x7fd8452ce00a xla::PjRtStreamExecutorClient::Compile() @ 0x7fd8451e866f xla::ifrt::PjRtLoadedExecutable::Create() @ 0x7fd8451ded14 xla::ifrt::PjRtCompiler::Compile() @ 0x7fd844763a52 xla::PyClient::Compile() @ 0x7fd8444950b3 pybind11::detail::argument_loader<>::call_impl<>() @ 0x7fd844495560 pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN() @ 0x7fd84444a768 pybind11::cpp_function::dispatcher() @ 0x525d17 cfunction_call