NATTEN
NATTEN copied to clipboard
Support torch.compile
Adds torch.compile support for torch>=2.4.0 by refactoring the ops/functions python interface and optionally registering libnatten ops with torch.library instead of creating autograd ops.
More information on why all of these steps are necessary: https://pytorch.org/tutorials/advanced/python_custom_ops.html
Fixes #89 .
TODOs:
- [x] verify torch.compile works without graph breaks
- [x] make registering with torch.library optional (forward mode AD still doesn't have an interface)
- [ ] torch.compile unit tests
- [ ] verify mixed precision still works with torch.library
- [ ] verify training NAT/DiNAT with torch.compile is functional
This is on hold pending a resolution from https://github.com/pytorch/pytorch/issues/137437.
needs to be redone due to #226.