maxtext icon indicating copy to clipboard operation
maxtext copied to clipboard

`inference_microbenchmark.py` always exit with 1

Open gpupuck opened this issue 7 months ago • 2 comments

With https://github.com/AI-Hypercomputer/maxtext/pull/1457, main now always returns a dict (inference_microbenchmark_sweep.py is expecting that) and will result in a final exit code 1 from app.run for inference_microbenchmark.py. I guess a bit massage is needed to ensure running inference_microbenchmark.py individually does not exit with 1.

Minimal reproducer:

cd /opt/maxtext

export XLA_PYTHON_CLIENT_MEM_FRACTION=0.9
BASE_XLA_FLAGS=${BASE_XLA_FLAGS:-"
            --xla_gpu_enable_latency_hiding_scheduler=true
            --xla_gpu_enable_command_buffer="FUSION"
            --xla_disable_hlo_passes=rematerialization"}
export XLA_FLAGS="$BASE_XLA_FLAGS $XLA_FLAGS "

RUN_SETTINGS="/opt/maxtext/MaxText/configs/base.yml run_name=logdir base_output_directory=/opt/maxtext/local_inference per_device_batch_size=40 model_name=llama2-7b ici_autoregressive_parallelism=1 max_prefill_predict_length=1024 max_target_length=2048 attention=dot_product scan_layers=False async_checkpointing=False tokenizer_path=/opt/maxtext/assets/tokenizer.llama2"


CUDA_VISIBLE_DEVICES=0 python3 -m MaxText.inference_microbenchmark $RUN_SETTINGS

gpupuck avatar May 14 '25 21:05 gpupuck

Thanks for the bug report.

@gpupuck Hmm, I'm not able to replicate. Can you tell me if either of these solutions resolve the issue for you?

diff --git a/MaxText/inference_microbenchmark.py b/MaxText/inference_microbenchmark.py
index 667b6530a..01a4686a6 100644
--- a/MaxText/inference_microbenchmark.py
+++ b/MaxText/inference_microbenchmark.py
@@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 """
+import sys
 
 """Inference microbenchmark for prefill and autoregressive steps."""
 import datetime
@@ -433,3 +434,4 @@ def main(config, **kwargs):
 
 if __name__ == "__main__":
   app.run(main)
+  sys.exit(os.R_OK)

Or:

diff --git a/MaxText/inference_microbenchmark.py b/MaxText/inference_microbenchmark.py
index 667b6530a..3c198546b 100644
--- a/MaxText/inference_microbenchmark.py
+++ b/MaxText/inference_microbenchmark.py
@@ -428,7 +428,7 @@ def run_benchmarks(config):
 
 def main(config, **kwargs):
   jax.config.update("jax_default_prng_impl", "unsafe_rbg")
-  return run_benchmarks(pyconfig.initialize(config, **kwargs))
+  print(run_benchmarks(pyconfig.initialize(config, **kwargs)))
 
 
 if __name__ == "__main__":

SamuelMarks avatar May 16 '25 18:05 SamuelMarks

Second solution will fix the issue for sure. But if you remove the return value for main function, I guess you cannot call inference_microbenchmark.main in inference_microbenchmark_sweep.py and expecting a return value anymore?

FYI, to reproduce the issue:

docker run -it --rm --gpus=all --shm-size=2g ghcr.io/nvidia/jax:maxtext-2025-05-15
bash minimal.sh
echo $?

You will see it prints 1 at the end.

gpupuck avatar May 16 '25 22:05 gpupuck

@gpupuck No that should be fine, inference_microbenchmark_sweep.py calls MaxText.inference_microbenchmark.run_benchmarks not MaxText.inference_microbenchmark.main.

Sending PR now

SamuelMarks avatar May 18 '25 05:05 SamuelMarks