Enzyme.jl
Enzyme.jl copied to clipboard
Implementation in DynamicExpressions.jl
I'm close to getting a working implementation of Enzyme.jl-generated derivatives for use in DynamicExpressions.jl (and by extension SymbolicRegression.jl and PySR). I was wondering if you could help me with a couple of things related to this?
1. Disabling warnings
How do I turn off the Recursive type
warnings?
┌ Warning: Recursive type
│ T = Node{Float64}
└ @ Enzyme /dev/shm/.julia/packages/GPUCompiler/BxfIW/src/utils.jl:56
I get about 100 of these on the first compile. It hasn't seemed to cause issues though (though maybe see 3. below), so I'm comfortable just turning it off for downstream users so people using the Python frontend aren't confused.
2. Improving speed
Currently it seems that the hand-rolled autodiff I wrote internally seems to have slightly better performance. The code I wrote isn't bad per se, but because the forward evaluation is so much more optimized in the code, my hope would be that Enzyme.jl could beat the internal autodiff by taking advantage of that.
Here is what I have so far. eval_tree_array
takes in a matrix of shape (nfeatures, nrows), and returns a vector of shape (nrows,).
function _eval_tree_array!(
result::AbstractVector{T},
tree::Node{T},
X::AbstractMatrix{T},
operators::OperatorEnum,
) where {T}
out, completed = eval_tree_array(tree, X, operators; turbo=false)
if !completed
result .= T(NaN)
else
result .= out
end
return nothing
end
function enzyme_forward_diff_variables(
i::Int, tree::Node{T}, X::AbstractMatrix{T}, operators::OperatorEnum
) where {T}
nfeatures, nrows = size(X)
result = zeros(T, nrows)
dresult = ones(T, nrows)
dX = zeros(T, nfeatures, nrows)
dX[i, :] .= T(1)
###############################
# Enzyme code: ################
###############################
Enzyme.autodiff(
Forward,
_eval_tree_array!,
Duplicated(result, dresult),
Const(tree),
Duplicated(X, dX),
Const(operators),
)
###############################
###############################
return dresult
end
function enzyme_forward_gradient(
tree::Node{T}, X::AbstractMatrix{T}, operators::OperatorEnum; variable::Bool=false
) where {T}
output = similar(X)
nfeatures = size(X, 1)
for i in 1:nfeatures
output[i, :] .= enzyme_forward_diff_variables(i, tree, X, operators)
end
return output
end
Is there anything I'm doing here that is obviously wrong performance-wise? You can see the full implementation here: https://github.com/SymbolicML/DynamicExpressions.jl/blob/1dcd09d0851ea782a57e79b78b7af7e42aa2b5d2/src/EnzymeInterface.jl
Here are the current benchmarks:
julia> @btime eval_grad_tree_array(ctree, X, operators; variable=true) setup=(ctree=copy_node(tree) * randn());
548.352 μs (118 allocations: 1.99 MiB)
julia> @btime enzyme_forward_gradient(ctree, X, operators; variable=true) setup=(ctree=copy_node(tree) * randn());
692.428 μs (173 allocations: 1.38 MiB)
so the custom gradient code is about 26% faster. Here is the setup code:
using DynamicExpressions
operators = OperatorEnum(; binary_operators=[+, -, *, /], unary_operators=[cos, sin], enable_autodiff=true);
# ^ Creates per-operator derivative with Zygote.jl
x1, x2, x3 = (i->Node(Float64; feature=i)).(1:3);
tree = cos(x1 * 3.2 - 5.8) * 0.2 - 0.5 * x2 * x3 * x3 + 0.9 / (x1 * x1 + 1);
X = randn(3, 5000);
~~3. Gradients with respect to variables in binary trees~~
Edit: Fixed!
~~I'm also trying to implement a gradient with respect to constants in trees, but it seems like it's returning all zeros. (Perhaps this is what the warning was about...)~~
~~Here is what I have so far:~~
function enzyme_forward_diff_constants(
i::Int,
tree::Node{T},
constants::AbstractVector{T},
X::AbstractMatrix{T},
operators::OperatorEnum,
) where {T}
nrows = size(X, 2)
result = zeros(T, nrows)
dresult = ones(T, nrows)
dconstants = zeros(T, length(constants))
dconstants[i] = T(1)
autodiff(
Forward,
_eval_tree_array!,
Duplicated(result, dresult),
Const(tree),
Duplicated(constants, dconstants),
Const(X),
Const(operators),
)
return dresult
end
with the modified evaluation code as:
function _eval_tree_array!(
result::AbstractVector{T},
tree::Node{T},
constants::C,
X::AbstractMatrix{T},
operators::OperatorEnum,
) where {T,C}
if !(C <: Nothing)
set_constants(tree, constants)
end
out, completed = eval_tree_array(tree, X, operators; turbo=false)
if !completed
result .= T(NaN)
else
result .= out
end
return nothing
end
This set_constants(tree, constants)
might be causing issues for it. If this is the case I could look at passing a vector of constants directly to eval_tree_array
and indexing it internally, however, the mutation of tree
is a bit cleaner.
Okay, I got 3. to work!! I just had to use Duplicated(tree, copy_node(tree))
instead of Const(tree)
. I had thought that I only needed to duplicate the constants arrays, but I suppose since the tree was being changed, it was important I duplicate it as well.
So now just 1 and 2 remain. Excited to get this ready for implementation!
Cheers, Miles
Regarding performance, two thngs to look for:
- type instabilities. Anything come via @code_warntype or similar ?
- if you add Enzyme.API.printperf!(true) it will show some performance warnings.
I think I forgot to send my original reply, sorry for the late response.
- I did a deep dive with Cthulhu and didn't see any type instability issues. There is that "Recursive type" warning but I'm not sure if this is an issue for Enzyme or not.
- I see the following messages with
printperf!(true)
:
SE could not compute loop limit of L93 of fwddiffejulia_deg0_eval_1730lim: ***COULDNOTCOMPUTE*** maxlim: (-1 + %30)<nsw>
failed to deduce type of load %204 = load i64, i64 addrspace(11)* %203, align 8, !dbg !200, !tbaa !41, !alias.scope !47, !noalias !50
failed to deduce type of load %210 = load i64, i64 addrspace(11)* %209, align 8, !dbg !200, !tbaa !41, !alias.scope !47, !noalias !50
failed to deduce type of load %244 = load i64, i64 addrspace(11)* %243, align 8, !dbg !207, !tbaa !41, !alias.scope !47, !noalias !50
SE could not compute loop limit of L93 of fwddiffejulia_deg0_eval_3114lim: ***COULDNOTCOMPUTE*** maxlim: (-1 + %30)<nsw>
failed to deduce type of load %204 = load i64, i64 addrspace(11)* %203, align 8, !dbg !206, !tbaa !48, !alias.scope !54, !noalias !57
failed to deduce type of load %210 = load i64, i64 addrspace(11)* %209, align 8, !dbg !206, !tbaa !48, !alias.scope !54, !noalias !57
failed to deduce type of load %244 = load i64, i64 addrspace(11)* %243, align 8, !dbg !213, !tbaa !48, !alias.scope !54, !noalias !57
The deg0_eval
it references is as follows:
function deg0_eval(
tree::Node{T}, cX::AbstractMatrix{T}
)::Tuple{AbstractVector{T},Bool} where {T<:Number}
if tree.constant
return (fill_similar(tree.val::T, cX, axes(cX, 2)), true)
else
return (cX[tree.feature, :], true)
end
end
where fill_similar
is defined as
@inline fill_similar(value, array, args...) = fill!(similar(array, args...), value)
One other potential clue is that I get nearly identical performance to my hand-rolled gradient when differentiating with respect to the input data. It is only when differentiating with respect to leafs of my binary tree do I see a 30% performance hit.
Perhaps this could be due to the fact that I need to duplicate both the binary tree structure, as well as the array for storing the leafs?
function enzyme_forward_diff_constants(
i::Int,
tree::Node{T},
constants::AbstractVector{T},
X::AbstractMatrix{T},
operators::OperatorEnum,
) where {T}
nrows = size(X, 2)
result = zeros(T, nrows)
dresult = ones(T, nrows)
dconstants = zeros(T, length(constants))
dconstants[i] = T(1)
autodiff(
Forward,
_eval_tree_array!,
Duplicated(result, dresult),
Duplicated(tree, copy_node(tree)),
Duplicated(constants, dconstants),
Const(X),
Const(operators),
)
return dresult
end
(Code in question: https://github.com/SymbolicML/DynamicExpressions.jl/blob/658900e0596700736cccfa8a23715c6fecef5ed9/src/EnzymeInterface.jl)
Not sure why this is, but if I replace the code:
return (cX[tree.feature, :], true)
with the code
out = similar(cX, axes(cX, 2))
feature = tree.feature
@inbounds @simd for j in axes(cX, 2)
out[j] = cX[feature, j]
end
return (out, true)
the loop limit warning (SE could not compute loop limit of L93 of fwddiffejulia_deg0_eval_3114lim
) goes away.
Edit: unchanged performance.
@MilesCranmer is this issue still live?