Ronghang Hu

Results 46 comments of Ronghang Hu

I just made some change. It should be fixed now.

Really looking forward to this PR! > Simple test case to sanity-check that collectives work as expected: It would be great to also resolve the `reduce_scatter`, `all_gather`, and `all_to_all` collective...

Thanks @will-cromar, I'll submit a new issue for all-gather, all-reduce, and all-to-all under PJRT. > I found that the PjRt device IDs are ordered ['TPU:0', 'TPU:2', 'TPU:3', 'TPU:1']. I bet...

> Also,`xm.rendezvous` doesn't work yet, but we had another early tester tell us that they were able to work around it by creating a `gloo` process group and using `dist.barrier`...

@will-cromar I created an issue in https://github.com/pytorch/xla/issues/3824 with a simple test example for all-gather, reduce-scatter, and all-to-all (but I cannot assign the issue to you since I don't have edit...

> @ronghanghu I will give you write access 😄 Great, thank you!

This would be a great feature. Looking forward to it!

Tests showed that it worked OK under the examples above. However, a context manager that patches `torch.nn.functional.linear` is not thread-safe under PJRT (i.e. one thread might exit the context scope...

@hjm-aws I added a new commit that introduces the option `shard_param_on_dim_0` (default `False`). When `shard_param_on_dim_0` is set ``True``, then shard the parameter tensors only along their first dimension (dim 0)...

@JackCaoG This is the issue `einsum` op we just mentioned. From our earlier profiling, lowering it could potentially bring 5%+ speed up to several models. However, as mentioned in this...