jax icon indicating copy to clipboard operation
jax copied to clipboard

Is there a way to control the pjit to dynamically partition the input IDs into different ranks?

Open MoFHeka opened this issue 8 months ago • 8 comments

I need to build a model like this: There is a very large distributed dynamic shape Embedding, which can be seen as a hash table. In every DP rank, when workers get the IDs input they need to transfer the IDs to another worker for lookup table. For example, rank 0 get [0,1,2,2], also rank 1 get [0,3,2,1]. But hash table in rank 0 only stored the key 0,2,..., and rank 1 with 1,3,.... For now there will be a alltoallv operator, where rank 0 send [1] and receive [0,2], and rank 1 send [0,2] and receive [1]. And then lookup the value back to their origin rank. 9bdc0d9d711b0980579cdc0eac52f8ed564322

Note: In fact, this procedure uses the send_recv operator to implement asynchronous ring alltoall to overlap the table lookup time.

Please:

  • [x] Check for duplicate requests.
  • [x] Describe your goal, and if possible provide a code snippet with a motivating example.

MoFHeka avatar Jun 16 '24 19:06 MoFHeka