Enzyme icon indicating copy to clipboard operation
Enzyme copied to clipboard

How to pass pointer to class member function to enzyme?

Open ipcamit opened this issue 2 years ago • 2 comments

Consider following pointer to class member function

class MyFunctions{
...
    void add_mul_const(double * a, double * b, double * out);
};

void __enzyme_autodiff(
     void (MyFunctions::* )(double * , double * , double *),
      int , double *,  double *,
      int, double *, double *,
      int, double *, double *);


int main(){
...
    void (MyFunctions::* f)(double * , double * , double *) = &MyFunctions::add_mul_const;
    // FWD
    (mf.*f)(a,b,out);

    //Derivative
    //__enzyme_autodiff(mf.*f, enzyme_dup, a, da,  enzyme_dup, b, db, enzyme_dup, out, dout); ? 
    // __enzyme_autodiff(f, enzyme_dup, a, da,  enzyme_dup, b, db, enzyme_dup, out, dout);      ?

}

What will be correct format to pass pointer to class member function to Enzyme? Both of my attempts yielded errors

 warning: Cannot cast __enzyme_autodiff primal argument 1, found i64 0, type i64 - to arg 0 %class.MyFunctions* [-Wpass-failed=enzyme]
int main(){
    ^
error: <unknown>:0:0: EnzymeFailure when replacing __enzyme_autodiff calls in main
1 warning and 1 error generated.

and

 error: reference to non-static member function must be called
    __enzyme_autodiff(mf.*f, enzyme_const, a,  enzyme_dup, b, db, enzyme_dup, out, dout);
    

Also is there anyway I can calculate derivative functions of these memeber functions without having to instantiate the class first?

ipcamit avatar Apr 25 '22 22:04 ipcamit

Taking hints from here, I bypassed the pointer to class requirement by creating a wrapper function as :

void wrapper(double * a, double * b, double * out, MyFunctions * mf){
    mf->add_mul_const(a, b,out);
}

followed by the enzyme call as

__enzyme_autodiff(wrapper, enzyme_dup, a, da,  enzyme_dup, b, db, enzyme_dup, out, dout, enzyme_const, &mf);

So if needed this issue can be closed, but I was curious as to how can we pass such pointers to functions?

ipcamit avatar Apr 26 '22 22:04 ipcamit

This remains something we should add syntactic sugar or a nicer internal handler of. The issue is that member functions introduce an additional argument (whether virtual) which is confused for the function you want to differentiate.

For now you indeed found the correct workaround, but let's leave this open as something to clean up moving forwards.

wsmoses avatar Apr 27 '22 04:04 wsmoses