George Necula

Results 8 issues of George Necula

This implementation is for the case jax2tf.convert(pjit(f_jax)), that is, the `pjit` appears at the top-level of the function to be lowered.

pull ready

Several forms of loops in JAX support reverse AD: `scan`, `fori_loop` with constant bounds, which is syntactic sugar for `scan`. I think it could be useful to have another syntactic...

enhancement

### Description The `jax.experimental.export` exports the VJP using a synthetic mesh using the first N devices for the export platform. This seemed reasonable because all that is captured in the...

bug

We make the following improvements: * pytree structural disequality messages attempt to localize the mismatch * we check that the rank of the block_shape matches the rank of the overall...

pull ready

This allows one to run most of export_test even if flatbuffers is not installed. Only the serialization and deserialization are skipped.

pull ready

It turns out that a limited form of shape polymorphism is already supported for Pallas call for TPU: the block sizes must be static, but the input and the grid...

pull ready

We have marked the host_callback APIs deprecated on March 21, 2024. They will be removed in October 2024. Users should use instead the new JAX [external callbacks](https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html). # Quick temporary...

enhancement

See the added documentation for `jax._src.export.register_pytree_node_serialization` and `jax._src.export.register_namedtuple_serialization`. Serialization of PyTree nodes is needed to serialize the `in_tree` and `out_tree` of `Exported` function (not to serialize actual instances of the...

pull ready