Pathways: pass XLA flags correctly
Pathways requires XLA flags to be passed to the pathways-proxy and MXLA flags need to be set on each worker.
It also allows overriding or adding more XLA flags by passing --pathways_xla_flags="xla_tpu_x=24,megascale_y=true".
Pathways currently also only supports passing the flags as command line arguments instead of passing them as environment variables.
I verified the XLA flags are effective by looking at step time for Fuji v2 7B on v6e-16:
- Steptime in pathways without XLA flags being passed correctly: 2.259768711319994 seconds
- Steptime in pathways with XLA flags being passed correctly: 2.1042977840899404 seconds
- Steptime in McJax with XLA flags being passed correctly:: 2.087459141649306 seconds
Also tested flags are being set correct by checking the head pod pathways-proxy container manually:
axlearn gcp launch run --cluster=$CLUSTER \
--runner_name gke_tpu_pathways --pathways_xla_flags=xla_jf_crs_combiner_threshold_count=11 \
--name=$USER \
--instance_type=tpu-v6e-16 \
--num_replicas=1 \
--bundler_spec=allow_dirty=True \
--bundler_type=artifactregistry --bundler_spec=image=tpu \
--bundler_spec=dockerfile=Dockerfile --bundler_spec=target=tpu \
-- sleep infinity;
I think the error from CI is not related to my PR? https://github.com/apple/axlearn/actions/runs/14895076691/job/41907295505?pr=1163#step:8:9480
#22 475.0 ==================================== ERRORS ====================================
#22 475.0 _______ ERROR collecting axlearn/open_api/metrics/code_contests_test.py ________
#22 475.0 ImportError while importing test module '/root/axlearn/open_api/metrics/code_contests_test.py'.
#22 475.0 Hint: make sure your test modules/packages have valid Python names.
#22 475.0 Traceback:
#22 475.0 /usr/local/lib/python3.10/importlib/__init__.py:126: in import_module
#22 475.0 return _bootstrap._gcd_import(name[level:], package, level)
#22 475.0 axlearn/open_api/metrics/code_contests_test.py:18: in <module>
#22 475.0 from axlearn.open_api.metrics import code_contests
#22 475.0 axlearn/open_api/metrics/code_contests.py:14: in <module>
#22 475.0 from axlearn.open_api.common import (
#22 475.0 axlearn/open_api/common.py:27: in <module>
#22 475.0 from openai.types.chat.chat_completion_message import ChatCompletionMessage
#22 475.0 E ModuleNotFoundError: No module named 'openai'
#22 475.0 -------------- generated xml file: /root/test-results/testing.xml --------------
@Ethanlm Thanks for reviewing and approving! Do we need anyone else to approve this?
@Ethanlm Thanks for reviewing and approving! Do we need anyone else to approve this?
Yes code owner approvals are required
@markblee could you review again please? I've added the docstring and responded to your comment about why one of the XLA flags was removed intentionally.
I will test it in our internal env as a final validation
Chatted with @Ethanlm the issue with the current PR is that it doesn't allow users to modify XLA flags in axlearn code. Please hold off with merging for now. I will discuss with Shaurya on how to address.
@Ethanlm can you please review the PR again now that I've added the ability to add and override any of the existing XLA flags?
@markblee could I get a final review from you so we can get it merged?
@markblee could you review once more please? I've addressed all your comments.
@markblee could you please review once more? I moved the get_xla_options and get_megascale_options to pathways_util.py. Nothing else changed. I also re-ran my manual test.
@markblee could you please review once more? I moved the get_xla_options and get_megascale_options to pathways_util.py. Nothing else changed. I also re-ran my manual test.
Thanks, I still see them in compiler_options, were they accidentally copied?
Good catch! Apologies for somehow messing up the deletion in compiler_options.py.
I reviewed the final changes and also moved the parse_xla_flag function to pathways_utils.py. Please review again once more.
Seems the error is unrelated to my PR: https://github.com/apple/axlearn/actions/runs/15166874891/job/42647986834?pr=1163#step:8:5949
#22 155.3 /opt/venv/lib/python3.10/site-packages/transformers/integrations/tensor_parallel.py:465: in __init__
#22 155.3 self.input_layouts = (input_layouts or Replicate(),)
#22 155.3 E NameError: name 'Replicate' is not defined
Merged with latest main in the hopes that it would fix the CI issue. Could you re-trigger?
Specifically #1205 may fix the CI failure in my run.
@markblee checks are now passing after merging latest main. Are we good to get this merged?
Failure seems unrelated to my PR:
The hosted runner: GitHub Actions 314 lost communication with the server. Anything in your workflow that terminates the runner process, starves it for CPU/Memory, or blocks its network access can cause this error.
@markblee not sure if there is anything else I should do.
Let me try again