jax
jax copied to clipboard
Is there a way to control the pjit to dynamically partition the input IDs into different ranks?
I need to build a model like this:
There is a very large distributed dynamic shape Embedding, which can be seen as a hash table.
In every DP rank, when workers get the IDs input they need to transfer the IDs to another worker for lookup table.
For example, rank 0 get [0,1,2,2], also rank 1 get [0,3,2,1]. But hash table in rank 0 only stored the key 0,2,..., and rank 1 with 1,3,.... For now there will be a alltoallv operator, where rank 0 send [1] and receive [0,2], and rank 1 send [0,2] and receive [1]. And then lookup the value back to their origin rank.
Note: In fact, this procedure uses the send_recv operator to implement asynchronous ring alltoall to overlap the table lookup time.
Please:
- [x] Check for duplicate requests.
- [x] Describe your goal, and if possible provide a code snippet with a motivating example.