ChainRulesCore.jl icon indicating copy to clipboard operation
ChainRulesCore.jl copied to clipboard

Projection for `x:: AbstractRange`

Open mcabbott opened this issue 4 years ago • 1 comments

If we want something like https://github.com/JuliaArrays/FillArrays.jl/pull/153 to project the gradient of a Fill onto a one-dimensional subspace, then I think we probably want something similar for the gradient of a range, but projecting onto a two-dimensional space, parameterised by the endpoints. Before I lose the bit of scrap paper I wrote this on, I think this would look as follows:

ProjectTo(x::AbstractRange) = ProjectTo{AbstractRange}()

function (project::ProjectTo{AbstractRange})(dx::AbstractVector)
    L = length(dx)
    μ = mean(dx)
    # δ = -sum(diff(dx))/2
    δ = sum(Base.splat(-), zip(dx, @view dx[2:end]))/2
    return LinRange(μ + δ, μ - δ, L)
end

(project::ProjectTo{AbstractRange})(dx::AbstractRange) = dx

Using LinRange allows for zero slope (e.g. for constant dx) and skips the high-precision machinery which StepRangeLen uses to hit endpoints exactly, as I don't think we're concerned about the last digit here. This isn't yet careful about element types etc.

mcabbott avatar Aug 16 '21 14:08 mcabbott

The above formula is wrong. Correct versions are here: https://github.com/mcabbott/OddArrays.jl/blob/6c3ef3ab5ebf05c8aa6aa030590456200715be0f/src/OddArrays.jl#L814-L838

And the motivation is things like this:

julia> gradient(x -> (2 .* x)[1], 0:0.2:1)  # natural
([2.0, 0.0, 0.0, 0.0, 0.0, 0.0],)

julia> gradient(first, LinRange(0,1,5))  # structural
((start = 1.0, stop = nothing, len = nothing, lendiv = nothing),)

julia> gradient(x -> first(2 .* x), LinRange(0,1,5))
ERROR: DimensionMismatch("x and y are of different lengths!")
Stacktrace:
  [1] dot(x::Tangent{Any, NamedTuple{(:start, :stop, :len, :lendiv), Tuple{Float64, ZeroTangent, ZeroTangent, ZeroTangent}}}, y::LinRange{Float64, Int64})

julia> gradient(x -> first(LinRange(x,1,5)), 0)
(1.0,)

julia> gradient(x -> (2 .* LinRange(x,1,5))[1], 0)
ERROR: Need an adjoint for constructor LinRange{Float64, Int64}. Gradient is of type Vector{Float64}

mcabbott avatar Nov 08 '21 19:11 mcabbott