lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

Fix auto-registered torch.special operators

Open kiya00 opened this issue 1 year ago • 1 comments

Before submitting
  • [ ] Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • [ ] Did you read the contributor guideline, Pull Request section?
  • [ ] Did you make sure to update the docs?
  • [ ] Did you write any new necessary tests?

What does this PR do?

Background: This PR addresses a bug related to the handling of torch.special operators, discovered during the development of PR #976.

torch.special operators has __name__ in the format special_opname, requiring extraction of the actual opname. Similar issues occur with torch.linalg and torch.fft operators.

In this PR:

  • Add function _get_torch_function_name to infer the python call name from the torch module and function
  • Add support for auto-registration of torch.linalg and torch.fft operators and the tests

PR review

Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

kiya00 avatar Aug 16 '24 13:08 kiya00

Hi @t-vi @IvanYashchuk , we could discuss further if https://github.com/Lightning-AI/lightning-thunder/pull/976 is necessary, this is a bug I found along the way, so I split it out and we could review this first. Testing results:

[2024-08-16 13:59:44] thunder/tests/test_auto_register_torchops.py::test_torch_ops_trace[cuda-train-special.i0e] PASSED
[2024-08-16 13:59:45] thunder/tests/test_auto_register_torchops.py::test_torch_ops_trace[cuda-train-special.i1] PASSED
[2024-08-16 13:59:45] thunder/tests/test_auto_register_torchops.py::test_torch_ops_trace[cuda-train-special.i1e] PASSED
[2024-08-16 13:59:45] thunder/tests/test_auto_register_torchops.py::test_torch_ops_trace[cuda-train-special.ndtr] PASSED
...

kiya00 avatar Aug 16 '24 14:08 kiya00