Sergei Lebedev
Sergei Lebedev
Hey @wenscarl, sorry for the silence. I'm a bit swamped this week, but I will try to look through the PR over the weekend or first thing next week.
Thanks, can you squash the commits please?
Thanks for the contribution! I think #20293 implied a slightly different idea, where profiling is enabled via an environment variable. The version in this PR might work fine for users...
Yep, something like this.
Hi @MoFHeka, could you clarify how your question relates to JAX? JAX does not have a concept of a rank (other than array rank), so `jit` probably cannot do the...
I will let Yash comment on dynamic shapes in `shard_map`, but I suspect that this use-case is not supported (since most of JAX assumes static shapes atm). Re "transform the...
I will defer the answer to Yash, as I'm not yet very familiar with sharding APIs in JAX.
Thanks @Lime-Cakes, this looks like a bug indeed.
I think you probably want to write a backwards kernel anyway, because the automatically derived one (even if we fixed the assertion) is guaranteed to be inefficient.
I don't think we test the lowering logic in x64 and at least on GPU there are known bugs with x64 support. So, I would maybe delay this until we...