exojax icon indicating copy to clipboard operation
exojax copied to clipboard

[opart] treeverse for layer scan?

Open HajimeKawahara opened this issue 2 months ago • 0 comments

Using treeverse for the computation of each layer might alleviate the memory-time tradeoff problem, so this will be investigated. It seems that equinox has implemented a scan with checkpointing functionality.

https://github.com/patrick-kidger/equinox/blob/main/equinox/internal/_loop/checkpointed.py

  • jax checkpoint https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html#jax.checkpoint

  • An intuitive explanation of treeverse https://github.com/GiggleLiu/TreeverseAlgorithm.jl

HajimeKawahara avatar Dec 18 '24 01:12 HajimeKawahara