Iurii Kemaev

Results 14 comments of Iurii Kemaev

Hey, thanks for the question - yes, it can be added! Would you like to send a PR?

Hi @fvisin, thanks for opening this issue. And thanks for volunteering @broper2! `static_argnames` kwarg is not supported in the original [pmap](https://jax.readthedocs.io/en/latest/jax.html#jax.pmap) because of `in_axes` incompatibility with kwargs (see the docstring)....

That's a good point! We can prohibit users from using *both* `static_argnames` and `static_argnums` at the same time in variants' args, wdyt?

Ah I see, it all makes sense now, TY! @broper2 are you interested in preparing the fix?

NP, thanks for letting me know @broper2!

Unfortunately, this change breaks multiple internal tests because as you noticed, pytype doesn't properly support this feature :( When I suppress this warning, pytype generates a bunch of other errors...

Hi @wookayin, thanks for reporting this issue. @thomkeh that would be very helpful! I'm happy to review your PR.

Hi everyone, thanks for flagging it up. I just merged a new version of `optax.MultiSteps` which should be more memory friendly, could you check this please?

Awesome work @celiolarcher! `jax.lax.cond` seems to be suboptimal in some use cases, e.g. here, in theory, it should understand that either `_mid_step` or `_final_step` needs to be executed, so it...

@mcleish7 it should be fixed in https://github.com/deepmind/clrs/commit/2b37ff3f6d56b2e2e43806b7c5635282e888f505