pytensor
pytensor copied to clipboard
Deprecate `test_value` machinery
Description
This adds a lot of complexity for little user benefit (most don't know about this functionality in the first place)
A first pass would be to just put FutureWarnings in the right places, when flags are changed or test values accessed from tags.
As a smaller scope alternative, it would be nice to have a helper that computes intermediate values for all variables in a graph so we can show them in dprint.
Something like:
def eval_intermediate_values(
variables: Union[Sequence[Variable], FunctionGraph],
vars_to_values: Mapping[Variable, Any],
) -> Mapping[Variable, Any] :
For instance
x = pt.scalar("x")
y = x - 1
z = pt.log(y)
eval_intermediate_values(z, {x: 0.5})
# {x: 0.5, y: -0.5, z: nan}
A first pass would be to just put FutureWarnings in the right places, when flags are changed or test values accessed from tags.
As a smaller scope alternative, it would be nice to have a helper that computes intermediate values for all variables in a graph so we can show them in dprint.
Something like:
def eval_intermediate_values( variables: Union[Sequence[Variable], FunctionGraph], vars_to_values: Mapping[Variable, Any], ) -> Mapping[Variable, Any] :
For instance
x = pt.scalar("x") y = x - 1 z = pt.log(y) eval_intermediate_values(z, {x: 0.5}) # {x: 0.5, y: -0.5, z: nan}
Can this approach lead to out of memory in some scenarios?
Moreover the eval_intermediate_values
seem to be useless since for small graphs you can eval
by hand and for large graphs there is no idea what is a variable corresponding to and what leads to nans
Moreover the
eval_intermediate_values
seem to be useless since for small graphs you caneval
by hand and for large graphs there is no idea what is a variable corresponding to and what leads to nans
The idea of that is that you can see it in dprint, which you can already with test values. That's useful because it shows which operations produced nans
Can this approach lead to out of memory in some scenarios?
This wouldn't take up more memory than the current test value approach so I don't think it's an important concern
The idea of that is that you can see it in dprint, which you can already with test values. That's useful because it shows which operations produced nans
Apparently I imagined that functionality. I still think it could be worth exploring but can be done in a separate issue.
Here is the kind of thing I had in mind:
import numpy as np
import pytensor
import pytensor.tensor as pt
pytensor.config.compute_test_value = "warn"
x = pt.vector("x")
x.tag.test_value = np.array([1, -2, 3])
y = pt.exp(pt.log(pt.tanh(x * 2)) + 3).sum()
pytensor.dprint(y)
# Sum{axes=None} [id A]nan
# └─ Exp [id B][19.36301155 nan 20.08529011]
# └─ Add [id C][2.96336463 nan 2.99998771]
# ├─ Log [id D][-3.66353747e-02 nan -1.22884247e-05]
# │ └─ Tanh [id E][ 0.96402758 -0.9993293 0.99998771]
# │ └─ Mul [id F][ 2. -4. 6.]
# │ ├─ x [id G][ 1. -2. 3.]
# │ └─ ExpandDims{axis=0} [id H][2]
# │ └─ 2 [id I]
# └─ ExpandDims{axis=0} [id J][3]
# └─ 3 [id K]
Here is another idea about providing more useful test_value-like machinery, that need not be so ingrained in the PyTensor codebase: https://gist.github.com/ricardoV94/e8902b4c35c26e87e189ab477f8d9288
Hi @ricardoV94 So in the following lines: https://github.com/pymc-devs/pytensor/blob/dbe0e09ae2a5f5d0119e0378259f843a95d36abc/pytensor/graph/op.py#L303-L304
will we want to add a warning whenever config.compute_test_value != "off"
or in cases whenever compute_test_value
is called?
Going by the first logic: https://github.com/pymc-devs/pytensor/blob/dbe0e09ae2a5f5d0119e0378259f843a95d36abc/pytensor/graph/op.py#L671-L677
we will raise a warning here only when assert action in ("ignore", "off")
fails right (I am not sure we should raise a warning when assert itself fails?)
I think the warnings make sense when config.compute_test_value!="off"
and when tag.test_value
is accessed ( which automatically includes the get_test_value
cases). Is this correct?
Yes, that's probably a good start. Then the challenging part is making sure we don't use those anywhere internally, other than direct tests of the test_value machinery (and those we just put a with pytest.warns
everytime we expect the functionality to be used.