Sergei Lebedev
Sergei Lebedev
Note that mypy is unhappy on the CI. > FWIW, `KeyboardFlags` could be out-sourced to a keyboards.py or something alike, if it feels too misplaced in screens.py. Yeah, this sgtm.
An unfortunate issue with mypy is that it does not type check functions if any type hint is missing (at least not by default). I guess you could use `object`...
Closing. The fix should be available in the latest jaxlib nightly: https://jax.readthedocs.io/en/latest/installation.html#jax-nightly-installation.
It looks like `pl.dot` uses `preferred_element_type=jnp.float32`. However, that doesn't explain the error, and AFAICT there are no casts to fp16 anywhere in the TTIR generated by Pallas.
I'm not familiar with Slurm, but it looks like this is an environment issue. Did you follow the installation instructions in https://jax.readthedocs.io/en/latest/installation.html?
@Pierre-Sassoulas our experience with running pylint as a daemon suggests that leaks definitely will be a problem. Given that any leaking node retains the full AST, so memory consumption get...
I think the minimum cuDNN version JAX supports is 9.0. Does it make sense to remove mask entirely from the JAX API, given that? CC @hawkinsp
Please fix the type checker errors.
Thanks, you just need to squash the commits and the PR is good to go.