xla
xla copied to clipboard
[DDP] Add a test case to test a larger model
Summary: This commit adds a test case to test a larger model that can trigger multiple all_reduces instead of one.
Test Plan: XRT: MASTER_ADDR=localhost MASTER_PORT=6000 python test/test_ddp.py TestXrtDistributedDataParallel.test_ddp_correctness_large_net PJRT: PJRT_DEVICE=TPU python test/pjrt/test_ddp.py TestPjRtDistributedDataParallel.test_ddp_correctness_large_net
Thanks Will for approving it. I will fix all the CI issues before merging.