clad icon indicating copy to clipboard operation
clad copied to clipboard

Add primitive support of reverse-mode constructor custom derivatives

Open infinite-void-16 opened this issue 6 months ago • 4 comments

This commit adds primitive support of reverse-mode custom derivatives for constructors.

Constructors can thought of like a special function. A function which also constructs an object. Therefore, differentiation of a constructor is very similar to the differentiation of a function call. Please note that this includes innocent looking initialization such as:

SomeClass C1 = C2;

This would involve differentiation of copy-constructor of SomeClass.

Let's see a concrete example of how differentiation of a constructor call looks like:

Original code:

SomeClass c(u, v, w);

Derivative code:

// forward-pass
clad::ValueAndAdjoint<SomeClass, SomeClass> _t0 =
  constructor_reverse_forw(clad::ConstructorReverseForwTag<SomeClass>,
    u, v, w, _d_u, _d_v, _d_w);
SomeClass _d_c(_t0.adjoint);
SomeClass c(_t0.value);

// reverse-pass
{
  double _r0 = 0;
  double _r1 = 0;
  double _r2 = 0;
  constructor_pullback(&c, u, v, w, &_d_c, &_r0, &_r1, &_r2);
  _d_u += _r0;
  _d_v += _r1;
  _d_w += _r2;
}

Please note the use of clad::ConstructorReverseForwTag<T> for constructor_reverse_forw function. The reasoning and motivation for this is the same as we have for clad::ConstructorPushforwardTag<T>

Reverse-mode custom derivatives for constructor can be specified as:

namespace clad {
namespace custom_derivatives {
namespace class_functions {
clad::ValueAndAdjoint<SomeClass, SomeClass>
constructor_reverse_forw(clad::ConstructorReverseForwTag<SomeClass>,
  ...) {
  // ...
  // ...
}

void constructor_pullback(SomeClass *c, ..., SomeClass *d_c, ...) {
  // ...
}

infinite-void-16 avatar Aug 18 '24 06:08 infinite-void-16