physicsnemo icon indicating copy to clipboard operation
physicsnemo copied to clipboard

Added backend context manager to select SDPA implementation

Open CharlelieLrt opened this issue 4 months ago • 1 comments

PhysicsNeMo Pull Request

Description

Context

PR #954 changed the computation of the Attention forward defined in physicsnemo/models/diffusion/layers.py such that it is now based on torch.nn.scaled_dot_product_attention (instead of the former custom python implementation). This Pytorch API offers improved performance, but it is still in beta and is known to be hardware dependent and sensitive to numerical errors.

Changes

The present PR re-introduces the former attention computation, based on a custom python implementation, as an option. In comparison to torch.nn.scaled_dot_product_attention, this implementation offers worse performance, but better stability and sensitivity to numerical errors. The default forward pass of the Attention class remains based on torch.nn.scaled_dot_product_attention, but the custom python implementation can now be set with a context manager:

from physicsnemo.models.diffusion.layers import Attention
from torch.nn.attention import SDPBackend

# Default: use torch.nn.functional.scaled_dot_product_attention
# without specific backend
y = model(x)

# Use custom python implementation of attention
with Attention.SDPA_backend("python"):
    y = model(x)
    
# Use specific pytorch backend
# for torch.nn.functional.scaled_dot_product_attention
with Attention.SDPA_backend(SDPBackend.FLASH_ATTENTION):
    y = model(x)

Checklist

  • [x] I am familiar with the Contributing Guidelines.
  • [ ] New or existing tests cover these changes.
  • [x] The documentation is up to date with these changes.
  • [x] The CHANGELOG.md is up to date with these changes.
  • [ ] An issue is linked to this pull request.

Dependencies

CharlelieLrt avatar Aug 13 '25 02:08 CharlelieLrt

/blossom-ci

CharlelieLrt avatar Aug 13 '25 02:08 CharlelieLrt