purejaxrl
purejaxrl copied to clipboard
S5: Longer compilation times
Hey, thanks for providing purejaxrl is pretty awesome.
I have used the experimental S5
code that you provide for a part of my research and after version 0.4.27 (same for 0.4.28) of jaxlib
I have been getting 5 times longer compilation times when I increase the n_layers
of the S5
. Any ideas why this might happen?