torchtitan
torchtitan copied to clipboard
[fused_rmsnorm] Avoid querying device inside forward
Stack from ghstack (oldest at bottom):
- -> #301
- #300
- #161
Get sm_count another way to work around issues with meta-device tracing
Note: this PR isn't strictly safe as it burns in device 0's sm count