Linear-Autoregressive-Similarity-Index
Linear-Autoregressive-Similarity-Index copied to clipboard
LASI-PyTorch
For those that need a PyTorch version of LASI, I provide an unofficial implementation here.
In short, I re-implement LASI in PyTorch with the following differences:
- Hand-written vectorized alternatives for
vmap
- I wrote a vectorized version for those
tensor
ops, because PyTorch'svmap
is still in preview; - If you have
vmap
in your PyTorch version, you can use it by init the LASI class withLASI(use_vmap=True)
.
- I wrote a vectorized version for those
- No JIT
- Numberical Errors:
The implemented
LASI.compute_distance
is NOT NUMBERICALLY SAME with the JAX version due to the following reasons:-
The numerical accuracy of JAX seems to be lower, which I'm not pretty sure about.
1.1
80 * jnp.eye(3) / 127.5 != 80 / 127.5 * jnp.eye(3)
. (The right term is more accurate and is equal to80 * torch.eye(3) / 127.5
)1.2 Accumulated error invovled by
sum(axis=0)
. -
pinv
is not numberically stable, but the error is ignorably small (1e-16
).
-