skorch icon indicating copy to clipboard operation
skorch copied to clipboard

Permit to pass '**predict_params' to 'predict' method as for 'fit' method

Open corradomio opened this issue 1 year ago • 2 comments

It could be useful to add the possibility to pass custom 'predict params' in 'predict' method as it is available in 'fit' method.

Starting from 'predict', there are only 4 skorch's methods to improve. This is a simple implementation I did

class NeuralNetRegressor(skorch.NeuralNetRegressor):
    def __init__(
            self,
            module,
            *args,
            criterion=torch.nn.MSELoss,
            **kwargs
    ):
        super(NeuralNetRegressor, self).__init__(
            module,
            *args,
            criterion=criterion,
            **kwargs
        )

    def predict(self, X, **predict_params):
        return self.predict_proba(X, **predict_params)

    def predict_proba(self, X, **predict_params):
        nonlin = self._get_predict_nonlinearity()
        y_probas = []
        for yp in self.forward_iter(X, training=False, **predict_params):
            yp = yp[0] if isinstance(yp, tuple) else yp
            yp = nonlin(yp)
            y_probas.append(to_numpy(yp))
        y_proba = np.concatenate(y_probas, 0)
        return y_proba

    def forward_iter(self, X, training=False, device='cpu', **params):
        dataset = self.get_dataset(X)
        iterator = self.get_iterator(dataset, training=training)
        for batch in iterator:
            yp = self.evaluation_step(batch, training=training, **params)
            yield to_device(yp, device=device)

    def evaluation_step(self, batch, training=False, **eval_params):
        self.check_is_fitted()
        Xi, _ = unpack_data(batch)
        with torch.set_grad_enabled(training):
            self._set_training(training)
            return self.infer(Xi, **eval_params)

corradomio avatar Jan 26 '24 06:01 corradomio

Hey! Thanks for the suggestion.

Indeed this is something that is not symmetrical with .fit(). There are sklearn classifiers and transformers that support additional parameters so I don't see an immediate reason against it.

Are you interested in working on this and submitting an PR? :)

githubnemo avatar Jan 27 '24 00:01 githubnemo

I believe this can already be achieved by making X a dict and using a proper collate function.

ramonamezquita avatar Mar 12 '24 23:03 ramonamezquita