Jake Vanderplas
Jake Vanderplas
Hi - JAX does not currently do much with static typing, beyond some scattered uses of `jnp.ndarray` and frequent annotations with aliases like `Array = Any`. We're currently exploring doing...
This is expected behavior; see the description from the CHANGELOG entry when `pickle` support was added: https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0314-june-27-2022 The issue is that pickling and unpickling need not happen in the same...
Unpickling of an array [uses `jax.device_put`](https://github.com/google/jax/blob/3243e23aa528db390fdece3fa32517c28a50318b/jax/_src/device_array.py#L326) with no device argument, so I believe the default default device context manager should do the right thing.
Reopening and assigning to @skye – it looks like `device_put` does not respect the default device context. Is this intended?
Thanks - echoing our conversation from elsewhere: this is a failure we've seen periodically before; it's a strange one that is hard to produce. It almost looks like there's some...
OK, I was able to reproduce it... it has something to do with jit cacheing I think: ```python from functools import partial from jax import jit class CustomClass: def __init__(self,...
(taking over review/merge, since this is related to #12049)
Hmm, since this is an XLA cacheing thing, I don't think our benchmark framework as currently written can catch it, because after the first execution subsequent executions will be fast.
I'm not following – does `state.__iter__` do some sort of XLA-level cache clear?
I don't see any easy way to benchmark this in our benchmark framework, because (if my understanding is correct) it would require clearing XLA's internal cache on each iteration, and...