numbakit-ode icon indicating copy to clipboard operation
numbakit-ode copied to clipboard

API design: prevent unneeded recompilation

Open maurosilber opened this issue 2 years ago • 1 comments

From: https://github.com/hgrecco/numbakit-ode/issues/12#issuecomment-1099136091

The solver shouldn't recompile when only the initial conditions y0 change. But this fact is hidden by the current API, which accepts both jitted and non-jitted functions. In the latter case, it is jitted inside Solver.__init__, leading to a "different" function as far as numba is concerned, and triggering a recompilation of (some) parts of the Solver.

Instead, we could:

  1. raise an error when the passed function is not jitted
  2. document this behaviour somewhere, as it could still be "misused" as Solver(numba.njit(func), ...)

A similar issue happens with functions which depend on parameters (func(t, y, p)). A closure (rhs(t, y)) is generated inside Solver.__init__, which is considered a different function even if the same parameters are used.

We could provide a helper function to produce the closure outside Solver.__init__:

@numba.njit
def closure(func, params):
    @numba.njit
    def rhs(t, y):
        return func(t, y, params)
    return rhs

Later, we could change the internal implementation, changing func(t, y) to func(t, y, p), and passing an array p of parameters where needed.

maurosilber avatar Apr 19 '22 13:04 maurosilber

I think we could start by raising a warning and then move to raise an error when we change to func(t, y, p)

hgrecco avatar Apr 19 '22 14:04 hgrecco