HRM icon indicating copy to clipboard operation
HRM copied to clipboard

Add universal device support with flash attention fallback

Open jryanhaber opened this issue 5 months ago • 1 comments

Summary

Extends cross-platform compatibility by adding conditional fallbacks for flash attention when not available on Apple Silicon MPS and other non-CUDA systems.

Problem

The HRM codebase crashes on systems without flash_attn (primarily Apple Silicon) due to hard dependency, preventing broader adoption and contribution from non-CUDA users.

Solution

Added conditional fallbacks that preserve original CUDA performance while enabling PyTorch attention fallback on incompatible systems.

Changes

Core Compatibility (models/layers.py)

  • Conditional import with graceful flash_attn fallback
  • Runtime detection: uses flash_attn when available, PyTorch attention otherwise
  • Zero performance impact on CUDA systems
  • Full backward compatibility

Technical Implementation

try:
    from flash_attn_interface import flash_attn_func
except ImportError:
    try:
        from flash_attn import flash_attn_func
    except ImportError:
        flash_attn_func = None

Runtime conditional execution

  if flash_attn_func is not None:
      attn_output = flash_attn_func(q=query, k=key, v=value,
  causal=self.causal)
  else:
      attn_output = F.scaled_dot_product_attention(query, key, value,
  is_causal=self.causal)

Compatibility Matrix

  • CUDA + flash_attn: ✅ Original performance (no changes)
  • CUDA - flash_attn: ✅ PyTorch fallback
  • Apple Silicon (MPS): ✅ PyTorch fallback
  • CPU: ✅ Universal compatibility

Testing

  • Device detection: MPS/CUDA/CPU verified
  • Model loading: 2.2GB checkpoint compatibility confirmed
  • Inference: Cross-platform execution validated
  • Performance: Zero impact on CUDA workflows

Impact

  • Broader hardware support for contributors
  • Maintained optimal performance on CUDA
  • Reduced onboarding friction
  • No breaking changes

This follows the established conditional device pattern while extending compatibility to the attention layer.

jryanhaber avatar Aug 12 '25 06:08 jryanhaber

@jryanhaber Great work, thank you! A README change would be good to make clear what the hardware support is and how to run it.

benman1 avatar Aug 16 '25 16:08 benman1