numbakit-ode
numbakit-ode copied to clipboard
API design: prevent unneeded recompilation
From: https://github.com/hgrecco/numbakit-ode/issues/12#issuecomment-1099136091
The solver shouldn't recompile when only the initial conditions y0
change. But this fact is hidden by the current API, which accepts both jitted and non-jitted functions. In the latter case, it is jitted inside Solver.__init__
, leading to a "different" function as far as numba is concerned, and triggering a recompilation of (some) parts of the Solver.
Instead, we could:
- raise an error when the passed function is not jitted
- document this behaviour somewhere, as it could still be "misused" as
Solver(numba.njit(func), ...)
A similar issue happens with functions which depend on parameters (func(t, y, p)
). A closure (rhs(t, y)
) is generated inside Solver.__init__
, which is considered a different function even if the same parameters are used.
We could provide a helper function to produce the closure outside Solver.__init__
:
@numba.njit
def closure(func, params):
@numba.njit
def rhs(t, y):
return func(t, y, params)
return rhs
Later, we could change the internal implementation, changing func(t, y)
to func(t, y, p)
, and passing an array p
of parameters where needed.
I think we could start by raising a warning and then move to raise an error when we change to func(t, y, p)