How to differentiate StableHLO with Enzyme-JAX from C++?
I'm trying to differentiate a StableHLO mlir::ModuleOp, and I'm lost. I'm starting by trying to interface with Enzyme-JAX, but I've noticed that the only source file exported by bazel is enzymexlamlir-opt.cpp. I did try to use registerStableHLODialectAutoDiffInterface but that's not exported. What's the recommended usage of Enzyme-JAX from C++?
closing as I think I've missed sth, may reopen
could you reopen please? All good
It depends on setup but I think my recommendation would be to emit an enzyme autodiff or forwardiff op, then you can run the enzyme pass which will replace the op with a call to the derivative
OK thanks. I'll need to do some reading to understand (I'm unfamiliar with "emit", and "pass")
Is this all in Enzyme-JAX repo or is some in Enzyme repo?
partially,
so just like stablehlo has ops for things there's also enzyme autodiff ops
In text form you can see it, for example, here: https://github.com/EnzymeAD/Enzyme-JAX/blob/fb483c06f697990c60cc3c0bda7fb1d730fca3de/test/lit_tests/grad_sum1d.mlir#L11
You then can run an optimization pass which generates the derivative [creating the code at the bottom in the comment (which is what the test compares against)]
thanks, and are there C++ apis for this? Looking at the Julia implementation "emit" seems to defer to the MLIR C++ API.
I'm not sure what running an optimization pass looks like. Is that with PassManager? Which passes are needed for enzyme.autodiff?
btw if you don't have time for all these questions, just say
Yeah there are. For an example of running passes you can look here https://github.com/EnzymeAD/Enzyme-JAX/blob/dea63960da134128b152c1624d1425048cd9fb3a/src/enzyme_ad/jax/compile_with_xla.cc#L99 which takes a module op and a string containing which passes to run and constructs a pass pipeline and runs it. You can also just add the passes using the C++ API.
Here you should just need to run the Enzyme pass, and the Enzyme remove unnecessary ops pass
thanks. I did look at that function, but it's not available in the public API of Enzyme-JAX. I might just copy its intention.
I more meant APIs for building enzyme.autodiff
This is very helpful. I can get a lot further with this
Where can I find the relevant dialects and passes? I tried
registry.insert<mlir::enzyme::EnzymeDialect>()
mlir::stablehlo::registerAllDialects(registry)
and
mlir::enzyme::createDifferentiatePass();
but that's not working. I don't know if that's the problem.
Much of the relevant machinery seems to be private, and there are many many different functions for dialects and passes so I'm lost on what's needed.
BTW once I work out how to do this, I might make a small C++ library for differentiating StableHLO, one that's not specific to any particular frontend
BTW once I work out how to do this, I might make a small C++ library for differentiating StableHLO, one that's not specific to any particular frontend
isn't that just the enzymexla-interpreter?
isn't that just the
enzymexla-interpreter?
what's that?
There’s a binary enzymexlamlir-opt that takes stablehlo (and other general MLIR files) as inputs, a list of optimizations (including differentiation) as args and prints out the transformed code to stdout or a file of choice.
Or @mofeing did you mean the interpreter (which works similarly but assumes the code is compromised of constants and essentially does the transformation and generates the final constant result)
i mean that the interpreter does what @joelberkeley wants but in a binary manner (take MLIR files and return result of transformations or perform actual interpretation and return values).
like what he wants is sth similar to the interpreter but "librarizing" it
I'm differentiating a C++ mlir::ModuleOp. The StableHLO is generated at runtime
In any case can you post your full code and error message and we can try to help see if there’s something missing (likely registering one of the MLIR interfaces).
I can paste it, though it's not even first draft yet, and is a combination of two different languages, so I'm not sure it will help. I'm just trying to differentiate a tensor<f64> -> tensor<f64> for now
computation <- compile xlaBuilder f
stablehlo <- hloModuleProtoToStableHLO !(proto computation)
reg <- mkDialectRegistry
insertEnzymeDialect reg
StableHLO.Dialect.Register.registerAllDialects reg
ctx <- getContext stablehlo
appendDialectRegistry ctx reg
mgr <- mkPassManager ctx
addPass mgr !createDifferentiatePass
enzymeOp <- emitEnzymeADOp stablehlo reg
_ <- run mgr enzymeOp
hloProto <- convertStablehloToHlo stablehlo
computation <- mkXlaComputation hloProto
and
mlir::ModuleOp* emitEnzymeADOp(mlir::ModuleOp& module_op, mlir::DialectRegistry& registry) {
mlir::enzyme::registerCoreDialectAutodiffInterfaces(registry);
auto ctx = module_op.getContext();
auto state = mlir::OperationState(mlir::UnknownLoc::get(ctx), "enzyme.autodiff");
auto scalarf64 = mlir::RankedTensorType::get({}, mlir::FloatType::getF64(ctx));
state.addTypes({scalarf64});
auto operands = module_op.getOperation()->getOperands(); // complete guess
state.addOperands(mlir::ValueRange(operands));
auto operation = module_op.getOperation(); // complete guess
state.addAttribute("fn", operation->getAttr("sym_name"));
auto activity = mlir::enzyme::ActivityAttr::get(ctx, mlir::enzyme::Activity::enzyme_active);
state.addAttribute("activity", {activity});
auto ret_activity = mlir::enzyme::ActivityAttr::get(
ctx, mlir::enzyme::Activity::enzyme_activenoneed
);
state.addAttribute("ret_activity", {ret_activity});
auto res = mlir::Operation::create(state);
return new mlir::ModuleOp(res);
}
Error
LLVM ERROR: can't create Attribute 'mlir::enzyme::ActivityAttr' because storage uniquer isn't initialized: the dialect was likely not loaded, or the attribute wasn't added with addAttributes<...>() in the Dialect::initialize() method.
hm yeah I think you're not initializing the Enzyme dialect from the looks of the error message.
This function here is extremely overkill but it should definitely add it: https://github.com/EnzymeAD/Enzyme-JAX/blob/fdcf4018e13d6fb17ffa290672ef1224cb739e7f/src/enzyme_ad/jax/RegistryUtils.cpp#L53 .
Also considering you're explicitly adding the op itself (and it's not being generated by a different pass or parser) you may need to explicitly load the dialect in the context
yeah I saw that function, but it's not public. I did try to reproduce it but its contents aren't public either
Feel free to make a PR to make anything public that you need
ok thanks. I'm going to head off now (it's late here). I'll come back to this in a few days. Seasons greetings!
I did quite a bit more digging today, but didn't progress. Here are some notes. I added prepareRegistry. I also added calls to registerenzymePasses and regsiterenzymeXLAPasses. Same error persists.
I noticed the call to addAttributes is commented out for EnzymeXLADialect
void EnzymeXLADialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "src/enzyme_ad/jax/Dialect/EnzymeXLAOps.cpp.inc"
>();
// addAttributes<
// #define GET_ATTRDEF_LIST
// #include "src/enzyme_ad/jax/Dialect/EnzymeXLAAttributes.cpp.inc"
// >();
// addTypes<
// #define GET_TYPEDEF_LIST
// #include "src/enzyme_ad/jax/Dialect/EnzymeXLAOpsTypes.cpp.inc"
// >();
}
Might that be the cause? I noticed mlir::enzyme::ActivityAttr is in the Enzyme repo not Enzyme-JAX. I uncommented those lines but they refer to files that no longer exist.
I'm still keen to make a StableHLO autodiff library that's decoupled from JAX, but I might not have time.
No that’s unrelated (it’s that we recently added an xla dialect for optimizing kernel calls but haven’t yet added custom attributes, if we do well I comment that).
The code here for autodiff is independent of jax (but is in this repo for ease for building the jax plugin). Perhaps it should be renamed enzymexla and also there’s an ongoing discussion on moving it into stablehlo proper.
This is all setup for MLIR (which is unfortunately not the clearest).
My recommendation: take the enzymexlmlir-opt.cpp binary, and copy it to be a library file and call the pass manager with your op instead of parsing a new one in from a file. That way you’ll have something with all the setup properly done (and it’s easier to remove excess registration once it’s running imo).
To make things more concrete: take this https://github.com/EnzymeAD/Enzyme-JAX/blob/b6d6563aa3a3050474a4250bf18322f7ebf0b486/src/enzyme_ad/jax/enzymexlamlir-opt.cpp#L124 line out and run the pass on your favorite module , first importing the dialect registry into the module’s context
Alternatively if you can share a repo with your whole setup we can try to take a look and fiddle with MLIR’s setup to make sure things are registered
ok, I've copy-pasted that function, and put its contents before my code. I still get the same error. I will compose all my stuff into a function that's as self-contained as possible and paste it here
I've edited this code to be more self-contained, it now produces an executable binary instead of a library. I might get round to making it into a git repo
#include "stablehlo/dialect/Register.h"
#include "xla/hlo/builder/xla_builder.h"
#include "xla/hlo/translate/stablehlo.h"
#include "xla/hlo/builder/lib/math.h"
#include "xla/mlir_hlo/mhlo/IR/register.h"
#include "Enzyme/MLIR/Dialect/Dialect.h"
#include "Enzyme/MLIR/Dialect/Ops.h"
#include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h"
#include "Enzyme/MLIR/Passes/Passes.h"
#include "src/enzyme_ad/jax/Dialect/Dialect.h"
#include "src/enzyme_ad/jax/Implementations/XLADerivatives.h"
#include "src/enzyme_ad/jax/Passes/Passes.h"
#include "src/enzyme_ad/jax/TransformOps/TransformOps.h"
#include "src/enzyme_ad/jax/RegistryUtils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
#include "mlir/Conversion/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/Func/Extensions/InlinerExtension.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Transform/Transforms/Passes.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Target/LLVM/NVVM/Target.h"
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
#include "llvm/Support/TargetSelect.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "stablehlo/tests/CheckOps.h"
class MemRefInsider
: public mlir::MemRefElementTypeInterface::FallbackModel<MemRefInsider> {};
template <typename T>
struct PtrElementModel
: public mlir::LLVM::PointerElementTypeInterface::ExternalModel<
PtrElementModel<T>, T> {};
int main() {
// create the stablehlo computation
xla::XlaBuilder builder("root");
auto xlaScalarf64 = xla::ShapeUtil::MakeScalarShape((xla::PrimitiveType) 12);
auto arg = xla::Parameter(&builder, 0, xlaScalarf64, "arg");
auto proto = builder.Build(xla::Square(arg))->proto();
mlir::MLIRContext ctx;
mlir::DialectRegistry registry_;
ctx.appendDialectRegistry(registry_);
mlir::mhlo::registerAllMhloDialects(registry_);
mlir::stablehlo::registerAllDialects(registry_);
auto module_op_ = xla::ConvertHloToStablehlo(ctx, &proto).value().release();
// stuff copied from enzyme mlir main function
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
registry_.insert<mlir::stablehlo::check::CheckDialect>();
prepareRegistry(registry_);
mlir::registerenzymePasses();
regsiterenzymeXLAPasses();
mlir::registerCSEPass();
mlir::registerConvertAffineToStandardPass();
mlir::registerSCCPPass();
mlir::registerInlinerPass();
mlir::registerCanonicalizerPass();
mlir::registerSymbolDCEPass();
mlir::registerLoopInvariantCodeMotionPass();
mlir::registerConvertSCFToOpenMPPass();
mlir::affine::registerAffinePasses();
mlir::registerReconcileUnrealizedCasts();
registry_.addExtension(+[](mlir::MLIRContext *ctx, mlir::LLVM::LLVMDialect *dialect) {
mlir::LLVM::LLVMFunctionType::attachInterface<MemRefInsider>(*ctx);
mlir::LLVM::LLVMArrayType::attachInterface<MemRefInsider>(*ctx);
mlir::LLVM::LLVMPointerType::attachInterface<MemRefInsider>(*ctx);
mlir::LLVM::LLVMStructType::attachInterface<MemRefInsider>(*ctx);
mlir::MemRefType::attachInterface<PtrElementModel<mlir::MemRefType>>(*ctx);
mlir::LLVM::LLVMStructType::attachInterface<
PtrElementModel<mlir::LLVM::LLVMStructType>>(*ctx);
mlir::LLVM::LLVMPointerType::attachInterface<
PtrElementModel<mlir::LLVM::LLVMPointerType>>(*ctx);
mlir::LLVM::LLVMArrayType::attachInterface<PtrElementModel<mlir::LLVM::LLVMArrayType>>(*ctx);
});
mlir::transform::registerInterpreterPass();
mlir::enzyme::registerGenerateApplyPatternsPass();
mlir::enzyme::registerRemoveTransformPass();
// attempt to create an `enzyme.autodiff` op
auto state = mlir::OperationState(mlir::UnknownLoc::get(&ctx), "enzyme.autodiff");
auto scalarf64 = mlir::RankedTensorType::get({}, mlir::FloatType::getF64(&ctx));
state.addTypes({scalarf64});
auto operands = module_op_.getOperation()->getOperands(); // complete guess
state.addOperands(mlir::ValueRange(operands));
auto operation = module_op_.getOperation(); // complete guess
state.addAttribute("fn", operation->getAttr("sym_name"));
auto activity = mlir::enzyme::ActivityAttr::get(&ctx, mlir::enzyme::Activity::enzyme_active);
state.addAttribute("activity", {activity});
auto ret_activity = mlir::enzyme::ActivityAttr::get(
&ctx, mlir::enzyme::Activity::enzyme_activenoneed
);
state.addAttribute("ret_activity", {ret_activity});
auto res = mlir::Operation::create(state);
return 0;
}
Here is the bazel target for the above
cc_binary(
name = "example",
linkstatic = True,
srcs = glob(["*.cpp"]),
deps = [
"@xla//xla/hlo/builder:xla_builder",
"@xla//xla/hlo/translate:stablehlo",
"@xla//xla/hlo/builder/lib:math",
"@xla//xla/mlir_hlo:hlo_dialect_registration",
"@enzyme-jax//:everything",
],
visibility = ["//visibility:public"],
)
where I've added this target to Enzyme-JAX
cc_library(
name = "everything",
srcs = [
"//src/enzyme_ad/jax:TransformOps",
"//src/enzyme_ad/jax:XLADerivatives",
"//src/enzyme_ad/jax:RegistryUtils.cpp",
],
hdrs = [
"//src/enzyme_ad/jax:TransformOps",
"//src/enzyme_ad/jax:XLADerivatives",
"//src/enzyme_ad/jax:RegistryUtils.h",
],
visibility = ["//visibility:public"],
deps = [
"@enzyme//:EnzymeMLIR",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:AsyncDialect",
"@llvm-project//mlir:ComplexDialect",
"@llvm-project//mlir:ControlFlowDialect",
"@llvm-project//mlir:ConversionPasses",
"@llvm-project//mlir:DLTIDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:LinalgDialect",
"@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:MlirOptLib",
"@llvm-project//mlir:NVVMDialect",
"@llvm-project//mlir:NVGPUDialect",
"@llvm-project//mlir:OpenMPDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:TransformDialect",
"@llvm-project//mlir:Transforms",
"//src/enzyme_ad/jax:TransformOps",
"//src/enzyme_ad/jax:XLADerivatives",
"@stablehlo//:chlo_ops",
"@stablehlo//stablehlo/tests:check_ops",
"@llvm-project//mlir:ArithToLLVM",
"@llvm-project//mlir:BuiltinToLLVMIRTranslation",
"@llvm-project//mlir:ComplexToLLVM",
"@llvm-project//mlir:ControlFlowToLLVM",
"@llvm-project//mlir:GPUToLLVMIRTranslation",
"@llvm-project//mlir:LLVMToLLVMIRTranslation",
"@llvm-project//mlir:NVVMToLLVMIRTranslation",
"@llvm-project//llvm:X86AsmParser",
"@llvm-project//llvm:X86CodeGen",
],
)
I'm going to count my chickens and say you can test this using this change, which is basically the above but inserted into this repo
https://github.com/EnzymeAD/Enzyme-JAX/compare/main...joelberkeley:Enzyme-JAX:example?expand=1
with
bazel build //:example
./bazel-bin/example
ok that has linker errors. I may come back to it. The code itself "works", in that it fails at runtime as described.
I've made progress: I added ctx->loadDialect<mlir::enzyme::EnzymeDialect>(), which curiously I can't find any mention of in either the Enzyme or Enzyme-JAX repos. I got it from the MLIR tutorials
Hi, do you support f64?
I think I've got the MLIR right: from this
module @root.3 attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
func.func @main(%arg0: tensor<f64>) -> tensor<f64> {
%0 = stablehlo.multiply %arg0, %arg0 : tensor<f64>
return %0 : tensor<f64>
}
}
to this
module @root.3 attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
func.func @tmp(%arg0: tensor<f64>) -> tensor<f64> {
%0 = stablehlo.multiply %arg0, %arg0 : tensor<f64>
return %0 : tensor<f64>
}
func.func @main(%arg0: tensor<f64>) -> tensor<f64> {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<f64>
%0 = enzyme.autodiff @tmp(%arg0, %cst) {activity = [#enzyme<activity enzyme_active>], ret_activity = [#enzyme<activity enzyme_activenoneed>]} : (tensor<f64>, tensor<f64>) -> tensor<f64>
return %0 : tensor<f64>
}
}
but when I run
mlir::PassManager pm(ctx);
pm.addPass(mlir::enzyme::createDifferentiatePass());
pm.run(module_op_);
I see
type does not have autodifftypeinterface: tensor<f64>