jax icon indicating copy to clipboard operation
jax copied to clipboard

🔪 Remaining Sharp Bit TODOs 🔪

Open levskaya opened this issue 2 years ago • 9 comments

We could do with sprucing up the Sharp Bits with common problems we've encountered in user code since it was first written.

Top of the list is documenting matmul / conv op precision issues:

  • bf16 multiplication defaults! bad for simulation / classic numerics.
  • context manager for precision

We should add some other ideas here.

levskaya avatar Mar 18 '22 05:03 levskaya

Just adding context: context manager for precision is defined here and there are some words about it in #6143.

mattjj avatar Mar 18 '22 05:03 mattjj

Others sharp bits:

  • OOB accesses don't by default raise errors and silently clip or drop! This is already in there actually, but extend it a bit and mention the mode argument for at syntax (and add link to checkify)
  • Can't cover it in any detail and not a JAX issue per se, but probably worth mentioning the general dangers of half-precision types: e.g. ease of float16 overflow/underflow and danger of accumulating into bf16.
  • accidental recompilation issues:
    • hashability of arguments / jit caching behavior
    • log-compiles feature for catching accidental recompiles.
    • worth a line: the danger of weak_type=True for triggering recompilation
  • Perhaps too esoteric / tpu-centric : we should mention the RngBitGenerator system for performance, maybe xla_tpu_spmd_rng_bit_generator_unsafe=true for oss users of spmd+rng.

levskaya avatar Mar 18 '22 07:03 levskaya

Fantastic!

On that first bullet, we could also mention checkify cc @LenaMartens

mattjj avatar Mar 18 '22 07:03 mattjj

Nice, +1 on the danger of weak_type=True for triggering recompilation, we've had people ask for more documentation on that. I might try and add that.

LenaMartens avatar Mar 18 '22 09:03 LenaMartens

Also, making sure that the async set of JAX calls used in a training loop don't introduce blocking calls that will kill dispatch pipelining efficiency (e.g. trivial host-side metrics fn or similar) - one of the most common performance mistakes I see (maybe belongs in a separate performance gotchas doc... not sure)

levskaya avatar Apr 01 '22 02:04 levskaya

I like the idea of having a new dedicated doc for performance tips and pitfalls

jakevdp avatar Apr 01 '22 02:04 jakevdp

Regarding reworking the Sharp Bits doc, I recently added a section on miscellaneous divergences between numpy and JAX. It might be nice to rearrange things so all the differences between numpy and JAX are listed briefly under a single heading, perhaps with links to deeper discussion later in the doc.

jakevdp avatar Apr 01 '22 02:04 jakevdp

Regarding the "jit caching behavior", is there any chance you could cache the compiled result to the file system so that it can persist across runs? In my development cycle, I typically change some hyperparameters and re-run the experiment. It's a little frustrating that each time I have to wait for the JIT compilation, even if I have compiled the exact same code multiple times.

I am under the impression that this won't be too hard to implement, since we already have a hashing/caching mechanism. All it takes is writing the emitted XLA program to the disk. Should I open a new issue for this?

nalzok avatar Jul 18 '22 15:07 nalzok

@nalzok - there is currently an implementation of this, but only for TPU. See https://github.com/google/jax/tree/main/jax/experimental/compilation_cache for details, and https://github.com/google/jax/issues/2490 where this kind of request is tracked.

jakevdp avatar Jul 18 '22 16:07 jakevdp

I have a fairly RNG generation-heavy workload that I am running on Cloud TPU and was googling around to try and understand the xla_tpu_spmd_rng_bit_generator_unsafe flag but only found this thread and a brief mention in the JAX documentation. The quality of randomness is not critical for me. Am I right in assuming this flags improves performance but at the cost of using a less well-understood algorithm underneath?

JeppeKlitgaard avatar Apr 15 '23 12:04 JeppeKlitgaard

@JeppeKlitgaard - yeah, it uses an adhoc method of splitting keys that we don't have theoretical justification for (and in fact we don't really have well established statistical tests for split-chain decorrelation when it comes to splittable PRNG systems). That said, it compiles and runs fast, and it's almost certainly good enough for e.g. dropout masks in the context of SGD training of NNs (and we've used it for that with no observed ill effects for some time). I'd be a bit more careful if I were doing classic MCMC or something.

levskaya avatar Apr 16 '23 08:04 levskaya