lightning-thunder
lightning-thunder copied to clipboard
Further improve`test_grad.py::check_vjp_correctness`
PR https://github.com/Lightning-AI/lightning-thunder/pull/2618 is not a perfect patch to temporarily speed up CI jobs targeting check_vjp_correctness tests. This work maintained this mathematical identity to perform correctness checks; instead of relying on a finite difference method to compute the Jacobian vector product, it employed torch.func.jvp, which turned out to be much faster.
This is a valid approach, but we could improve the total test time by asserting different mathematical concepts. For example this with this.
Some issues persist after this work:
- Most of the tests are run with the faster
torch.func.jvp, but in some cases, we discovered relevant numerical discrepancies between the function that Thunder runs and the one Torch runs (eager), especially when executors are involved. Those are whitelisted and run with by using fdm. - To make
torch.func.jvpwork on some test cases, we had to clone the input tensor inside the sample generators, but this will produce contiguous tensors changing the original striders - Complete types are not differentiable within Torch, but they are for Thunder. If the input has a complete type, then fdm is employed instead of the Torch API.
cc @borda @mruberry