SparseConnectivityTracer.jl
SparseConnectivityTracer.jl copied to clipboard
`DataInterpolations` support
Is this the way to go with dispatching on callable structs? Or should it be done per interpolation type separately?
Wow, this PR came impressively quickly after merging #177!
Is this the way to go with dispatching on callable structs?
I'm afraid this might just overload the constructors.
You can inspect the generated code by omitting the eval statement, e.g. calling
SCT.overload_gradient_1_to_1(:DataInterpolations, AbstractInterpolation)
which will most likely return a function similar to
function DataInterpolations.AbstractInterpolation(t::SparseConnectivityTracer.GradientTracer)
return SparseConnectivityTracer.gradient_tracer_1_to_1(t, false)
end
which unfortunately is not what you want.
I'll try to figure out how to work around it. The simplest solution would be to add a new specialized overload_gradient_callable_1_to_1 code generation utility.
There must be an easier way to get all the interpolation types. Also, for some reason the tests cannot find DataInterpolations.
Edit: nvm, I'll try some different things.
I think your initial attempt was the intuitive way to do it, it just needs some work from our side.
I think my latest approach is the best, this is also the level on which other DataInterpolation extensions operate. This still doesn't work, and I think I know why looking at the generated methods:
[6] _interpolate(tx::T, ty::T) where T<:SparseConnectivityTracer.GradientTracer
@ SparseConnectivityTracerDataInterpolationsExt C:\Users\konin_bt\SciML\Other\SparseConnectivityTracer.jl\src\overloads\gradient_tracer.jl:119
[7] _interpolate(tx::T, ty::T) where T<:SparseConnectivityTracer.HessianTracer
@ SparseConnectivityTracerDataInterpolationsExt C:\Users\konin_bt\SciML\Other\SparseConnectivityTracer.jl\src\overloads\hessian_tracer.jl:188
[8] _interpolate(tx::SparseConnectivityTracer.GradientTracer, ::Real)
@ SparseConnectivityTracerDataInterpolationsExt C:\Users\konin_bt\SciML\Other\SparseConnectivityTracer.jl\src\overloads\gradient_tracer.jl:125
[9] _interpolate(::Real, ty::SparseConnectivityTracer.GradientTracer)
@ SparseConnectivityTracerDataInterpolationsExt C:\Users\konin_bt\SciML\Other\SparseConnectivityTracer.jl\src\overloads\gradient_tracer.jl:129
[10] _interpolate(dx::D, y::Real) where {P, T<:SparseConnectivityTracer.GradientTracer, D<:SparseConnectivityTracer.Dual{P, T}}
@ SparseConnectivityTracerDataInterpolationsExt C:\Users\konin_bt\SciML\Other\SparseConnectivityTracer.jl\src\overloads\gradient_tracer.jl:163
[11] _interpolate(dx::D, dy::D) where {P, T<:SparseConnectivityTracer.GradientTracer, D<:SparseConnectivityTracer.Dual{P, T}}
@ SparseConnectivityTracerDataInterpolationsExt C:\Users\konin_bt\SciML\Other\SparseConnectivityTracer.jl\src\overloads\gradient_tracer.jl:145
[12] _interpolate(dx::D, dy::D) where {P, T<:SparseConnectivityTracer.HessianTracer, D<:SparseConnectivityTracer.Dual{P, T}}
@ SparseConnectivityTracerDataInterpolationsExt C:\Users\konin_bt\SciML\Other\SparseConnectivityTracer.jl\src\overloads\hessian_tracer.jl:227
[13] _interpolate(x::Real, dy::D) where {P, T<:SparseConnectivityTracer.GradientTracer, D<:SparseConnectivityTracer.Dual{P, T}}
@ SparseConnectivityTracerDataInterpolationsExt C:\Users\konin_bt\SciML\Other\SparseConnectivityTracer.jl\src\overloads\gradient_tracer.jl:196
[14] _interpolate(dx::D, y::Real) where {P, T<:SparseConnectivityTracer.HessianTracer, D<:SparseConnectivityTracer.Dual{P, T}}
@ SparseConnectivityTracerDataInterpolationsExt C:\Users\konin_bt\SciML\Other\SparseConnectivityTracer.jl\src\overloads\hessian_tracer.jl:256
[15] _interpolate(x::Real, ty::SparseConnectivityTracer.HessianTracer)
@ SparseConnectivityTracerDataInterpolationsExt C:\Users\konin_bt\SciML\Other\SparseConnectivityTracer.jl\src\overloads\hessian_tracer.jl:204
[16] _interpolate(tx::SparseConnectivityTracer.HessianTracer, y::Real)
@ SparseConnectivityTracerDataInterpolationsExt C:\Users\konin_bt\SciML\Other\SparseConnectivityTracer.jl\src\overloads\hessian_tracer.jl:200
[17] _interpolate(x::Real, dy::D) where {P, T<:SparseConnectivityTracer.HessianTracer, D<:SparseConnectivityTracer.Dual{P, T}}
@ SparseConnectivityTracerDataInterpolationsExt C:\Users\konin_bt\SciML\Other\SparseConnectivityTracer.jl\src\overloads\hessian_tracer.jl:284
The original methods are e.g. _interpolate(A::LinearInterpolation, t), so the first argument is not a real number at all. However, this is assumed by the generated methods. I think this is a quite general issue you have to adress.
Yes, this kind of function requires writing manual overloads. I should add to the documentation that our "N-to-M operators" assume Real inputs.
Judging by its name, _interpolate seems to be an internal function. We should avoid touching these, as they can break without SemVer notice.
Overloading on callable AbstractInterpolation structs has the additional advantage of them being "1-to-1 operators" on Reals.
For my application I only need scalar to scalar, but for many interpolation types the output can also be a vector. Or is that still considered 1 to 1?
Are interpolations generic functions $f: \mathbb{R}^n \rightarrow \mathbb{R}^m$? If so, we'll have to think about what the sparsity patterns should look like and implement the overloads manually.
An overly conservative estimate could be obtained by using tools from src/overloads/arrays.jl:
- if all scalar entries in the input vector interact with each other, the union of their index sets can be computed using
first_order_ororsecond_order_or - we can then return a
Fillvector from FillArrays.jl that has the combined tracer on each entry.
(This is how we handle matrix inversion for example.)
DataInterpolations input is always 1D, but the output can be in any vector space. But most support is for scalar/vector output.
Could you rebase this PR on main and move the test cases to the new test/ext folder?
I’ll add the needed overloads for you tomorrow. Any tests you add will help me out greatly.
In general what I think should be supported (maybe each can be a separate issue/PR):
- scalar to scalar interpolation: that's just calling the interpolation object as a callable struct
- scalar to array interpolation: idem but with array output (vectors, matrices are also supported somewhat)
- code optimization: all derivatives of constant interpolation are zero, second derivative of linear interpolation is zero
- derivatives and integrals: computed as e.g.
derivative(A::LinearInterpolation, t),integral(A::LinearInterpolation, t)orintegral(A::LinearInterpolation, t1, t2) - local: in certain intervals certain derivatives of certain interpolation types are 0. Might not be worth the effort though
- derivatives w.r.t. constructor input: quite niche, but sometimes people want to compute derivatives w.r.t. input data (for e.g. optimization). Here the sparsity pattern depends on the interpolation type.
Thanks for the list! I think that should all be doable in here.
By the way, what's the philosophy of where extensions are located? DataInterpolations has its own extensions for Symbolics, Zygote etc. For those extensions it's also not that bad if they use DataInterpolations internals
I'd say we should have them in here for now:
- SCT is less stable and less established than DataInterpolations
- the overloads won't use internals from DataInterpolations but (in this case) will have to use internals from SCT
:warning: Please install the to ensure uploads and comments are reliably processed by Codecov.
Codecov Report
Attention: Patch coverage is 87.12871% with 26 lines in your changes missing coverage. Please review.
Project coverage is 90.42%. Comparing base (
c7355d3) to head (23ee8ab). Report is 1 commits behind head on main.
Additional details and impacted files
@@ Coverage Diff @@
## main #178 +/- ##
==========================================
- Coverage 91.36% 90.42% -0.94%
==========================================
Files 41 44 +3
Lines 1772 2037 +265
==========================================
+ Hits 1619 1842 +223
- Misses 153 195 +42
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
1D interpolations should be working now. Do you have an example for N-dimensional interpolations? I couldn't find any in the DataInterpolations docs.
vector output:
using DataInterpolations
t = [1.0, 2.0, 3.0, 4.0]
# Input defined as matrix
u = [1.0 4.0 9.0 16.0; 1.0 4.0 9.0 16.0]
A = QuadraticInterpolation(u, t)
A(2.5) # [6.25, 6.25]
# Input defined as vector of vectors
u = [u[:, i] for i in 1:4]
A = QuadraticInterpolation(u, t)
A(2.5) # [6.25, 6.25]
Looks like we can dispatch on the parametric types of each interpolant:
julia> uvec = [1.0, 2.0, 5.0];
julia> umat = [1.0 4.0 9.0; 1.0 4.0 9.0];
julia> QuadraticInterpolation(uvec, t)
QuadraticInterpolation with 3 points, Forward mode
┌──────┬─────┐
│ time │ u │
├──────┼─────┤
│ 0.0 │ 1.0 │
│ 1.0 │ 2.0 │
│ 3.0 │ 5.0 │
└──────┴─────┘
julia> typeof(ans)
QuadraticInterpolation{Vector{Float64}, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64}
julia> QuadraticInterpolation(umat, t)
QuadraticInterpolation with 3 points, Forward mode
┌──────┬─────┬─────┐
│ time │ u1 │ u2 │
├──────┼─────┼─────┤
│ 0.0 │ 1.0 │ 1.0 │
│ 1.0 │ 4.0 │ 4.0 │
│ 3.0 │ 9.0 │ 9.0 │
└──────┴─────┴─────┘
julia> typeof(ans)
QuadraticInterpolation{Matrix{Float64}, Vector{Float64}, Vector{Float64}, Vector{Vector{Float64}}, Float64}
Unfortunately for us, they are all AbstractInterpolation{eltype(u)}, not AbstractInterpolation{typeof(u)}, so we'll have to duplicate a bit of code.
But at least, it only extends to matrices:
julia> uten = rand(3, 3, 3);
julia> QuadraticInterpolation(uten, t)
ERROR: MethodError: no method matching munge_data(::Array{Float64, 3}, ::Vector{Float64})
The code could have been similar to the initial 20 LOC if DataInterpolation's AbstractInterpolation supertype had some indicator whether outputs are scalar or not.
But N-dim interpolations should be supported now.
Can I ask you for a final review before we merge this @SouthEndMusic?
Urgh, I wasn't aware DataInterpolations had separate methods for AbstractArray inputs...
What about things like this:
import SparseConnectivityTracer as SCT using DataInterpolations method = SCT.TracerLocalSparsityDetector() t = [0.0, 1.0, 3.0] f(u) = ConstantInterpolation(u, t)(2.0) u = [1.0, 2.0, 5.0] SCT.jacobian_sparsity(f, u, method) # 1×3 SparseArrays.SparseMatrixCSC{Bool, Int64} with 1 stored entry: # ⋅ 1 ⋅As you can see it already works for some methods, but for other methods it doesn't yet (e.g.
LinearInterpolation).
Wait, this is an entirely different use-case than what we are currently implementing. You don't want to differentiate through the interpolation, you want to differentiate through the creation of the interpolant?
As you can see it already works for some methods
I would say that with the additional overloads on tracers, this is unintended behavior, as this PR only returns the input indices of the "interpolation query tracers t", not the potential "interpolant constructor tracers" u. Pushing tracers through interpolants constructed from tracers would require additional methods.
We should overload the constructors to throw a NotImplementedError on Tracers to not bloat this PR.
The differentiation of the result of an interpolation w.r.t. to the input arguments of its constructor is also a bit more nuanced than what we are currently doing.
You don't want to differentiate through the interpolation, you want to differentiate through the creation of the interpolant?
As usual, I assume that this works out of the box for local tracers, and that only global tracers are an issue?
The differentiation of the result of an interpolation w.r.t. to the input arguments of its constructor is also a bit more nuanced than what we are currently doing.
IIUC this would require constructing each field of the corresponding Interpolation type (some of which may not even be public API) and filling them with the appropriate tracers in a mathematically sound way. So there's no hope of doing this in a batched way. I think this is out of scope for the current PR, although I understand the appeal.
Of course there's also the crazy option: make LinearInterpolation(u::Vector{Tracer}, t::Vector{Tracer}) return a new TracerLinearInterpolation object without all the caches and SciML shenanigans.
As usual, I assume that this works out of the box for local tracers, and that only global tracers are an issue?
Maybe before this PR, but with the new overloads we possibly return the wrong patterns.
IIUC this would require constructing each field of the corresponding Interpolation type (some of which may not even be public API) and filling them with the appropriate tracers in a mathematically sound way.
I‘m not even sure whether there is a way for Global tracers to return any kind of sparsity. Let’s take LinearInterpolation as an example. For a given query t, the output depends on the closest lower and higher datapoints. Global tracers don’t have the required primal value for such an ordering.
If a dense pattern is the correct solution for global tracers, the solution might not be too complicated:
Collect all tracers from the data (if there are any) and take a union over all of them. Then take a union with the „query tracers“ and return a Fill array of the correct size.
I‘m not even sure whether there is a way for Global tracers to return any kind of sparsity
You're right about that, sparsity of an interpolation call w.r.t. the input u is only meaningful locally, for global you cannot say anything or you would have to assume dense dependency on the whole of u.
We should overload the constructors to throw a NotImplementedError on Tracers to not bloat this PR
That is fine by me, this is quite a niche usecase. This was also implemented for DataInterpolations + Enzyme only partially and relatively recently because someone specifically asked for it. I think the most important use cases left are derivative and integral.
Global tracers don’t have the required primal value for such an ordering.
That depends, if t is assumed to be ordered then we have that neighborness information
That depends, if
tis assumed to be ordered then we have that neighborness information
No, the query is entirely independent from the data. We don’t know which data points are closest to the query without primals.