Jake Vanderplas

Results 527 comments of Jake Vanderplas

Hi - JAX does not currently do much with static typing, beyond some scattered uses of `jnp.ndarray` and frequent annotations with aliases like `Array = Any`. We're currently exploring doing...

This is expected behavior; see the description from the CHANGELOG entry when `pickle` support was added: https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0314-june-27-2022 The issue is that pickling and unpickling need not happen in the same...

Unpickling of an array [uses `jax.device_put`](https://github.com/google/jax/blob/3243e23aa528db390fdece3fa32517c28a50318b/jax/_src/device_array.py#L326) with no device argument, so I believe the default default device context manager should do the right thing.

Reopening and assigning to @skye – it looks like `device_put` does not respect the default device context. Is this intended?

Thanks - echoing our conversation from elsewhere: this is a failure we've seen periodically before; it's a strange one that is hard to produce. It almost looks like there's some...

OK, I was able to reproduce it... it has something to do with jit cacheing I think: ```python from functools import partial from jax import jit class CustomClass: def __init__(self,...

(taking over review/merge, since this is related to #12049)

Hmm, since this is an XLA cacheing thing, I don't think our benchmark framework as currently written can catch it, because after the first execution subsequent executions will be fast.

I'm not following – does `state.__iter__` do some sort of XLA-level cache clear?

I don't see any easy way to benchmark this in our benchmark framework, because (if my understanding is correct) it would require clearing XLA's internal cache on each iteration, and...