gluonts
gluonts copied to clipboard
Torch studentT distribution Output not compatible with StudentT
Description
The studentT distribution from torch expects positive (x > 0) constraint on the scale and df parameters. The current implementation takes softplus(input) and softplus(-120) > 0 results in False
To Reproduce
(Please provide minimal example of code snippet that reproduces the error. For existing examples, please provide link.)
from gluonts.torch.modules.distribution_output import StudentTOutput
import torch
s = StudentTOutput()
args = s.domain_map(torch.Tensor([2]), torch.Tensor([20]),torch.Tensor([-200]))
s.distribution(args)
Error message or code output
(Paste the complete error message, including stack trace, or the undesired output that the above snippet produces.)
Traceback (most recent call last):
File "/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.8/lib/python3.8/code.py", line 90, in runcode
exec(code, self.locals)
File "<input>", line 1, in <module>
File "/Users/kgopalsw/Ensemble/src/Ensembler/src/gluonts/torch/modules/distribution_output.py", line 139, in distribution
distr = self._base_distribution(distr_args)
File "/Users/kgopalsw/Ensemble/src/Ensembler/src/gluonts/torch/modules/distribution_output.py", line 116, in _base_distribution
return self.distr_cls(*distr_args)
File "/Users/kgopalsw/Ensemble/src/Ensembler/venv/lib/python3.8/site-packages/torch/distributions/studentT.py", line 48, in __init__
super(StudentT, self).__init__(batch_shape, validate_args=validate_args)
File "/Users/kgopalsw/Ensemble/src/Ensembler/venv/lib/python3.8/site-packages/torch/distributions/distribution.py", line 55, in __init__
raise ValueError(
ValueError: Expected parameter scale (Tensor of shape ()) of distribution StudentT(df: 4.126928329467773, loc: 20.0, scale: 0.0) to satisfy the constraint GreaterThan(lower_bound=0.0), but found invalid values:
0.0
Environment
- Operating system: mac os
- Python version: 3.8
- GluonTS version: '0.0.0'
- MXNet version:'1.9.1'
(Add as much information about your environment as possible, e.g. dependencies versions.) The issue is version independent
Thanks @karthickgopalswamy! In practice this may not constitute an issue, but I agree this could be made more robust.
I think the fix could be the same as it’s done for other distributions, see https://github.com/awslabs/gluon-ts/blob/5e64be79182d0f4237632b2710bd3d357f57a9bc/src/gluonts/torch/distributions/distribution_output.py#L219
Note that also NormalOutput
seems to be affected.
@karthickgopalswamy do you want to open a PR?
i would suggest we move to squareplus
#1894 which is much more robust... and does not suffer from above issue. I can send a PR as i have that lying around...
@kashif I think I would go for the "epsilon" change for the time being. Mainly for consistency with mxnet-based models, which I would not update for backward-compatibility just yet (turning to squareplus
would be slightly breaking for existing trained models).
The squareplus
is interesting because it requires no log
nor exp
, so I can see it being much faster to compute.
implemented in this #2791, closing this bug