exojax
exojax copied to clipboard
[opart] treeverse for layer scan?
Using treeverse
for the computation of each layer might alleviate the memory-time tradeoff problem, so this will be investigated. It seems that equinox
has implemented a scan
with checkpointing functionality.
https://github.com/patrick-kidger/equinox/blob/main/equinox/internal/_loop/checkpointed.py
-
jax checkpoint https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html#jax.checkpoint
-
An intuitive explanation of treeverse https://github.com/GiggleLiu/TreeverseAlgorithm.jl