jax
jax copied to clipboard
🔪 Remaining Sharp Bit TODOs 🔪
We could do with sprucing up the Sharp Bits with common problems we've encountered in user code since it was first written.
Top of the list is documenting matmul / conv op precision issues:
- bf16 multiplication defaults! bad for simulation / classic numerics.
- context manager for precision
We should add some other ideas here.
Just adding context: context manager for precision is defined here and there are some words about it in #6143.
Others sharp bits:
- OOB accesses don't by default raise errors and silently clip or drop! This is already in there actually, but extend it a bit and mention the
mode
argument forat
syntax (and add link tocheckify
) - Can't cover it in any detail and not a JAX issue per se, but probably worth mentioning the general dangers of half-precision types: e.g. ease of
float16
overflow/underflow and danger of accumulating intobf16
. - accidental recompilation issues:
- hashability of arguments / jit caching behavior
- log-compiles feature for catching accidental recompiles.
- worth a line: the danger of
weak_type=True
for triggering recompilation
- Perhaps too esoteric / tpu-centric : we should mention the RngBitGenerator system for performance, maybe
xla_tpu_spmd_rng_bit_generator_unsafe=true
for oss users of spmd+rng.
Fantastic!
On that first bullet, we could also mention checkify
cc @LenaMartens
Nice, +1 on the danger of weak_type=True for triggering recompilation
, we've had people ask for more documentation on that. I might try and add that.
Also, making sure that the async set of JAX calls used in a training loop don't introduce blocking calls that will kill dispatch pipelining efficiency (e.g. trivial host-side metrics fn or similar) - one of the most common performance mistakes I see (maybe belongs in a separate performance gotchas doc... not sure)
I like the idea of having a new dedicated doc for performance tips and pitfalls
Regarding reworking the Sharp Bits doc, I recently added a section on miscellaneous divergences between numpy and JAX. It might be nice to rearrange things so all the differences between numpy and JAX are listed briefly under a single heading, perhaps with links to deeper discussion later in the doc.
Regarding the "jit caching behavior", is there any chance you could cache the compiled result to the file system so that it can persist across runs? In my development cycle, I typically change some hyperparameters and re-run the experiment. It's a little frustrating that each time I have to wait for the JIT compilation, even if I have compiled the exact same code multiple times.
I am under the impression that this won't be too hard to implement, since we already have a hashing/caching mechanism. All it takes is writing the emitted XLA program to the disk. Should I open a new issue for this?
@nalzok - there is currently an implementation of this, but only for TPU. See https://github.com/google/jax/tree/main/jax/experimental/compilation_cache for details, and https://github.com/google/jax/issues/2490 where this kind of request is tracked.
I have a fairly RNG generation-heavy workload that I am running on Cloud TPU and was googling around to try and understand the xla_tpu_spmd_rng_bit_generator_unsafe
flag but only found this thread and a brief mention in the JAX documentation. The quality of randomness is not critical for me. Am I right in assuming this flags improves performance but at the cost of using a less well-understood algorithm underneath?
@JeppeKlitgaard - yeah, it uses an adhoc method of splitting keys that we don't have theoretical justification for (and in fact we don't really have well established statistical tests for split-chain decorrelation when it comes to splittable PRNG systems). That said, it compiles and runs fast, and it's almost certainly good enough for e.g. dropout masks in the context of SGD training of NNs (and we've used it for that with no observed ill effects for some time). I'd be a bit more careful if I were doing classic MCMC or something.