skorch
skorch copied to clipboard
Permit to pass '**predict_params' to 'predict' method as for 'fit' method
It could be useful to add the possibility to pass custom 'predict params' in 'predict' method as it is available in 'fit' method.
Starting from 'predict', there are only 4 skorch's methods to improve. This is a simple implementation I did
class NeuralNetRegressor(skorch.NeuralNetRegressor):
def __init__(
self,
module,
*args,
criterion=torch.nn.MSELoss,
**kwargs
):
super(NeuralNetRegressor, self).__init__(
module,
*args,
criterion=criterion,
**kwargs
)
def predict(self, X, **predict_params):
return self.predict_proba(X, **predict_params)
def predict_proba(self, X, **predict_params):
nonlin = self._get_predict_nonlinearity()
y_probas = []
for yp in self.forward_iter(X, training=False, **predict_params):
yp = yp[0] if isinstance(yp, tuple) else yp
yp = nonlin(yp)
y_probas.append(to_numpy(yp))
y_proba = np.concatenate(y_probas, 0)
return y_proba
def forward_iter(self, X, training=False, device='cpu', **params):
dataset = self.get_dataset(X)
iterator = self.get_iterator(dataset, training=training)
for batch in iterator:
yp = self.evaluation_step(batch, training=training, **params)
yield to_device(yp, device=device)
def evaluation_step(self, batch, training=False, **eval_params):
self.check_is_fitted()
Xi, _ = unpack_data(batch)
with torch.set_grad_enabled(training):
self._set_training(training)
return self.infer(Xi, **eval_params)
Hey! Thanks for the suggestion.
Indeed this is something that is not symmetrical with .fit().
There are sklearn classifiers and transformers that support additional parameters so I don't see an immediate reason against it.
Are you interested in working on this and submitting an PR? :)
I believe this can already be achieved by making X a dict and using a proper collate function.