xla icon indicating copy to clipboard operation
xla copied to clipboard

Shardy support

Open nsmithtt opened this issue 6 months ago β€’ 22 comments

πŸš€ Feature

Support shardy instead of GSPMD.

Motivation

GSPMD will soon become a deprecated IR in openxla in favor of shardy.

Pitch

There are probably two paths forward:

  • Support the GSPMD -> Shardy conversion pass from OpenXLA, this can probably enable shardy based IR faster.
  • Support shardy as a first class backend IR in torch-xla.

nsmithtt avatar Jun 12 '25 13:06 nsmithtt

Would love to get some input from torch-xla maintainers who'd have a much better feel for the scope of work here. Also, if it's already been thought about / considered.

nsmithtt avatar Jun 12 '25 13:06 nsmithtt

Hello, me and @GleasonK both took a look at this a while ago. I will add here what I know, and I will let Kevin correct me or add anything:

The biggest hurdle on the way to achieve Shardy is that it consumes MLIR. Currently we output both HLO and StableHLO. Migrating fully to StableHLO should be relatively trivial. Migrating the entire system to use MLIR is desirable, but initial estimates for achieving this is really long as it requires a lot of changes accross PyTorchXLA.

Some of the potential work being done related to https://github.com/pytorch/xla/issues/9019 will likely see PyTorchXLA eventually use IFRT, but the timing on that is not settled in stone at the moment.

The "GSPMD -> Shardy conversion pass from OpenXLA" initial implementation might be what needs to be done at first depending on the timing of deprecation for GSPMD.

pgmoka avatar Jun 12 '25 18:06 pgmoka

Awesome, thank you for the quick response! This is a good insight, we're going to investigate what it takes to invoke the GSPMD -> Shardy conversion pass.

nsmithtt avatar Jun 12 '25 18:06 nsmithtt

That's probably the best path. I'm curious what your PTXLA ingestion pipeline looks like? AFAIK today torch_xla pass HLO to PJRT plugins, are you using the XLA_STABLEHLO_COMPILE flag? If so adding the SDY support you want shouldn't be too difficult!

cc @tomnatan30 @bartchr808

GleasonK avatar Jun 12 '25 22:06 GleasonK

Yes we are using that flag:

os.environ["PJRT_DEVICE"] = "TT"
os.environ["XLA_STABLEHLO_COMPILE"] = "1"

I think we also tried, without luck:

config.update("jax_use_shardy_partitioner", True)

nsmithtt avatar Jun 13 '25 03:06 nsmithtt

Chatted with the Shardy folks - there's going to be two work items for getting these passes to work:

  1. The conversion should happen in the PT/XLA layer right around the XLA_STABLEHLO_COMPILE callsite before passing the IR to the PJRT plugin - this will save you from needing to take a dependency on XLA since the passes require an XLA dep for the sharding proto.
  2. We'll need PTXLA to build OpShardingV2 annotations (V1: devices=[4,2]0,1,2,3,4,5,6,7,8 vs V2: devices=[4,2]<=[8]) -- this isn't a hard change in theory since PT/XLA is already marking sharding based on a mesh, and that's the only requirement of V2.

So it wont be as trivial as I initially thought, but we should do this in the near future IMO (both changes independent of eachother are improvements, (2) is a compile time improvement independent of SDY support since it reduces program size on large distributed programs by a lot.

GleasonK avatar Jun 13 '25 16:06 GleasonK

@GleasonK Thanks so much for looking into this! Regarding (2), could you give me a code pointer to where PTXLA is currently building the V1 annotations?

Also, regarding (1), we found this pipeline within the OpenXLA repo that converts modules with mhlo.shardings into the Shardy dialect: https://github.com/openxla/xla/blob/d6beb849cf6ac93f419d6be27a9d1eac1b72420a/xla/service/spmd/shardy/stablehlo_round_trip/stablehlo_import.cc#L678

We tested it with the SHLO output of some TorchXLA programs and see that (2) needs to happen before it converts the GSPMD module to Shardy (otherwise it just errors out). Is this also what you had in mind for the GSPMD->Shardy conversion path?

hshahTT avatar Jun 13 '25 17:06 hshahTT

Code pointer on (2) is probably around here (I'd assume contributions are welcome here, its functionally the same and JAX has already migrated to use OpShardingV2): https://github.com/pytorch/xla/blob/3a1ed628d0e5fcc2a6d2c1b37fb2493020d28670/torch_xla/distributed/spmd/xla_sharding.py#L652

It generates an OpSharding of the old format. I think @sdasgup3 may have ideas on how this would look at a V2.

And RE (1) yeah that's exactly the pass I think. But tt-xla shouldn't need a dep on XLA just for that pass - PTXLA already has a dep so we can instead call that pass here once (2) is addressed.

GleasonK avatar Jun 13 '25 18:06 GleasonK

config.update("jax_use_shardy_partitioner", True)

fwiw, after setting the option above, it only changes what IR jax emits and not what torchxla emits. However, given that we have call_jax feature launched recently, we can probably get shardy support with composition of call_jax and jax.lax.with_sharding_constraint.

This is akin to how shard_as is implemented in torch_xla: https://github.com/pytorch/xla/blob/bea86eec296f407cf8b689a8ef605914fe9b9797/torch_xla/distributed/spmd/xla_sharding.py#L726

cc. @tengyifei

qihqi avatar Jun 13 '25 23:06 qihqi

Furthermore, if you wrap a PyTorch function with https://docs.pytorch.org/xla/master/perf/assume_pure.html, and enable the jax Shardy config, then any xs.mark_sharding call within that PyTorch function will be staged out with Shardy annotations.

That's because @assume_pure lowers your PyTorch function using torchax to JAX then to HLO.

tengyifei avatar Jun 14 '25 02:06 tengyifei

Hey everyone, sorry for the long message.

I'm working on adding OpShardingV2 annotations to torch/xla (mentioned by @GleasonK as work item 2 here).

I found the following documentation that gives some examples and explanations of how to create the V2 sharding annotations:

  • https://github.com/jax-ml/jax/blob/84af8a8e74c05ce4196079e145d50f0c9504ff16/jax/_src/named_sharding.py#L414
  • https://github.com/openxla/xla/blob/bdfcd696c1c4311b40ccd081f5286011199d931b/xla/hlo/ir/tile_assignment.h#L50

To make sure that I understand the format correctly, could someone with more knowledge please confirm whether the dims, reshape_dims, and transpose_perm lists for the following examples are correct (so I know if I'm on the right track)?

In all cases we define a 2D device mesh with x and y axes in a [2, 4] shape (the x axis has 2 devices and y has 4):

mesh = Mesh([0,...,7], (2, 4), ("x", "y"))

Also, in all cases a 2D tensor is being sharded. The reshape_dims will also be [2, 4] everywhere.

  1. PartitionSpec = ('x', 'y') a. dims = [2, 4] b. transpose_perm = [0, 1]

  2. PartitionSpec = ('y', 'x') a. dims = [4, 2] b. transpose_perm = [1, 0]

  3. PartitionSpec = ('x', None) a. dims = [2, 1, 4] b. transpose_perm = [0, 1]

  4. PartitionSpec = ('y', None) a. dims = [4, 1, 2] b. transpose_perm = [1, 0]

  5. PartitionSpec = (None, 'x') a. dims = [1, 2, 4] b. transpose_perm = [0, 1]

  6. PartitionSpec = (None, 'y') a. dims = [1, 4, 2] b. transpose_perm = [1, 0]

  7. PartitionSpec = (('x', 'y'), None) a. dims = [8, 1] b. transpose_perm = [0, 1]

  8. PartitionSpec = (('y', 'x'), None) a. dims = [8, 1] b. transpose_perm = [1, 0]

  9. PartitionSpec = (None, ('x', 'y')) a. dims = [1, 8] b. transpose_perm = [0, 1]

  10. PartitionSpec = (None, ('y', 'x')) a. dims = [1, 8] b. transpose_perm = [1, 0]

hshahTT avatar Jun 25 '25 13:06 hshahTT

I think dims should always be [2,4], and the rest is reshape and transpose ('y', 'x') -> [2,4]T(1,0). But not fully sure how the translation from named axes into numbered device mesh works or what None does to axes dims, cc @tomnatan30 for that.

GleasonK avatar Jun 26 '25 21:06 GleasonK

Hey!

The transpose perm is used to reorder the axes in the mesh so they can be used in any order.

So in the above example 4 you actually need transpose perm [1, 0] because y is used before x:

PartitionSpec = ('y', None)

a. dims = [4, 1, 2] b. transpose_perm = [0, 1]

There is also the reshape dims (which must be specified), which can merge consecutive axes in the mesh if they are kept consecutive, for example:

Mesh - [x:2, y:2, z:2]

Pspec = (z, (x, y)) a. dims = [2, 4] b. reshape = [4, 2] c. Transpose = [1, 0]

Finally, this allows you to use sub axes, e.g. for axis x of size 4 you can use a sub axis x:(1)2 (offset 1, size 2) or x:(2)2 (offset 2, size 2). Shardy also supports this with the notation β€œx”:(offset)size.

Hope this helps!

On Thu, 26 Jun 2025 at 22:53, Kevin Gleason @.***> wrote:

GleasonK left a comment (pytorch/xla#9348) https://github.com/pytorch/xla/issues/9348#issuecomment-3010258723

I think dims should always be [2,4], and the rest is reshape and transpose ('y', 'x') -> [2,4]T(1,0). But not fully sure how the translation from named axes into numbered device mesh works or what None does to axes dims, cc @tomnatan30 https://github.com/tomnatan30 for that.

β€” Reply to this email directly, view it on GitHub https://github.com/pytorch/xla/issues/9348#issuecomment-3010258723, or unsubscribe https://github.com/notifications/unsubscribe-auth/A7DIFHZN3DRO4DPH6V36PTT3FRTWVAVCNFSM6AAAAAB7FP35K2VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZTAMJQGI2TQNZSGM . You are receiving this because you were mentioned.Message ID: @.***>

tomnatan30 avatar Jun 27 '25 06:06 tomnatan30

That said, if you already have a mesh with named axes, how hard would it be to go directly to shardy shardings like JAX does? Or is this just for understanding v1 vs v2 hlo sharding?

On Fri, 27 Jun 2025 at 07:31, Tom Natan @.***> wrote:

Hey!

The transpose perm is used to reorder the axes in the mesh so they can be used in any order.

So in the above example 4 you actually need transpose perm [1, 0] because y is used before x:

PartitionSpec = ('y', None)

a. dims = [4, 1, 2] b. transpose_perm = [0, 1]

There is also the reshape dims (which must be specified), which can merge consecutive axes in the mesh if they are kept consecutive, for example:

Mesh - [x:2, y:2, z:2]

Pspec = (z, (x, y)) a. dims = [2, 4] b. reshape = [4, 2] c. Transpose = [1, 0]

Finally, this allows you to use sub axes, e.g. for axis x of size 4 you can use a sub axis x:(1)2 (offset 1, size 2) or x:(2)2 (offset 2, size 2). Shardy also supports this with the notation β€œx”:(offset)size.

Hope this helps!

On Thu, 26 Jun 2025 at 22:53, Kevin Gleason @.***> wrote:

GleasonK left a comment (pytorch/xla#9348) https://github.com/pytorch/xla/issues/9348#issuecomment-3010258723

I think dims should always be [2,4], and the rest is reshape and transpose ('y', 'x') -> [2,4]T(1,0). But not fully sure how the translation from named axes into numbered device mesh works or what None does to axes dims, cc @tomnatan30 https://github.com/tomnatan30 for that.

β€” Reply to this email directly, view it on GitHub https://github.com/pytorch/xla/issues/9348#issuecomment-3010258723, or unsubscribe https://github.com/notifications/unsubscribe-auth/A7DIFHZN3DRO4DPH6V36PTT3FRTWVAVCNFSM6AAAAAB7FP35K2VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZTAMJQGI2TQNZSGM . You are receiving this because you were mentioned.Message ID: @.***>

tomnatan30 avatar Jun 27 '25 06:06 tomnatan30

Thanks @tomnatan30!

nit:

So in the above example 4 you actually need transpose perm [1, 0] because y is used before x

In the example provided

4. PartitionSpec = ('y', None)
a. dims = [4, 1, 2]
b. transpose_perm = [1, 0]

already had the suggested [1,0] transpose perm. Do you expect that to be reversed?

sdasgup3 avatar Jun 27 '25 18:06 sdasgup3

This is how I see example 4 in this thread, maybe it’s outdated?

PartitionSpec = ('y', None) a. dims = [4, 1, 2] b. transpose_perm = [0, 1]

On Fri, 27 Jun 2025 at 19:33, Sandeep Dasgupta @.***> wrote:

sdasgup3 left a comment (pytorch/xla#9348) https://github.com/pytorch/xla/issues/9348#issuecomment-3014037614

Thanks @tomnatan30 https://github.com/tomnatan30!

nit:

So in the above example 4 you actually need transpose perm [1, 0] because y is used before x

In the example provided

  1. PartitionSpec = ('y', None) a. dims = [4, 1, 2] b. transpose_perm = [1, 0]

already had the suggested [1,0] transpose perm. Do you expect that to be reversed?

β€” Reply to this email directly, view it on GitHub https://github.com/pytorch/xla/issues/9348#issuecomment-3014037614, or unsubscribe https://github.com/notifications/unsubscribe-auth/A7DIFHY4G7PQKRJA6QUL3BT3FWFALAVCNFSM6AAAAAB7FP35K2VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZTAMJUGAZTONRRGQ . You are receiving this because you were mentioned.Message ID: @.***>

tomnatan30 avatar Jun 27 '25 18:06 tomnatan30

Ah, it seems outdated! nvm then :)

sdasgup3 avatar Jun 27 '25 18:06 sdasgup3

Thanks for the responses everyone! Sorry for the confusion regarding example 4. I had edited the comment shortly after I posted it once I realized the typo, but it seems like that update didn't get sent via email.

@GleasonK Re:

I think dims should always be [2,4], and the rest is reshape and transpose ('y', 'x') -> [2,4]T(1,0)

When I run the ('y', 'x') sharding using the latest stable release of torch-xla I get devices=[4,2]0,4,1,5,2,6,3,7 (not [2,4] as I'd expect from your comment). Could you confirm whether this is expected? With my V2 code it becomes devices=[4,2]<=[2,4]T(1,0).

@tomnatan30 Regarding your example:

Mesh - [x:2, y:2, z:2]

Pspec = (z, (x, y)) a. dims = [2, 4] b. reshape = [4, 2] c. Transpose = [1, 0]

Sorry could you please explain how you got the values for reshape and transpose? From the documentation mentioned here I would've thought that the correct dims, reshape, and transpose for the given Pspec would be:

  • dims = [2, 4]
  • reshape = [2, 2, 2]
  • transpose = [2, 0, 1]

but I'm probably missing something. I don't quite understand what you mean by "[the reshape dims] can merge consecutive axes in the mesh if they are kept consecutive"

Also, re:

how hard would it be to go directly to shardy shardings like JAX does?

The approach we're taking to go to shardy shardings is:

  1. Allow the option of using HloShardingV2 from within torch-xla (like how JAX uses V2)
  2. Use this OpenXLA pass to convert the SHLO with V2 shardings to Shardy.

For (1) I have this implementation that gets me a V2 sharding without any errors for all the above examples, except 8 and 10: https://github.com/pytorch/xla/commit/af5cd2dca33bea5395a675d0a062e69453c38e26#diff-76bd84e4abe22701ee8697bf77e9fc97e19b6d6ff05175f2dc87f938f3a88837

The branch is: https://github.com/tenstorrent/pytorch-xla/tree/hshah/opsharding-v2

For (2) I have this implementation that runs the aforementioned pass to convert the SHLO to Shardy: https://github.com/pytorch/xla/commit/ee0f1fbd4095631281f4e6cf8e570bbe498a3ad8

hshahTT avatar Jun 30 '25 12:06 hshahTT

Hey!

First of all, there are multiple ways of representing certain shardings in v2, and sometimes you can canonicalize into a more compact version - when you have consecutive dims in transpose (0 ,1), those can be merged into one.

Your version can be canonicalized into:

  • dims = [2, 4]
  • reshape = [2, 4]
  • transpose = [1, 0]

but you are right that this is correct and mine had a mistake, the transpose tells you the order in which mesh axes are places for the sharding (dims).

On Mon, Jun 30, 2025 at 1:14β€―PM Het Shah @.***> wrote:

hshahTT left a comment (pytorch/xla#9348) https://github.com/pytorch/xla/issues/9348#issuecomment-3018923977

Thanks for the responses everyone! Sorry for the confusion regarding example 4. I had edited the comment shortly after I posted it once I realized the typo, but it seems like that update didn't get sent via email.

@GleasonK https://github.com/GleasonK Re:

I think dims should always be [2,4], and the rest is reshape and transpose ('y', 'x') -> [2,4]T(1,0)

When I run the ('y', 'x') sharding using the latest stable release of torch-xla I get devices=[4,2]0,4,1,5,2,6,3,7 (not [2,4] as I'd expect from your comment). Could you confirm whether this is expected? With my V2 code it becomes devices=[4,2]<=[2,4]T(1,0).

@tomnatan30 https://github.com/tomnatan30 Regarding your example:

Mesh - [x:2, y:2, z:2]

Pspec = (z, (x, y)) a. dims = [2, 4] b. reshape = [4, 2] c. Transpose = [1, 0]

Sorry could you please explain how you got the values for reshape and transpose? From the documentation mentioned here https://github.com/jax-ml/jax/blob/84af8a8e74c05ce4196079e145d50f0c9504ff16/jax/_src/named_sharding.py#L414 I would've thought that the correct dims, reshape, and transpose for the given Pspec would be:

  • dims = [2, 4]
  • reshape = [2, 2, 2]
  • transpose = [2, 0, 1]

but I'm probably missing something. I don't quite understand what you mean by "[the reshape dims] can merge consecutive axes in the mesh if they are kept consecutive"

Also, re:

how hard would it be to go directly to shardy shardings like JAX does?

The approach we're taking to go to shardy shardings is:

  1. Allow the option of using HloShardingV2 from within torch-xla (like how JAX uses V2)
  2. Use this https://github.com/openxla/xla/blob/d6beb849cf6ac93f419d6be27a9d1eac1b72420a/xla/service/spmd/shardy/stablehlo_round_trip/stablehlo_import.cc#L678 OpenXLA pass to convert the SHLO with V2 shardings to Shardy.

For (1) I have this implementation that gets me a V2 sharding without any errors for all the above examples, except 8 and 10: af5cd2d #diff-76bd84e4abe22701ee8697bf77e9fc97e19b6d6ff05175f2dc87f938f3a88837 https://github.com/pytorch/xla/commit/af5cd2dca33bea5395a675d0a062e69453c38e26#diff-76bd84e4abe22701ee8697bf77e9fc97e19b6d6ff05175f2dc87f938f3a88837

The branch is: https://github.com/tenstorrent/pytorch-xla/tree/hshah/opsharding-v2

For (2) I have this implementation that runs the aforementioned pass to convert the SHLO to Shardy: ee0f1fb https://github.com/pytorch/xla/commit/ee0f1fbd4095631281f4e6cf8e570bbe498a3ad8

β€” Reply to this email directly, view it on GitHub https://github.com/pytorch/xla/issues/9348#issuecomment-3018923977, or unsubscribe https://github.com/notifications/unsubscribe-auth/A7DIFHZVP4DQYVERX3NVC5D3GESYVAVCNFSM6AAAAAB7FP35K2VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZTAMJYHEZDGOJXG4 . You are receiving this because you were mentioned.Message ID: <pytorch/xla/issues/9348/3018923977 @.***>

tomnatan30 avatar Jun 30 '25 12:06 tomnatan30

Hey everyone, I added a PR for shardy support: https://github.com/pytorch/xla/pull/9541. Could folks please look at it for correctness when you get a chance? Want to make sure I haven't missed any obvious cases. So far it has worked with the sharding specs I tested with on Tenstorrent hardware, and we were able to get a tensor parallel llama3.1-8B model working with it in SPMD mode.

hshahTT avatar Aug 05 '25 14:08 hshahTT

Hey all, could I get some eyes on the above PR please: https://github.com/pytorch/xla/pull/9541?

hshahTT avatar Aug 14 '25 02:08 hshahTT

V2 HloSharding represents the device list in a compact way. The relation between V1 and V2 is

numpy.arange(num_devices).reshape(reshape_dims).transpose(transpose_perm).reshape(num_devices)

For example, the following two are equivalent.

V1: devices=[2,2,4]0,8,4,12,1,9,5,13,2,10,6,14,3,11,7,15
V2: devices=[2,2,4]<=[2,2,4]T(2,1,0)

ZixuanJiang avatar Aug 15 '25 22:08 ZixuanJiang