nolearn icon indicating copy to clipboard operation
nolearn copied to clipboard

Weighted loss function

Open cbonnett opened this issue 9 years ago • 5 comments

As discussed, it would be useful to be able to easily add a custom weighted cost function. Currently a custom loss function is easily implemented, but having a weighted loss function is currently not supported. A possible scikit-learn compatible API for this would be to pass the weights in the fit call like this.

fit(X, y, sample_weight=None)

I have tried to implement it, but up to now with no success. Here is the diff :

https://github.com/cbonnett/nolearn/commit/f4ad011397d966b05dfbbe517a6845c40e7c490a

cbonnett avatar Jun 02 '15 20:06 cbonnett

The PR "Add ability to weight training samples and predefine training/validation splits" #123 has an implementation for this. The pull request has problems though which are discussed here: https://github.com/dnouri/nolearn/pull/123#issuecomment-124687911

dnouri avatar Jul 24 '15 21:07 dnouri

The best way to go about this might be to create a second input layer for the weights, then use that layer inside of your objective function to multiply your loss.

Instead of calling fit like this:

fit(X, y, sample_weight=weights)

You would do:

fit({'input': X, 'weights': weights}, y)

dnouri avatar Mar 11 '16 04:03 dnouri

Would you consider an API to specify the weights per label instead of per training examples? That might be a more common use case (at least for myself).

felixlaumon avatar Mar 11 '16 04:03 felixlaumon

Hmm, would that be possible by creating a custom objective function alone?

dnouri avatar Mar 11 '16 05:03 dnouri

Yes, that's basically what I have been doing. If I remember correctly, here's what I did –

def custom_objective(...):
    ...
    weights_per_label = theano.shared(lasagne.utils.floatX([0.8, 0.1, 0.1]))
    weights = weights_per_label[target]  # This is a bit non-obvious
    loss = aggregate(loss_function(network_output, target), weights=weights)
    ...

In most of my experiments I actually found that rebalancing the dataset is more effective.

felixlaumon avatar Mar 14 '16 15:03 felixlaumon