Support PyTorch custom ops when they have meta functions (also autograd)
PyTorch allows to register meta functions with custom ops With that information, we should be able to apply the fallback mechanism that @kiya00 developed for PyTorch operations in thunder.
https://pytorch.org/docs/stable/library.html
There also is support for registering backward rules, which might also be of interest.
cc @apaz-cli
Here's a better link showing how to use the API: https://pytorch.org/tutorials/advanced/python_custom_ops.html#python-custom-ops-tutorial.
What are the implications of including PyTorch custom ops in traces, considering semantic-changing transforms (like autocast, forward-backward, etc.) and overall system compatibility?
What should be the recommended way of registering arbitrary custom operations? Should the torch.library way always be preferred and Thunder would then need to understand PyTorch's format? When should it be recommended to add only Thunder-native registration?
For the transform compatibility:
- Given that the typical use would be custom kernels, I would not overly worry about autocast for at first.
- for backward with
op.register_autogradwe can use @crcrpar 's autograd.Function mechanism as backward is identical to torch.autograd.FunctionSubclass.backward and the torch.autograd.FunctionSubclass.forward isres = op(inps); setup_context(ctx, inps, res); return resfrom the documentation saying:Note that the backward must be a composition of PyTorch-understood operators, which is why we wrapped paste into a custom operator instead of directly using PIL’s paste.
Until thunder achieves world domination, we likely want to support torch's mechanism here at least as a fallback. You would still be able to register custom implementations with the usual mechanisms.