PyTorch Backend: Missing autograd wrapper for se_t descriptor's second-order derivatives
Summary
The se_t descriptor (TabulateFusionSeTOp) does not seem to support second-order derivative calculations (required for virial/stress), even though the underlying C++/CUDA kernel (deepmd::tabulate_fusion_se_t_grad_grad_gpu) and its wrapper function (TabulateFusionSeTGradGradForward) are present in the code.
DeePMD-kit Version
devel latest
Backend and its version
pytorch
Python Version, CUDA Version, GCC Version, LAMMPS Version, etc
No response
Details
-
Working Implementation (se_a): The se_a descriptor has a complete autograd chain. Its first-derivative calculation is wrapped in TabulateFusionSeAGradOp, whose backward method correctly calls TabulateFusionSeAGradGradForward. This makes the operation fully second-order differentiable.
-
Incomplete Implementation (se_t): For the se_t descriptor, the TabulateFusionSeTOp's backward method calls TabulateFusionSeTGradForward directly. There is no corresponding TabulateFusionSeTGradOp class to wrap this first-derivative calculation.
-
The Consequence: Because the autograd wrapper (GradOp) is missing, the TabulateFusionSeTGradGradForward function is defined but is never actually called by the PyTorch autograd engine. This effectively makes it dead code in the context of PyTorch's automatic differentiation and prevents training models with virial/stress labels when using the se_t descriptor.
@OutisLi The virial/stress only needs first-order derivative as same as force, while the second-order derivative calculations are for training.
@njzjz I'm also confusing that it seems that even for se_a, the TabulateFusionSeAOp in source/op/pt/tabulate_multi_device.cc only supports first-order derivative since it does not link to TabulateFusionSeAGradOp?
@njzjz I'm also confusing that it seems that even for se_a, the
TabulateFusionSeAOpinsource/op/pt/tabulate_multi_device.cconly supports first-order derivative since it does not link toTabulateFusionSeAGradOp?
Yes, it's an incorrect implementation that doesn't support higher-order derivatives. A correct backward function refers to the code below
https://github.com/deepmodeling/deepmd-kit/blob/41e62d659ebfa47d9d7af7ba831bb7739b7dd824/source/op/pt/thsilu.cc#L230-L241