pytorch3d
pytorch3d copied to clipboard
Support for higher order derivatives
🚀 Feature
Compatibility with calculation of higher derivatives (second or higher). Currently several functions are @once_differentiable
Motivation
Second derivatives would allow the use of optimisers that require curvature information, and the optimisation of components placed after the renderer to induce properties in the gradients of the inputs to the renderer (my particular application).
Pitch
In my application, I would like to be able to execute something like the following:
translationGradients, rotationQuaternionGradients = grad(outputs = pixelLoss, inputs = [translations, rotationQuaternions], grad_outputs = None, retain_graph = True, create_graph = True, only_inputs = True, allow_unused = False)
loss = (translationGradients- translationGradientTargets).pow(2).mean() + (rotationQuaternionGradients - rotationGradientTargets).pow(2).mean()
loss.backward()
Am I right in thinking that you just want second derivatives using the renderer? So it is only packed_to_padded and interpolate_face_attributes which matter, and you want one more derivative not arbitrarily many?
@bottler Thanks, I just need second derivatives. Essentially what I want to do is place a learned network after the renderer and before a loss function, and then train the network to induce certain characteristics in the gradients of the inputs to the renderer. This requires me to do one differentiation step from the loss through the network and the renderer to the inputs of the renderer, calculate a second loss based on the gradients, and then differentiate the second loss to generate gradients for the network training. It's similar in implementation to a gradient penalty loss term.
I agree this is a useful feature, but I think it would be very nontrivial to implement. Adding second derivatives for packed_to_padded and interpolate_face_attributes probably wouldn't be too difficult, but I think adding second derivatives for rasterization would be a pretty big undertaking.
I also realized that rasterization isn't currently marked as once_differentiable, but it probably should be as well:
https://github.com/facebookresearch/pytorch3d/blob/master/pytorch3d/renderer/mesh/rasterize_meshes.py#L214