NNlib.jl
NNlib.jl copied to clipboard
Can we get rid of auto-broadcasting of 0D arrays for activations?
(Ideally I don't think it should be auto-broadcasting in the first place). But if we just get rid of O-D array broadcasting that solves our problem over at https://github.com/EnzymeAD/Reactant.jl/issues/54
https://github.com/FluxML/NNlib.jl/blob/ba29c9044402d40d349379313599ca2621c6d6b2/src/activations.jl#L752-L755
Essentially, Reactant needs to treat scalars as OD trackedrarrays but that causes a recursion loop and expectedly the IR has an unreachable (https://github.com/EnzymeAD/Reactant.jl/issues/54#issuecomment-2383975256). This means the only way we can support NNlib activations is to manually copy over all the code for activation functions.
Now I know there isn't a general way to "opt-out" of the broadcasting for 0-D arrays but we can just define the broadcasting for N=1..10 and hope no one is using an 11+D tensor.
I don't like the auto-broadcast either but here we are.
The built-in opt-out is this function -- which perhaps Reactant needs to know how to handle anyway?
julia> Base.Broadcast.broadcast_preserving_zero_d(sin, fill(pi/2))
0-dimensional Array{Float64, 0}:
1.0
julia> Base.Broadcast.broadcast_preserving_zero_d(sin, [0, pi/2])
2-element Vector{Float64}:
0.0
1.0
This will still cause issues, right? I want the OD case to be forwarded to the original call without any broadcasting. For example, for relu I want relu(x) = max(x, 0) to be called instead of relu(x::AbstractArray) = relu.(x)
But what this function is fed is some special fake 0D array which Reactant invents? My hope is that it can also be made to understand that broadcast_preserving_zero_d(sin, x::ZeroDimTrackedArray):: ZeroDimTrackedArray , but short-circuiting the present implementation used for Array{T,0}.
I found a solution that would do it:
for nnlib_op in setdiff(Tuple(NNlib.ACTIVATIONS), (:tanh_fast, :sigmoid_fast, :sigmoid, :σ))
@eval function NNlib.$(nnlib_op)(x::TracedRArray{T,0}) where {T}
return invoke(NNlib.$(nnlib_op), Tuple{Any}, x)
end
end
what this function is fed is some special fake 0D array which Reactant invents?
correct. Reactant doesn't have a "Number" type, so we treat 0D arrays as a scalar
With a rework of the Reactant scalar handling this is now fixed without using invoke