axlearn
axlearn copied to clipboard
ssm_enhancement
Pull Request: Enhancements to Mamba and Jamba State-space Models
Summary
This pull request introduces several enhancements to the Mamba and Jamba state-space models (SSMs) implementation, including new recurrence methods, hybrid approaches, and comprehensive testing.
Changes
1. New Recurrence Methods
- HybridMambaRecurrence: Combines different recurrence methods to leverage their strengths.
- AlternativeMambaRecurrence: Implements an alternative recurrence method for potentially better performance or accuracy.
2. Enhancements to ssm.py
- Added
HybridMambaRecurrence
andAlternativeMambaRecurrence
classes. - Updated
MambaMixerLayer
andJambaMixerLayer
to integrate the new recurrence methods.
3. Comprehensive Testing in ssm_test.py
- Added tests for
HybridMambaRecurrence
andAlternativeMambaRecurrence
inMambaMixerLayerTest
. - Added tests for hybrid and alternative recurrence methods in
StackedMambaTest
. - Added tests for hybrid and alternative recurrence methods in
StackedMixedSSMTransformerTest
.
(Documentation and Examples : Updated docstrings and comments to reflect the new features and changes.)
Testing
All new features have been thoroughly tested with the following configurations:
- Different data types (
jnp.float32
,jnp.bfloat16
). - Various model dimensions, state dimensions, and hidden dimensions.
- Integration within
MambaBlock
,JambaMambaBlock
, andStackedMixedSSMTransformerLayer
.
Conclusion
These enhancements provide additional flexibility and options for implementing and experimenting with different recurrence methods in the Mamba and Jamba models, potentially improving performance and accuracy for various tasks.