clad icon indicating copy to clipboard operation
clad copied to clipboard

Incorrect gradient when a reference variable is initialized with a pointer dereference

Open parth-07 opened this issue 1 year ago • 3 comments

Clad generates an incorrect gradient code when a reference variable is initialized with a pointer deference.

Reproducible example:

#include "clad/Differentiator/Differentiator.h"
#include <iostream>
#define show(x) std::cout << #x << ": " << x << "\n";

double fn(double u, double v) {
    double *p = &v;
    double &ref = *p;
    ref += u;
    double res = u + v;
    return res;
}

int main() {
    auto fn_grad = clad::gradient(fn);
    double u = 3, v = 5;
    double du = 0, dv = 0;
    fn_grad.execute(u, v, &du, &dv);
    show(du);
    show(dv);
}

Actual output: du: 1 dv: 1

Expected output: du: 2 dv: 1

Root cause:

double &ref = *p

ref and *p should share the same adjoint variable as reference is just an alias name for the same value / memory location. However, currently, double &ref has it's own separate adjoint variable double _d_ref = 0. This case should be handled similarly to how we handle a function returning a reference/pointer. Therefore, double &ref = *p should be transformed as follows in the forward pass:

double *_d_ref = nullptr;
// ...
double &ref = *p;
_d_ref =  _d_p;

parth-07 avatar Jan 23 '24 20:01 parth-07

I can work on this!

saras1212 avatar Jan 25 '24 19:01 saras1212

Hi @parth-07 @vgvassilev , I would like to contribute to this issue, can you please guide me through this?

kaushal-malpure avatar Mar 09 '24 18:03 kaushal-malpure

Hi @parth-07 , According to my understanding the reference variable doesn't share the same adjoint as the pointer resulting in wrong gradient. For this we could change the gradient function to handle the case where a reference variable is initialized with pointer dereference where we can create a adjoint variable corresponding to reference variable and set it's value equal to the adjoint of pointer. I attempted to solve this by changing the gradient function in clad/differentiator/differentiator.h accordingly, but I'm confused in how do I exactly access the adjoint of the pointer so that I can set it equal to the new adjoint variable created. Can you please help me in this?

kaushal-malpure avatar Mar 14 '24 06:03 kaushal-malpure