gluonts icon indicating copy to clipboard operation
gluonts copied to clipboard

Torch studentT distribution Output not compatible with StudentT

Open karthickgopalswamy opened this issue 2 years ago • 3 comments

Description

The studentT distribution from torch expects positive (x > 0) constraint on the scale and df parameters. The current implementation takes softplus(input) and softplus(-120) > 0 results in False

To Reproduce

(Please provide minimal example of code snippet that reproduces the error. For existing examples, please provide link.)

from gluonts.torch.modules.distribution_output import StudentTOutput
import torch
s = StudentTOutput()
args = s.domain_map(torch.Tensor([2]), torch.Tensor([20]),torch.Tensor([-200]))
s.distribution(args)

Error message or code output

(Paste the complete error message, including stack trace, or the undesired output that the above snippet produces.)

Traceback (most recent call last):
  File "/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.8/lib/python3.8/code.py", line 90, in runcode
    exec(code, self.locals)
  File "<input>", line 1, in <module>
  File "/Users/kgopalsw/Ensemble/src/Ensembler/src/gluonts/torch/modules/distribution_output.py", line 139, in distribution
    distr = self._base_distribution(distr_args)
  File "/Users/kgopalsw/Ensemble/src/Ensembler/src/gluonts/torch/modules/distribution_output.py", line 116, in _base_distribution
    return self.distr_cls(*distr_args)
  File "/Users/kgopalsw/Ensemble/src/Ensembler/venv/lib/python3.8/site-packages/torch/distributions/studentT.py", line 48, in __init__
    super(StudentT, self).__init__(batch_shape, validate_args=validate_args)
  File "/Users/kgopalsw/Ensemble/src/Ensembler/venv/lib/python3.8/site-packages/torch/distributions/distribution.py", line 55, in __init__
    raise ValueError(
ValueError: Expected parameter scale (Tensor of shape ()) of distribution StudentT(df: 4.126928329467773, loc: 20.0, scale: 0.0) to satisfy the constraint GreaterThan(lower_bound=0.0), but found invalid values:
0.0

Environment

  • Operating system: mac os
  • Python version: 3.8
  • GluonTS version: '0.0.0'
  • MXNet version:'1.9.1'

(Add as much information about your environment as possible, e.g. dependencies versions.) The issue is version independent

karthickgopalswamy avatar Jul 07 '22 05:07 karthickgopalswamy

Thanks @karthickgopalswamy! In practice this may not constitute an issue, but I agree this could be made more robust.

I think the fix could be the same as it’s done for other distributions, see https://github.com/awslabs/gluon-ts/blob/5e64be79182d0f4237632b2710bd3d357f57a9bc/src/gluonts/torch/distributions/distribution_output.py#L219

Note that also NormalOutput seems to be affected.

@karthickgopalswamy do you want to open a PR?

lostella avatar Jul 07 '22 05:07 lostella

i would suggest we move to squareplus #1894 which is much more robust... and does not suffer from above issue. I can send a PR as i have that lying around...

kashif avatar Jul 07 '22 06:07 kashif

@kashif I think I would go for the "epsilon" change for the time being. Mainly for consistency with mxnet-based models, which I would not update for backward-compatibility just yet (turning to squareplus would be slightly breaking for existing trained models).

The squareplus is interesting because it requires no log nor exp, so I can see it being much faster to compute.

lostella avatar Jul 07 '22 06:07 lostella

implemented in this #2791, closing this bug

karthickgopalswamy avatar May 29 '23 05:05 karthickgopalswamy