gluonts icon indicating copy to clipboard operation
gluonts copied to clipboard

Factoring loss out of distribution objects

Open lostella opened this issue 3 years ago • 6 comments

Problem: It is not clear how one can select a different loss to train a model, other than by customizing the network's code.

(Some proposal/discussion on this problem is also in #1043.)

Models (intended as networks) in gluonts can in principle have all sorts of different outputs: sample paths, marginal samples, parametric distributions (intended as distribution objects, or the tuple of tensors containing their parameters), point forecasts, quantiles... None of these is intrinsically tied to a "loss", despite the fact that e.g. by default we compute negative log-likelihooSo it might make sense to isolate the concept of "loss" as "something that compares (some form of) prediction to ground truth"

class Loss:
    def __call__(self, prediction, ground_truth):
        [...]

(Type annotations are omitted on purpose). For example, for a gluonts Distribution some losses could be

class NegativeLogLikelihood(Loss):
    def __call__(self, prediction: gluonts.mx.distribution.Distribution, ground_truth: np.ndarray):
        return -prediction.log_prob(ground_truth)
        # one can also do any (weighted) averaging across the time dimension here

class ContinuousRankedProbabilityScore(Loss):
    def __call__(self, prediction: gluonts.mx.distribution.Distribution, ground_truth: np.ndarray):
        return prediction.crps(ground_truth)
        # one can also do any (weighted) averaging across the time dimension here

And defining more would be extremely simple. These objects could be injected in any of the model training components (estimator, network, trainer, depending on who's doing the loss computation), and one can put a custom Loss object of a similar nature as the default one (i.e. with same type for the prediction argument.

lostella avatar May 20 '21 08:05 lostella

Couldn't we get away with just using functions? Or when would you initialise a loos with some custom parameters?

jaheba avatar May 20 '21 10:05 jaheba

Also, I guess the losses would be framework specific, i.e. we would have different implementations for mxnet and torch.

jaheba avatar May 20 '21 10:05 jaheba

Couldn't we get away with just using functions? Or when would you initialise a loos with some custom parameters?

Of course, one should be able to use just a function for that. When the loss has parameters, one can do a class as above or fix the parameters with partial. To some extent it is a matter of style, but if you need to serialize it then a class with @validated constructor might be handy

lostella avatar May 20 '21 11:05 lostella

I think we really want to stop using validated everywhere. In many cases you should get away with a simple pydantic model.

jaheba avatar May 20 '21 11:05 jaheba

Yes, whatever works

lostella avatar May 20 '21 11:05 lostella

@lostella Hi! Apologies for reviving a zombie thread, but is there any update on this issue?

baharian avatar Feb 22 '24 20:02 baharian