ImplicitDifferentiation.jl
ImplicitDifferentiation.jl copied to clipboard
Compatibility with scalar functions
At the moment our implementation fails for scalar functions, even though they should be a special case of vector. Need to investigate why
Note: whenever the function is scalar, use static vectors to store the output
After some reflection, I'm not sure this is a good idea. The "array in, array out" convention is fairly clear. If we allow scalars in the output, why not allow them in the input? It will be a lot of headaches and code complexity for a quality-of-life improvement
If we test that our implementation works for SVectors, we can document that to make it work for scalars, users must wrap the scalar in a static vector of length 1.
Sounds good to me
Code suggestion from @mohamed82008 in #51:
struct ScalarImplicitFunction{F}
f::F
end
function ScalarImplicitFunction(forward, conditions, linear_solver)
_forward = x -> begin
y, z = forward(only(x))
return @SVector([y]), z
end
_conditions = (x, y, z) -> @SVector [conditions(only(x), only(y), z)]
f = ImplicitFunction(_forward, _conditions, linear_solver)
return ScalarImplicitFunction(f)
end
function (f::ScalarImplicitFunction)(x::Real;_ kwargs...)
return only(f.f(@SVector([x]); kwargs...))
end
Won't the implmentation above require the input to be scalar as well? It'd be great if all of these worked
\begin{aligned}
\mathbb{R} &\rightarrow \mathbb{R}\\
\mathbb{R}^n &\rightarrow \mathbb{R}\\
\mathbb{R} &\rightarrow \mathbb{R}^n\\
\mathbb{R}^n &\rightarrow \mathbb{R}^n
\end{aligned}
I agree that it would be great, but my failed attempt to implement the reverse rule (#51) showed me it's a bit harder than it appears. While I'm grateful for your forward rule implementation, I didn't want to merge only half of the solution, which is why I shelved it for now.
Our first priority with @mohamed82008 is getting the API stabilized by making the byproduct z optional throughout (#57). Once that is done, we'll test the StaticVector solution and thoroughly document it (which should also bring about performance improvements in the linear solve, as you noted). And after that, we'll reconsider how we can add dispatches that go from x to SVector(x) and back :)
I reopened the PR and tagged you on the main hurdle I faced, in case you come up with a good workaround!
Won't the implmentation above require the input to be scalar as well? It'd be great if all of these worked
It's possible to modify the implementation to support all these cases if you pass in some sample inputs.
The main branch now supports static arrays. Static arrays of size 1 should be very similar in performance to scalars.