catalyst icon indicating copy to clipboard operation
catalyst copied to clipboard

[MLIR] Specialize active callbacks to their own function

Open erick-xanadu opened this issue 1 year ago • 1 comments

Context: Enzyme allows one to specify custom gradients for specific functions. In order to specify custom gradients for callbacks, callbacks need to be specialized to their own specific functions. E.g., instead of having the following code:

llvm.call @callback(%identifier, %argc, %retc, %0, %1, ..., %m, %n)

And be unable to register custom gradients for @callback. Specialize callbacks to their identifiers like so:

llvm.func @callback_123(%arg0, %arg1, ..., %argm, %argn) {
  %identifier = llvm.constant ...
  %argc = llvm.constant ...
  %retc = llvm.constant ...
  llvm.call @callback(%identifier, %argc, %retc, %arg0, %arg1, ... %argm, %argn)
  llvm.return
}

  // ...
  llvm.call @callback_123(%0, %1, ... %m, %n)
  // ..

And now we can register a custom gradient for callback_123 and any other callback_456.

Description of the Change:

  • Iterate over all ActiveCallbackOps and create the specialize function. Each ActiveCallbackOp will be annotated with the specialized function.
  • During the lowering of ActiveCallbackOps to LLVM-IR replace with a call to the specialized function.

Related GitHub Issues: PR #706 needs to be merged first. The first PR here is the squashed version of #706.

[sc-60494]

erick-xanadu avatar May 13 '24 19:05 erick-xanadu

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 98.04%. Comparing base (7248c12) to head (00cf68c). Report is 187 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #735   +/-   ##
=======================================
  Coverage   98.04%   98.04%           
=======================================
  Files          69       69           
  Lines        9536     9538    +2     
  Branches      762      763    +1     
=======================================
+ Hits         9350     9352    +2     
  Misses        151      151           
  Partials       35       35           

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov[bot] avatar May 13 '24 21:05 codecov[bot]

@dime10

Thanks Erick, looks good! Regarding reusing the transformation from results into arguments and memrefs into struct pointers, is that something that would be applicable to this PR?

No. But maybe a future PR might change this. I'm thinking whether using memrefs directly in the specialization and later undergoing the pointer-to-struct ABI transform would be good in terms of readability / invariants. I think so, but I will implement it in a different PR.

erick-xanadu avatar May 30 '24 18:05 erick-xanadu

Closing in favour of #782

erick-xanadu avatar Jun 03 '24 14:06 erick-xanadu