torch_truncnorm icon indicating copy to clipboard operation
torch_truncnorm copied to clipboard

Invalid value when calling log_prob after sample

Open louisabraham opened this issue 2 years ago • 5 comments

My code looks like

    m = TruncatedNormal(loc, scale, 0, 1)
    action_pt = m.sample()
    return m.log_prob(action_pt)

It looks like action_pt can take the value 1.0 and causes log_prob to raise an error:

    111     def log_prob(self, value):
    112         if self._validate_args:
--> 113             self._validate_sample(value)
    114         return CONST_LOG_INV_SQRT_2PI - self._log_Z - (value**2) * 0.5
    115 

~/.pyenv/versions/3.8.8/lib/python3.8/site-packages/torch/distributions/distribution.py in _validate_sample(self, value)
    291         valid = support.check(value)
    292         if not valid.all():
--> 293             raise ValueError(

I don't know if the error is:

  1. that the value 1.0 shouldn't be able to be picked
  2. that the value 1.0 is in the possible interval and shouldn't be called out as impossible

louisabraham avatar Dec 23 '22 20:12 louisabraham

The exception is raised based on the support check, meaning that 1.0 doesn't land into the support interval. Since loc and scale aren't given in the snippet, it is hard to say if this is an issue with precision or incorrect usage of parameters. The interface was designed to follow conventions of the similar scipy function

toshas avatar Dec 23 '22 20:12 toshas

Here is a reproducible example:

m = TruncatedNormal(torch.full((1000,), 2.), torch.full((1000,), .2), 0, 1)
m.log_prob(m.sample())

Some values are more than 1 and some are -inf:

tensor([0.9667, 0.9819, 0.9819, 0.9160, 0.9819, 0.9974, 0.9411, 0.9929, 0.9667,
        0.9819, 0.9667, 0.9160, 0.9667,   -inf, 0.9560,   -inf, 0.9560, 0.9160,
        0.9160, 0.9751, 0.9160, 0.9411, 1.0015,...

louisabraham avatar Dec 23 '22 22:12 louisabraham

One thing I'd try first is plug these values in the unit test here https://github.com/toshas/torch_truncnorm/blob/main/tests/test.py#L97 and see if it passes the check against scipy. If not, there is a bug..

toshas avatar Dec 24 '22 00:12 toshas

I added a line self._test_numerical(2.0, 0.2, 0.0, 1.0)

It gives:

======================================================================
FAIL: test_simple (__main__.Tests)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "testa.py", line 103, in test_simple
    self._test_numerical(2.0, 0.2, 0.0, 1.0)
  File "testa.py", line 74, in _test_numerical
    self.assertRelativelyEqual(mean_sc, mean_pt)
  File "testa.py", line 66, in assertRelativelyEqual
    raise self.failureException(msg)
AssertionError: array(0.96269921) != array(1.0022793, dtype=float32) within tol=1e-06 abs=1e-05 (rel=0.03949006605978869 diff=0.039580075041381724)

======================================================================
FAIL: test_support (__main__.Tests)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "testa.py", line 131, in test_support
    self.assertEqual(
AssertionError: 'Expected value argument (Tensor of shape [157 chars]10.0' != 'The value argument must be within the support'
+ The value argument must be within the support- Expected value argument (Tensor of shape ()) to be within the support (Interval(lower_bound=-1.0, upper_bound=2.0)) of the distribution TruncatedNormal(a: -1.0, b: 2.0), but found invalid values:
- -10.0

The second error is not due to my test, you might want to fix it in another issue. The first IS a bug.

louisabraham avatar Dec 24 '22 09:12 louisabraham

The reason seems to be that extreme values for the icdf function should be clamped. https://github.com/pytorch/rl/blob/main/torchrl/modules/distributions/truncated_normal.py

Wu-Chenyang avatar Apr 15 '23 03:04 Wu-Chenyang