gluonts
gluonts copied to clipboard
Factoring loss out of distribution objects
Problem: It is not clear how one can select a different loss to train a model, other than by customizing the network's code.
(Some proposal/discussion on this problem is also in #1043.)
Models (intended as networks) in gluonts can in principle have all sorts of different outputs: sample paths, marginal samples, parametric distributions (intended as distribution objects, or the tuple of tensors containing their parameters), point forecasts, quantiles... None of these is intrinsically tied to a "loss", despite the fact that e.g. by default we compute negative log-likelihooSo it might make sense to isolate the concept of "loss" as "something that compares (some form of) prediction to ground truth"
class Loss:
def __call__(self, prediction, ground_truth):
[...]
(Type annotations are omitted on purpose). For example, for a gluonts Distribution
some losses could be
class NegativeLogLikelihood(Loss):
def __call__(self, prediction: gluonts.mx.distribution.Distribution, ground_truth: np.ndarray):
return -prediction.log_prob(ground_truth)
# one can also do any (weighted) averaging across the time dimension here
class ContinuousRankedProbabilityScore(Loss):
def __call__(self, prediction: gluonts.mx.distribution.Distribution, ground_truth: np.ndarray):
return prediction.crps(ground_truth)
# one can also do any (weighted) averaging across the time dimension here
And defining more would be extremely simple. These objects could be injected in any of the model
training components (estimator, network, trainer, depending on who's doing the loss computation),
and one can put a custom Loss object of a similar nature as the default one (i.e. with same type for the prediction
argument.
Couldn't we get away with just using functions? Or when would you initialise a loos with some custom parameters?
Also, I guess the losses would be framework specific, i.e. we would have different implementations for mxnet
and torch
.
Couldn't we get away with just using functions? Or when would you initialise a loos with some custom parameters?
Of course, one should be able to use just a function for that. When the loss has parameters, one can do a class as above or fix the parameters with partial
. To some extent it is a matter of style, but if you need to serialize it then a class with @validated
constructor might be handy
I think we really want to stop using validated
everywhere. In many cases you should get away with a simple pydantic model.
Yes, whatever works
@lostella Hi! Apologies for reviving a zombie thread, but is there any update on this issue?