mamba-jax icon indicating copy to clipboard operation
mamba-jax copied to clipboard

[Tracking] Ensuring correctness of implementation wrt original implementation

Open vvvm23 opened this issue 2 years ago • 0 comments

Though this repository already can be used for sampling from pretrained models, there are no test scripts to guarantee matching results between this implementation and the original one. Particularly risky areas include parameter initialisation, differing defaults between the JAX and torch function implementations, and correctly handling mixed precision.

I don't expect it to match perfectly due to JAX using XLA which torch does not use. However, it should be pretty close.

This issue tracks progress towards this goal.

vvvm23 avatar Jan 10 '24 11:01 vvvm23