ImplicitDifferentiation.jl icon indicating copy to clipboard operation
ImplicitDifferentiation.jl copied to clipboard

Compatibility with scalar functions

Open gdalle opened this issue 3 years ago • 10 comments

At the moment our implementation fails for scalar functions, even though they should be a special case of vector. Need to investigate why

gdalle avatar Aug 16 '22 20:08 gdalle

Note: whenever the function is scalar, use static vectors to store the output

gdalle avatar May 24 '23 07:05 gdalle

After some reflection, I'm not sure this is a good idea. The "array in, array out" convention is fairly clear. If we allow scalars in the output, why not allow them in the input? It will be a lot of headaches and code complexity for a quality-of-life improvement

gdalle avatar May 27 '23 10:05 gdalle

If we test that our implementation works for SVectors, we can document that to make it work for scalars, users must wrap the scalar in a static vector of length 1.

mohdibntarek avatar May 27 '23 14:05 mohdibntarek

Sounds good to me

gdalle avatar May 30 '23 08:05 gdalle

Code suggestion from @mohamed82008 in #51:

struct ScalarImplicitFunction{F}
  f::F
end
function ScalarImplicitFunction(forward, conditions, linear_solver)
  _forward = x -> begin
    y, z = forward(only(x))
    return @SVector([y]), z
  end
  _conditions = (x, y, z) -> @SVector [conditions(only(x), only(y), z)]
  f = ImplicitFunction(_forward, _conditions, linear_solver)
  return ScalarImplicitFunction(f)
end
function (f::ScalarImplicitFunction)(x::Real;_ kwargs...)
  return only(f.f(@SVector([x]); kwargs...))
end

gdalle avatar May 30 '23 09:05 gdalle

Won't the implmentation above require the input to be scalar as well? It'd be great if all of these worked

\begin{aligned}
\mathbb{R} &\rightarrow \mathbb{R}\\
\mathbb{R}^n &\rightarrow \mathbb{R}\\
\mathbb{R} &\rightarrow \mathbb{R}^n\\
\mathbb{R}^n &\rightarrow \mathbb{R}^n
\end{aligned}

baggepinnen avatar May 30 '23 16:05 baggepinnen

I agree that it would be great, but my failed attempt to implement the reverse rule (#51) showed me it's a bit harder than it appears. While I'm grateful for your forward rule implementation, I didn't want to merge only half of the solution, which is why I shelved it for now.

Our first priority with @mohamed82008 is getting the API stabilized by making the byproduct z optional throughout (#57). Once that is done, we'll test the StaticVector solution and thoroughly document it (which should also bring about performance improvements in the linear solve, as you noted). And after that, we'll reconsider how we can add dispatches that go from x to SVector(x) and back :)

gdalle avatar May 30 '23 17:05 gdalle

I reopened the PR and tagged you on the main hurdle I faced, in case you come up with a good workaround!

gdalle avatar May 30 '23 17:05 gdalle

Won't the implmentation above require the input to be scalar as well? It'd be great if all of these worked

It's possible to modify the implementation to support all these cases if you pass in some sample inputs.

mohdibntarek avatar May 30 '23 20:05 mohdibntarek

The main branch now supports static arrays. Static arrays of size 1 should be very similar in performance to scalars.

mohdibntarek avatar Jul 30 '23 13:07 mohdibntarek