axlearn icon indicating copy to clipboard operation
axlearn copied to clipboard

ssm_enhancement

Open vishesh9131 opened this issue 5 months ago • 3 comments

Pull Request: Enhancements to Mamba and Jamba State-space Models

Summary

This pull request introduces several enhancements to the Mamba and Jamba state-space models (SSMs) implementation, including new recurrence methods, hybrid approaches, and comprehensive testing.

Changes

1. New Recurrence Methods

  • HybridMambaRecurrence: Combines different recurrence methods to leverage their strengths.
  • AlternativeMambaRecurrence: Implements an alternative recurrence method for potentially better performance or accuracy.

2. Enhancements to ssm.py

  • Added HybridMambaRecurrence and AlternativeMambaRecurrence classes.
  • Updated MambaMixerLayer and JambaMixerLayer to integrate the new recurrence methods.

3. Comprehensive Testing in ssm_test.py

  • Added tests for HybridMambaRecurrence and AlternativeMambaRecurrence in MambaMixerLayerTest.
  • Added tests for hybrid and alternative recurrence methods in StackedMambaTest.
  • Added tests for hybrid and alternative recurrence methods in StackedMixedSSMTransformerTest.

(Documentation and Examples : Updated docstrings and comments to reflect the new features and changes.)

Testing

All new features have been thoroughly tested with the following configurations:

  • Different data types (jnp.float32, jnp.bfloat16).
  • Various model dimensions, state dimensions, and hidden dimensions.
  • Integration within MambaBlock, JambaMambaBlock, and StackedMixedSSMTransformerLayer.

Conclusion

These enhancements provide additional flexibility and options for implementing and experimenting with different recurrence methods in the Mamba and Jamba models, potentially improving performance and accuracy for various tasks.

vishesh9131 avatar Sep 06 '24 12:09 vishesh9131