`inference_microbenchmark.py` always exit with 1
With https://github.com/AI-Hypercomputer/maxtext/pull/1457, main now always returns a dict (inference_microbenchmark_sweep.py is expecting that) and will result in a final exit code 1 from app.run for inference_microbenchmark.py. I guess a bit massage is needed to ensure running inference_microbenchmark.py individually does not exit with 1.
Minimal reproducer:
cd /opt/maxtext
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.9
BASE_XLA_FLAGS=${BASE_XLA_FLAGS:-"
--xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_enable_command_buffer="FUSION"
--xla_disable_hlo_passes=rematerialization"}
export XLA_FLAGS="$BASE_XLA_FLAGS $XLA_FLAGS "
RUN_SETTINGS="/opt/maxtext/MaxText/configs/base.yml run_name=logdir base_output_directory=/opt/maxtext/local_inference per_device_batch_size=40 model_name=llama2-7b ici_autoregressive_parallelism=1 max_prefill_predict_length=1024 max_target_length=2048 attention=dot_product scan_layers=False async_checkpointing=False tokenizer_path=/opt/maxtext/assets/tokenizer.llama2"
CUDA_VISIBLE_DEVICES=0 python3 -m MaxText.inference_microbenchmark $RUN_SETTINGS
Thanks for the bug report.
@gpupuck Hmm, I'm not able to replicate. Can you tell me if either of these solutions resolve the issue for you?
diff --git a/MaxText/inference_microbenchmark.py b/MaxText/inference_microbenchmark.py
index 667b6530a..01a4686a6 100644
--- a/MaxText/inference_microbenchmark.py
+++ b/MaxText/inference_microbenchmark.py
@@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
+import sys
"""Inference microbenchmark for prefill and autoregressive steps."""
import datetime
@@ -433,3 +434,4 @@ def main(config, **kwargs):
if __name__ == "__main__":
app.run(main)
+ sys.exit(os.R_OK)
Or:
diff --git a/MaxText/inference_microbenchmark.py b/MaxText/inference_microbenchmark.py
index 667b6530a..3c198546b 100644
--- a/MaxText/inference_microbenchmark.py
+++ b/MaxText/inference_microbenchmark.py
@@ -428,7 +428,7 @@ def run_benchmarks(config):
def main(config, **kwargs):
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
- return run_benchmarks(pyconfig.initialize(config, **kwargs))
+ print(run_benchmarks(pyconfig.initialize(config, **kwargs)))
if __name__ == "__main__":
Second solution will fix the issue for sure. But if you remove the return value for main function, I guess you cannot call inference_microbenchmark.main in inference_microbenchmark_sweep.py and expecting a return value anymore?
FYI, to reproduce the issue:
docker run -it --rm --gpus=all --shm-size=2g ghcr.io/nvidia/jax:maxtext-2025-05-15
bash minimal.sh
echo $?
You will see it prints 1 at the end.
@gpupuck No that should be fine, inference_microbenchmark_sweep.py calls MaxText.inference_microbenchmark.run_benchmarks not MaxText.inference_microbenchmark.main.
Sending PR now