jax
jax copied to clipboard
Tracking issue: Documentation update
Started with #18269
The plan is to create an entirely new cohesive documentation structure in docs/tutorials, to move sections of the existing docs there, and then to remove the old versions leaving HTTP redirects in their place.
Existing drafts (on the main branch) can be seen at https://jax.readthedocs.io/en/latest/tutorials/: note that this is currently unlisted in the website table of contents.
Indroductory
- [x] Quickstart @jakevdp #18613
- [x] Key Concepts @jakevdp #20347
- [x] ~Thinking in JAX @8bitmp3 #18581~ removed in #20802
- [x] Just-in-time compilation @jakevdp #18627
- [x] Automatic vectorization @jakevdp #18626
- [x] Automatic differentiation @8bitmp3 #18676
- [x] Debugging @8bitmp3 #18582
- [x] Pseudorandom numbers @jakevdp #18631
- [x] Working with PyTrees @mtthss #18864
- [x] Introduction to sharded computation @jakevdp #20779
- [x] Stateful computations @jakevdp #20732
- [ ] Example: Writing a simple neural network
Intermediate
- [x] Parallel computation
- [x] Advanced automatic differentiation @8bitmp3
- [x] Advanced vectorization (collectives, batched map)
- [x] Gradient checkpointing
- [x] Advanced debugging
- [x] External callbacks @jakevdp #18616
- [ ] ~Profiling and performance @dfm~ punting on this for now
Advanced
- [ ] How JAX primitives work
- [ ] Understanding Jaxprs
- [ ] Writing a custom Jaxpr interpreter
- [ ] Ahead-of-time lowering and computation
- [ ] Custom operations for GPUs with C++ and CUDA
- [ ] Pallas: a JAX kernel language
Reference
- [x] Installation @8bitmp3 #18580 moved to main page in #20784
I'm happy to start drafting the "Profiling and performance" tutorial because it's something I'd like to learn more about unless someone else is planning on working on it!
Awesome, consider it assigned 😁