skorch
skorch copied to clipboard
Disable caching in scoring when using user supplied functions
Whenever the user uses subsets of the validation data (or different data altogether) the caching will produce wrong results. Sometimes this is caught automatically when there is a mismatch of dimensions but this is not guaranteed to be the case and when, for example, swapping the underlying model which produces the same shape of data but different values, this might not be detected.
To prevent this from happening I propose to change the default behavior: when *Scoring
is supplied with a user-defined function (or None
value) we disable caching as a precaution and possibly notify the user about it. Caching can then be re-activated once the user is sure that this will not break the scoring.
Two problems:
- How to detect user-defined functions reliably without too much overhead
- How to force caching
when
*Scoring
is supplied with a user-defined function (orNone
value) we disable caching as a precaution and possibly notify the user about it
This sounds a little bit too much for me. But there are certainly edge cases where caching leads to unexpected behavior (e.g. when the scoring function predicts twice, the second time the prediction is empty). Maybe there is a more fundamental way that we can go about this that would solve a couple of existing issues.
when
*Scoring
is supplied with a user-defined function (orNone
value) we disable caching as a precaution and possibly notify the user about itThis sounds a little bit too much for me. But there are certainly edge cases where caching leads to unexpected behavior (e.g. when the scoring function predicts twice, the second time the prediction is empty).
Well the benefit would be that we are on the safe side where all user-defined code will return what the user intended. There are, of course, more subtle ways, for example issuing a warning when this combination is met but that would be more annoying than helpful I think.
Maybe there is a more fundamental way that we can go about this that would solve a couple of existing issues.
OK so at least for caching lets gather some issues, maybe there's a less radical solution available. Related to caching I can think of the following issues:
- predicting twice in a scoring function depletes the cache, leading to an exception
- predicting with more data in a scoring function may deplete the cache
- predicting with less data in a scoring function breaks with a exception (
y_true.shape != y_pred.shape
) - predicting with other data in a scoring function will not return the results for this data
Examples:
def oops1(net, X, y): net.predict(X[X['mask']], y[X['mask']]) # (3)
def oops2(net, X, y): net.predict(X, y) + net.predict(X, y) # (1)
def oops3(net, X, y): net.predict(X[shuffled_indices(X)], y[shuffled_indices(X)]) # (4), silent
def oops4(net, X, y): net.predict(concat(X, X), concat(y, y)) # (2)
One thing all these issues have in common is that the input data changed. If there was a cheap way of detecting that the input data changed (shape or contents) then this could be used as a signal to invalidate (or ignore) the cache in the scoring callback.
But maybe I'm too focused on the current structure and there is a solution waiting a level higher up.
One thing all these issues have in common is that the input data changed.
Almost, since 1. does not necessarily imply that, but the exhaustion problem could be solved by using TeeGenerator
.
Regarding the changing data, in theory we could hash the data and then compare it but this has of course a couple of issues, so let's try to find a better solution.
But maybe I'm too focused on the current structure and there is a solution waiting a level higher up.
This will almost certainly mean that we need to think about how we do caching, opening this old can of worms. But I agree that something needs to be done.
When do you see a user change the validation data in the middle of an epoch? Specifically, a user would need to change the validation set after the y_pred
is calculated, and before *Scoring
callbacks are called.
When do you see a user change the validation data in the middle of an epoch? Specifically, a user would need to change the validation set after the
y_pred
is calculated, and before*Scoring
callbacks are called.
Consider the case where the input data is two-fold, one half supervised, one half unsupervised (for a semi-supervised setting). Some scores only work with labels which is why the user needs to do some filtering in the scoring function. Therefore, different data than the cached data is presented during scoring.
For example:
def score(self, net, X, y):
labelled = X['is_labelled']
X_ = multi_indexing(X, labelled)
y_ = y[labelled]
return accuracy(y_, self.predict(X_))
@githubnemo Thank you for the example.
Almost, since 1. does not necessarily imply that, but the exhaustion problem could be solved by using TeeGenerator.
This works.
I propose the following to help with the "incorrect amount of data issue":
- Store the amount of data that is cached.
- Patch
net.predict
to get the amount of data passed in. - Issue a warning if they are not equal.
For the case of shuffled data, I think only good documentation, to get a user to set caching==False
would be best.
I would really like to see a more complete solution than just looking at the amount of data. As mentioned, it is easy to run into errors that are not covered by just looking at the length. If length is covered but other things are not, users could be lulled into a false sense of security when no error is raised.
Another issue that I sometimes come across is that the value of X
during scoring depends on caching:
X, y = make_classification()
X = X.astype(np.float32)
def my_accuracy(model, X, y):
return (model.predict(X) == y).mean()
net = NeuralNetClassifier(MyModule, callbacks=[EpochScoring(my_accuracy)])
# in the scoring function, X is a Subset, y is a numpy array
net = NeuralNetClassifier(MyModule, callbacks=[EpochScoring(my_accuracy, use_caching=False)])
# in the scoring function, X is a numpy array, y is a numpy array
This can lead to confusion. Of course, one could argue that with caching, I don't need to care about X
but since it is there I could believe it matters when in really it could be anything*. Several times already did I fall into the trap of doing something with X
only to notice much later that it didn't matter.
*not quite, since net.get_dataset(X)
is called, so it must have a len
.
I do not see a good way for us to know when the cache is invalid. We can come up with a bunch of heuristics, but it would never cover all use cases. I think @ottonemo initial suggestion may be a good way forward:
To prevent this from happening I propose to change the default behavior: when *Scoring is supplied with a user-defined function (or None value) we disable caching as a precaution and possibly notify the user about it.
We could even just change the default of caching to false
, and consider caching
an "advanced feature".
We could even just change the default of caching to
false
, and considercaching
an "advanced feature".
I would be reluctant to disable something that probably works just fine for the majority of users, and that is probably an expected feature.
If we had a completely free choice about how to implement this, I would try to make it a very conscious choice for the user to use cached data or not when they write their custom scorers:
def my_accuracy_with_cache(net_without_cache, X, y, y_cached):
y_pred = np.argmax(y_cached)
return accuracy_score(y, y_pred)
def my_accuracy_without_cache(net_without_cache, X, y, y_cached):
y_pred = net_without_cache.predict(X)
return accuracy_score(y, y_pred)
But then the scoring functions have a different signature than sklearn scoring
functions and we should try to keep custom code needed for skorch at a minimum.
PS: It can become more complicated than that, e.g. sometimes I need the y
(not y_pred
) that comes out of the dataset/dataloader, sometimes the y
as in net.fit(X, y)
. When caching is on, I get the Subset
and can extract the y
from it, without caching it's more complicated.
I admit the following is a fairly strange proposal: When in the scoring functions, we can temporary replace the predict function with NeuralNet.predict(..., cache=False)
to allow scorers to control when they want to avoid the cache.
Could you elaborate on that?
On second thought, patching out the predict function is not ideal. We need a way to allow a user to disable the cache in the scoring function. Maybe another context manager:
def score_something(net, X, y):
with skorch.helper.disable_cache(net):
y_pred = net.predict(X)
...
We started out with the premise that it is easy to make an error with caching enabled and not noticing it. If users have to use a context manager, they are already aware of caching and might as well disable it on the callback directly.
As said earlier, it would be nice if we could force the user to make a conscious choice about whether to use caching or not, but my proposal involves breaking the API or having a bunch of custom code. Also, we shouldn't disrupt code that currently just work, so I would oppose changing the default.
Another issue for me is that the data you get in your scoring callback changes depending on whether caching is on or not, which is not at all ideal.