skorch icon indicating copy to clipboard operation
skorch copied to clipboard

Disable caching in scoring when using user supplied functions

Open ottonemo opened this issue 6 years ago • 13 comments

Whenever the user uses subsets of the validation data (or different data altogether) the caching will produce wrong results. Sometimes this is caught automatically when there is a mismatch of dimensions but this is not guaranteed to be the case and when, for example, swapping the underlying model which produces the same shape of data but different values, this might not be detected.

To prevent this from happening I propose to change the default behavior: when *Scoring is supplied with a user-defined function (or None value) we disable caching as a precaution and possibly notify the user about it. Caching can then be re-activated once the user is sure that this will not break the scoring.

Two problems:

  1. How to detect user-defined functions reliably without too much overhead
  2. How to force caching

ottonemo avatar Nov 23 '18 10:11 ottonemo

when *Scoring is supplied with a user-defined function (or None value) we disable caching as a precaution and possibly notify the user about it

This sounds a little bit too much for me. But there are certainly edge cases where caching leads to unexpected behavior (e.g. when the scoring function predicts twice, the second time the prediction is empty). Maybe there is a more fundamental way that we can go about this that would solve a couple of existing issues.

BenjaminBossan avatar Nov 24 '18 10:11 BenjaminBossan

when *Scoring is supplied with a user-defined function (or None value) we disable caching as a precaution and possibly notify the user about it

This sounds a little bit too much for me. But there are certainly edge cases where caching leads to unexpected behavior (e.g. when the scoring function predicts twice, the second time the prediction is empty).

Well the benefit would be that we are on the safe side where all user-defined code will return what the user intended. There are, of course, more subtle ways, for example issuing a warning when this combination is met but that would be more annoying than helpful I think.

Maybe there is a more fundamental way that we can go about this that would solve a couple of existing issues.

OK so at least for caching lets gather some issues, maybe there's a less radical solution available. Related to caching I can think of the following issues:

  1. predicting twice in a scoring function depletes the cache, leading to an exception
  2. predicting with more data in a scoring function may deplete the cache
  3. predicting with less data in a scoring function breaks with a exception (y_true.shape != y_pred.shape)
  4. predicting with other data in a scoring function will not return the results for this data

Examples:

def oops1(net, X, y): net.predict(X[X['mask']], y[X['mask']]) # (3)
def oops2(net, X, y): net.predict(X, y) + net.predict(X, y) # (1)
def oops3(net, X, y): net.predict(X[shuffled_indices(X)], y[shuffled_indices(X)]) # (4), silent
def oops4(net, X, y): net.predict(concat(X, X), concat(y, y)) # (2)

One thing all these issues have in common is that the input data changed. If there was a cheap way of detecting that the input data changed (shape or contents) then this could be used as a signal to invalidate (or ignore) the cache in the scoring callback.

But maybe I'm too focused on the current structure and there is a solution waiting a level higher up.

ottonemo avatar Nov 26 '18 08:11 ottonemo

One thing all these issues have in common is that the input data changed.

Almost, since 1. does not necessarily imply that, but the exhaustion problem could be solved by using TeeGenerator.

Regarding the changing data, in theory we could hash the data and then compare it but this has of course a couple of issues, so let's try to find a better solution.

But maybe I'm too focused on the current structure and there is a solution waiting a level higher up.

This will almost certainly mean that we need to think about how we do caching, opening this old can of worms. But I agree that something needs to be done.

benjamin-work avatar Nov 26 '18 09:11 benjamin-work

When do you see a user change the validation data in the middle of an epoch? Specifically, a user would need to change the validation set after the y_pred is calculated, and before *Scoring callbacks are called.

thomasjpfan avatar Nov 26 '18 16:11 thomasjpfan

When do you see a user change the validation data in the middle of an epoch? Specifically, a user would need to change the validation set after the y_pred is calculated, and before *Scoring callbacks are called.

Consider the case where the input data is two-fold, one half supervised, one half unsupervised (for a semi-supervised setting). Some scores only work with labels which is why the user needs to do some filtering in the scoring function. Therefore, different data than the cached data is presented during scoring.

For example:

def score(self, net, X, y):
    labelled = X['is_labelled']
    X_ = multi_indexing(X, labelled)
    y_ = y[labelled]
    return accuracy(y_, self.predict(X_))

githubnemo avatar Nov 26 '18 22:11 githubnemo

@githubnemo Thank you for the example.

Almost, since 1. does not necessarily imply that, but the exhaustion problem could be solved by using TeeGenerator.

This works.

I propose the following to help with the "incorrect amount of data issue":

  1. Store the amount of data that is cached.
  2. Patch net.predict to get the amount of data passed in.
  3. Issue a warning if they are not equal.

For the case of shuffled data, I think only good documentation, to get a user to set caching==False would be best.

thomasjpfan avatar Nov 27 '18 02:11 thomasjpfan

I would really like to see a more complete solution than just looking at the amount of data. As mentioned, it is easy to run into errors that are not covered by just looking at the length. If length is covered but other things are not, users could be lulled into a false sense of security when no error is raised.

Another issue that I sometimes come across is that the value of X during scoring depends on caching:

X, y = make_classification()
X = X.astype(np.float32)

def my_accuracy(model, X, y):
    return (model.predict(X) == y).mean()

net = NeuralNetClassifier(MyModule, callbacks=[EpochScoring(my_accuracy)])
# in the scoring function, X is a Subset, y is a numpy array

net = NeuralNetClassifier(MyModule, callbacks=[EpochScoring(my_accuracy, use_caching=False)])
# in the scoring function, X is a numpy array, y is a numpy array

This can lead to confusion. Of course, one could argue that with caching, I don't need to care about X but since it is there I could believe it matters when in really it could be anything*. Several times already did I fall into the trap of doing something with X only to notice much later that it didn't matter.

*not quite, since net.get_dataset(X) is called, so it must have a len.

benjamin-work avatar Nov 28 '18 16:11 benjamin-work

I do not see a good way for us to know when the cache is invalid. We can come up with a bunch of heuristics, but it would never cover all use cases. I think @ottonemo initial suggestion may be a good way forward:

To prevent this from happening I propose to change the default behavior: when *Scoring is supplied with a user-defined function (or None value) we disable caching as a precaution and possibly notify the user about it.

We could even just change the default of caching to false, and consider caching an "advanced feature".

thomasjpfan avatar Nov 28 '18 16:11 thomasjpfan

We could even just change the default of caching to false, and consider caching an "advanced feature".

I would be reluctant to disable something that probably works just fine for the majority of users, and that is probably an expected feature.

If we had a completely free choice about how to implement this, I would try to make it a very conscious choice for the user to use cached data or not when they write their custom scorers:

def my_accuracy_with_cache(net_without_cache, X, y, y_cached):
    y_pred = np.argmax(y_cached)
    return accuracy_score(y, y_pred)

def my_accuracy_without_cache(net_without_cache, X, y, y_cached):
    y_pred = net_without_cache.predict(X)
    return accuracy_score(y, y_pred)

But then the scoring functions have a different signature than sklearn scoring functions and we should try to keep custom code needed for skorch at a minimum.

PS: It can become more complicated than that, e.g. sometimes I need the y (not y_pred) that comes out of the dataset/dataloader, sometimes the y as in net.fit(X, y). When caching is on, I get the Subset and can extract the y from it, without caching it's more complicated.

benjamin-work avatar Nov 29 '18 16:11 benjamin-work

I admit the following is a fairly strange proposal: When in the scoring functions, we can temporary replace the predict function with NeuralNet.predict(..., cache=False) to allow scorers to control when they want to avoid the cache.

thomasjpfan avatar Dec 09 '18 17:12 thomasjpfan

Could you elaborate on that?

benjamin-work avatar Dec 10 '18 13:12 benjamin-work

On second thought, patching out the predict function is not ideal. We need a way to allow a user to disable the cache in the scoring function. Maybe another context manager:

def score_something(net, X, y):
    with skorch.helper.disable_cache(net):
        y_pred = net.predict(X)
    ...

thomasjpfan avatar Dec 10 '18 23:12 thomasjpfan

We started out with the premise that it is easy to make an error with caching enabled and not noticing it. If users have to use a context manager, they are already aware of caching and might as well disable it on the callback directly.

As said earlier, it would be nice if we could force the user to make a conscious choice about whether to use caching or not, but my proposal involves breaking the API or having a bunch of custom code. Also, we shouldn't disrupt code that currently just work, so I would oppose changing the default.

Another issue for me is that the data you get in your scoring callback changes depending on whether caching is on or not, which is not at all ideal.

benjamin-work avatar Dec 11 '18 14:12 benjamin-work