Add universal device support with flash attention fallback
Summary
Extends cross-platform compatibility by adding conditional fallbacks for flash attention when not available on Apple Silicon MPS and other non-CUDA systems.
Problem
The HRM codebase crashes on systems without flash_attn (primarily Apple Silicon) due to hard dependency, preventing broader adoption and contribution from non-CUDA users.
Solution
Added conditional fallbacks that preserve original CUDA performance while enabling PyTorch attention fallback on incompatible systems.
Changes
Core Compatibility (models/layers.py)
- Conditional import with graceful flash_attn fallback
- Runtime detection: uses flash_attn when available, PyTorch attention otherwise
- Zero performance impact on CUDA systems
- Full backward compatibility
Technical Implementation
try:
from flash_attn_interface import flash_attn_func
except ImportError:
try:
from flash_attn import flash_attn_func
except ImportError:
flash_attn_func = None
Runtime conditional execution
if flash_attn_func is not None:
attn_output = flash_attn_func(q=query, k=key, v=value,
causal=self.causal)
else:
attn_output = F.scaled_dot_product_attention(query, key, value,
is_causal=self.causal)
Compatibility Matrix
- CUDA + flash_attn: ✅ Original performance (no changes)
- CUDA - flash_attn: ✅ PyTorch fallback
- Apple Silicon (MPS): ✅ PyTorch fallback
- CPU: ✅ Universal compatibility
Testing
- Device detection: MPS/CUDA/CPU verified
- Model loading: 2.2GB checkpoint compatibility confirmed
- Inference: Cross-platform execution validated
- Performance: Zero impact on CUDA workflows
Impact
- Broader hardware support for contributors
- Maintained optimal performance on CUDA
- Reduced onboarding friction
- No breaking changes
This follows the established conditional device pattern while extending compatibility to the attention layer.
@jryanhaber Great work, thank you! A README change would be good to make clear what the hardware support is and how to run it.