models
models copied to clipboard
A walkaround on JIT compilation issue in MNMG training
Description
:memo: removing jit_compile=True, and annotate the whole train_step function by tf.function.
- Please also include relevant motivation and context.
When training a model with multiworkermirrored stratety, jit_compile=True causes compilation issues on some graph nodes. We can remove jit_compile=True and annotate the whole train_step function, then by setting up relevant env variables, XLA can be enabled.- List any dependencies that are required for this change.
No dependencies
Type of change
For a new feature or function, please create an issue first to discuss it with us before submitting a pull request.
Note: Please delete options that are not relevant.
- [x] Bug fix (non-breaking change which fixes an issue)
Tests
:memo: Please describe the tests that you ran to verify your changes.
- Provide instructions so we can reproduce.
Here is the command I ran on 2 nodes with tf.config correctly set up:/usr/local/bin/python3.9 official/projects/vit/train.py --experiment vit_imagenet_pretrain --mode train --model_dir /opt/ml/model --params_override runtime.enable_xla=True,runtime.num_gpus=8,runtime.distribution_strategy=multi_worker_mirrored,runtime.all_reduce_alg=nccl,runtime.mixed_precision_dtype=float16,task.train_data.global_batch_size=1024,task.train_data.input_path=/opt/ml/input/data/training/train*,task.train_data.cache=True,trainer.train_steps=6347,trainer.steps_per_loop=634,trainer.summary_interval=634,trainer.checkpoint_interval=6347,task.model.backbone.type=vit- Please also list any relevant details for your test configuration.
Test Configuration:
Checklist
- [X] I have signed the Contributor License Agreement.
- [X] I have read guidelines for pull request.
- [X] My code follows the coding guidelines.
- [X] I have performed a self code review of my own code.
- [X] I have commented my code, particularly in hard-to-understand areas.
- [X] I have made corresponding changes to the documentation.
- [X] My changes generate no new warnings.
- [X] I have added tests that prove my fix is effective or that my feature works.
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).
View this failed invocation of the CLA check for more information.
For the most up to date status, view the checks section at the bottom of the pull request.
This sounds like a bug of tf.distribute and xla. Could you file an issue to TF repo?
@yuefengz @cheshire
Hi @YuchengT ,
Looks like it is the bug of tf.distribute and xla.There is no action item do from the official.
could you please close the PR here.
Thanks.
Closing this PR as it is not related to official.