Invalid value when calling log_prob after sample
My code looks like
m = TruncatedNormal(loc, scale, 0, 1)
action_pt = m.sample()
return m.log_prob(action_pt)
It looks like action_pt can take the value 1.0 and causes log_prob to raise an error:
111 def log_prob(self, value):
112 if self._validate_args:
--> 113 self._validate_sample(value)
114 return CONST_LOG_INV_SQRT_2PI - self._log_Z - (value**2) * 0.5
115
~/.pyenv/versions/3.8.8/lib/python3.8/site-packages/torch/distributions/distribution.py in _validate_sample(self, value)
291 valid = support.check(value)
292 if not valid.all():
--> 293 raise ValueError(
I don't know if the error is:
- that the value 1.0 shouldn't be able to be picked
- that the value 1.0 is in the possible interval and shouldn't be called out as impossible
The exception is raised based on the support check, meaning that 1.0 doesn't land into the support interval. Since loc and scale aren't given in the snippet, it is hard to say if this is an issue with precision or incorrect usage of parameters. The interface was designed to follow conventions of the similar scipy function
Here is a reproducible example:
m = TruncatedNormal(torch.full((1000,), 2.), torch.full((1000,), .2), 0, 1)
m.log_prob(m.sample())
Some values are more than 1 and some are -inf:
tensor([0.9667, 0.9819, 0.9819, 0.9160, 0.9819, 0.9974, 0.9411, 0.9929, 0.9667,
0.9819, 0.9667, 0.9160, 0.9667, -inf, 0.9560, -inf, 0.9560, 0.9160,
0.9160, 0.9751, 0.9160, 0.9411, 1.0015,...
One thing I'd try first is plug these values in the unit test here https://github.com/toshas/torch_truncnorm/blob/main/tests/test.py#L97 and see if it passes the check against scipy. If not, there is a bug..
I added a line self._test_numerical(2.0, 0.2, 0.0, 1.0)
It gives:
======================================================================
FAIL: test_simple (__main__.Tests)
----------------------------------------------------------------------
Traceback (most recent call last):
File "testa.py", line 103, in test_simple
self._test_numerical(2.0, 0.2, 0.0, 1.0)
File "testa.py", line 74, in _test_numerical
self.assertRelativelyEqual(mean_sc, mean_pt)
File "testa.py", line 66, in assertRelativelyEqual
raise self.failureException(msg)
AssertionError: array(0.96269921) != array(1.0022793, dtype=float32) within tol=1e-06 abs=1e-05 (rel=0.03949006605978869 diff=0.039580075041381724)
======================================================================
FAIL: test_support (__main__.Tests)
----------------------------------------------------------------------
Traceback (most recent call last):
File "testa.py", line 131, in test_support
self.assertEqual(
AssertionError: 'Expected value argument (Tensor of shape [157 chars]10.0' != 'The value argument must be within the support'
+ The value argument must be within the support- Expected value argument (Tensor of shape ()) to be within the support (Interval(lower_bound=-1.0, upper_bound=2.0)) of the distribution TruncatedNormal(a: -1.0, b: 2.0), but found invalid values:
- -10.0
The second error is not due to my test, you might want to fix it in another issue. The first IS a bug.
The reason seems to be that extreme values for the icdf function should be clamped. https://github.com/pytorch/rl/blob/main/torchrl/modules/distributions/truncated_normal.py