Enzyme
Enzyme copied to clipboard
How to pass pointer to class member function to enzyme?
Consider following pointer to class member function
class MyFunctions{
...
void add_mul_const(double * a, double * b, double * out);
};
void __enzyme_autodiff(
void (MyFunctions::* )(double * , double * , double *),
int , double *, double *,
int, double *, double *,
int, double *, double *);
int main(){
...
void (MyFunctions::* f)(double * , double * , double *) = &MyFunctions::add_mul_const;
// FWD
(mf.*f)(a,b,out);
//Derivative
//__enzyme_autodiff(mf.*f, enzyme_dup, a, da, enzyme_dup, b, db, enzyme_dup, out, dout); ?
// __enzyme_autodiff(f, enzyme_dup, a, da, enzyme_dup, b, db, enzyme_dup, out, dout); ?
}
What will be correct format to pass pointer to class member function to Enzyme? Both of my attempts yielded errors
warning: Cannot cast __enzyme_autodiff primal argument 1, found i64 0, type i64 - to arg 0 %class.MyFunctions* [-Wpass-failed=enzyme]
int main(){
^
error: <unknown>:0:0: EnzymeFailure when replacing __enzyme_autodiff calls in main
1 warning and 1 error generated.
and
error: reference to non-static member function must be called
__enzyme_autodiff(mf.*f, enzyme_const, a, enzyme_dup, b, db, enzyme_dup, out, dout);
Also is there anyway I can calculate derivative functions of these memeber functions without having to instantiate the class first?
Taking hints from here, I bypassed the pointer to class requirement by creating a wrapper function as :
void wrapper(double * a, double * b, double * out, MyFunctions * mf){
mf->add_mul_const(a, b,out);
}
followed by the enzyme call as
__enzyme_autodiff(wrapper, enzyme_dup, a, da, enzyme_dup, b, db, enzyme_dup, out, dout, enzyme_const, &mf);
So if needed this issue can be closed, but I was curious as to how can we pass such pointers to functions?
This remains something we should add syntactic sugar or a nicer internal handler of. The issue is that member functions introduce an additional argument (whether virtual) which is confused for the function you want to differentiate.
For now you indeed found the correct workaround, but let's leave this open as something to clean up moving forwards.