Sergei Lebedev

Results 228 comments of Sergei Lebedev

Hey @wenscarl, sorry for the silence. I'm a bit swamped this week, but I will try to look through the PR over the weekend or first thing next week.

Thanks, can you squash the commits please?

Thanks for the contribution! I think #20293 implied a slightly different idea, where profiling is enabled via an environment variable. The version in this PR might work fine for users...

Hi @MoFHeka, could you clarify how your question relates to JAX? JAX does not have a concept of a rank (other than array rank), so `jit` probably cannot do the...

I will let Yash comment on dynamic shapes in `shard_map`, but I suspect that this use-case is not supported (since most of JAX assumes static shapes atm). Re "transform the...

I will defer the answer to Yash, as I'm not yet very familiar with sharding APIs in JAX.

Thanks @Lime-Cakes, this looks like a bug indeed.

I think you probably want to write a backwards kernel anyway, because the automatically derived one (even if we fixed the assertion) is guaranteed to be inefficient.

I don't think we test the lowering logic in x64 and at least on GPU there are known bugs with x64 support. So, I would maybe delay this until we...