pints icon indicating copy to clipboard operation
pints copied to clipboard

Allow simulation output to be cached

Open MichaelClerx opened this issue 4 years ago • 2 comments

There could be use cases for caching at several levels, e.g.

  • CachingForwardModel(model), checks times and parameters against a (limited size?) dict and returns cached results if possible. Not sure when you'd use this
  • CachingSingle/MultiSeriesProblem(problem), checks parameters against a (limited size?) dict and returns cached results if possible. Useful in #1403 .
  • CachingLikelihood and CachingError, checks parameters against a (limited size?) dict and returns cached results if possible, useful if we expect methods to test the same parameter sets multiple times (do we?)

I've written them as wrapper classes here, as that's the minimal effort for developers solution (no changes to underlying classes). But we could also consider updating the base classes with some reusable caching code and making all derived classes use it

MichaelClerx avatar Oct 12 '21 10:10 MichaelClerx

We might only need it for Error and LogPdf ?

Here's an example for an error. It uses the lru_cache decorator.

The only complication is that, because it needs to place x in a dict, x needs to be hashable. Numpy arrays are not hashable, so here it converts it to a tuple first.

class CachingError(pints.ErrorMeasure):
    """
    Wraps around another error measure and provides a ``sensitivities()``
    method for use with e.g. ``scipy``.

    All calls to the error and to :meth:`sensitivities` are mapped to
    :meth:`evaluateS1`. To reduce redundant calculations, up to 32 results of
    calling ``evaluateS1`` will be kept in cache.

    Note: Using this wrapper for methods that don't require sensitivities will
    result in very poor performance.
    """

    def __init__(self, error):
        self._e = error

    def __call__(self, x):
        return self._both(tuple(x))[0]

    @functools.lru_cache(maxsize=32)
    def _both(self, x):
        return self._e.evaluateS1(x)

    def evaluateS1(self, x):
        return self._both(tuple(x))

    def n_parameters(self):
        return self._e.n_parameters()

    def sensitivities(self, x):
        return self._both(tuple(x))[1]

MichaelClerx avatar Aug 13 '22 11:08 MichaelClerx

Just noticed that this is not required for scipy, you should set jac=True instead and have it call evaluateS1 :D

MichaelClerx avatar Aug 13 '22 14:08 MichaelClerx