Linear-Autoregressive-Similarity-Index icon indicating copy to clipboard operation
Linear-Autoregressive-Similarity-Index copied to clipboard

LASI-PyTorch

Open Na-moe opened this issue 8 months ago • 0 comments

For those that need a PyTorch version of LASI, I provide an unofficial implementation here.

In short, I re-implement LASI in PyTorch with the following differences:

  • Hand-written vectorized alternatives for vmap
    • I wrote a vectorized version for those tensor ops, because PyTorch's vmap is still in preview;
    • If you have vmap in your PyTorch version, you can use it by init the LASI class with LASI(use_vmap=True).
  • No JIT
  • Numberical Errors: The implemented LASI.compute_distance is NOT NUMBERICALLY SAME with the JAX version due to the following reasons:
    1. The numerical accuracy of JAX seems to be lower, which I'm not pretty sure about.

      1.1 80 * jnp.eye(3) / 127.5 != 80 / 127.5 * jnp.eye(3). (The right term is more accurate and is equal to 80 * torch.eye(3) / 127.5)

      1.2 Accumulated error invovled by sum(axis=0).

    2. pinv is not numberically stable, but the error is ignorably small (1e-16).

Na-moe avatar Jun 17 '24 12:06 Na-moe