pytensor
pytensor copied to clipboard
Gradient of OpFromGraph fails
The gradients of OpFromGraph seem a bit fragile. I saw the following failures:
Multiple output
from pytensor.compile.builders import OpFromGraph
import pytensor.tensor as at
x, y = at.scalars("x", "y")
out1 = x + y
out2 = x * y
op = OpFromGraph([x, y], [out1, out2])
outs = op(x, y)
at.grad(outs[0].sum(), x)
Traceback (most recent call last):
File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3441, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-2-ebcb546bdac3>", line 9, in <module>
at.grad(outs[0].sum(), x)
File "/home/ricardo/Documents/Projects/aesara/aesara/gradient.py", line 623, in grad
_rval: Sequence[Variable] = _populate_grad_dict(
File "/home/ricardo/Documents/Projects/aesara/aesara/gradient.py", line 1434, in _populate_grad_dict
rval = [access_grad_cache(elem) for elem in wrt]
File "/home/ricardo/Documents/Projects/aesara/aesara/gradient.py", line 1434, in <listcomp>
rval = [access_grad_cache(elem) for elem in wrt]
File "/home/ricardo/Documents/Projects/aesara/aesara/gradient.py", line 1387, in access_grad_cache
term = access_term_cache(node)[idx]
File "/home/ricardo/Documents/Projects/aesara/aesara/gradient.py", line 1213, in access_term_cache
input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)
File "/home/ricardo/Documents/Projects/aesara/aesara/compile/builders.py", line 744, in L_op
ret_ofg_l = self._lop_op(*inps, return_list=True)
File "/home/ricardo/Documents/Projects/aesara/aesara/compile/builders.py", line 769, in __call__
return super().__call__(*actual_inputs, **kwargs)
File "/home/ricardo/Documents/Projects/aesara/aesara/graph/op.py", line 297, in __call__
node = self.make_node(*inputs, **kwargs)
File "/home/ricardo/Documents/Projects/aesara/aesara/compile/builders.py", line 784, in make_node
non_shared_inputs = [
File "/home/ricardo/Documents/Projects/aesara/aesara/compile/builders.py", line 785, in <listcomp>
inp_t.filter_variable(inp)
File "/home/ricardo/Documents/Projects/aesara/aesara/tensor/type.py", line 262, in filter_variable
other2 = self.convert_variable(other)
File "/home/ricardo/Documents/Projects/aesara/aesara/tensor/type.py", line 328, in convert_variable
if (self.ndim == var.type.ndim) and (self.dtype == var.type.dtype):
AttributeError: 'DisconnectedType' object has no attribute 'ndim'
Single output, involving a discrete Elemwise input
from aesara.compile.builders import OpFromGraph
import aesara.tensor as at
x = at.scalar("x")
y = at.lscalar("y")
out1 = x + at.switch(at.eq(y, 0), -1, 1)
at.grad(out1, x) # Fine
op = OpFromGraph([x, y], [out1])
out2 = op(x, y)
at.grad(out2, x) # Fails
Traceback (most recent call last):
File "/home/ricardo/miniconda3/envs/pymcx/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3433, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-2-b1c4038d13ee>", line 11, in <module>
at.grad(out2, x)
File "/home/ricardo/miniconda3/envs/pymcx/lib/python3.10/site-packages/aesara/gradient.py", line 521, in grad
var_to_app_to_idx = _populate_var_to_app_to_idx(outputs, _wrt, consider_constant)
File "/home/ricardo/miniconda3/envs/pymcx/lib/python3.10/site-packages/aesara/gradient.py", line 968, in _populate_var_to_app_to_idx
account_for(output)
File "/home/ricardo/miniconda3/envs/pymcx/lib/python3.10/site-packages/aesara/gradient.py", line 939, in account_for
connection_pattern = _node_to_pattern(app)
File "/home/ricardo/miniconda3/envs/pymcx/lib/python3.10/site-packages/aesara/gradient.py", line 817, in _node_to_pattern
connection_pattern = node.op.connection_pattern(node)
File "/home/ricardo/miniconda3/envs/pymcx/lib/python3.10/site-packages/aesara/compile/builders.py", line 851, in connection_pattern
lop_op = self.get_lop_op()
File "/home/ricardo/miniconda3/envs/pymcx/lib/python3.10/site-packages/aesara/compile/builders.py", line 700, in get_lop_op
self._recompute_lop_op()
File "/home/ricardo/miniconda3/envs/pymcx/lib/python3.10/site-packages/aesara/configparser.py", line 47, in res
return f(*args, **kwargs)
File "/home/ricardo/miniconda3/envs/pymcx/lib/python3.10/site-packages/aesara/compile/builders.py", line 495, in _recompute_lop_op
gdefaults_l = fn_grad(wrt=local_inputs)
File "/home/ricardo/miniconda3/envs/pymcx/lib/python3.10/site-packages/aesara/gradient.py", line 623, in grad
_rval: Sequence[Variable] = _populate_grad_dict(
File "/home/ricardo/miniconda3/envs/pymcx/lib/python3.10/site-packages/aesara/gradient.py", line 1434, in _populate_grad_dict
rval = [access_grad_cache(elem) for elem in wrt]
File "/home/ricardo/miniconda3/envs/pymcx/lib/python3.10/site-packages/aesara/gradient.py", line 1434, in <listcomp>
rval = [access_grad_cache(elem) for elem in wrt]
File "/home/ricardo/miniconda3/envs/pymcx/lib/python3.10/site-packages/aesara/gradient.py", line 1387, in access_grad_cache
term = access_term_cache(node)[idx]
File "/home/ricardo/miniconda3/envs/pymcx/lib/python3.10/site-packages/aesara/gradient.py", line 1058, in access_term_cache
output_grads = [access_grad_cache(var) for var in node.outputs]
File "/home/ricardo/miniconda3/envs/pymcx/lib/python3.10/site-packages/aesara/gradient.py", line 1058, in <listcomp>
output_grads = [access_grad_cache(var) for var in node.outputs]
File "/home/ricardo/miniconda3/envs/pymcx/lib/python3.10/site-packages/aesara/gradient.py", line 1387, in access_grad_cache
term = access_term_cache(node)[idx]
File "/home/ricardo/miniconda3/envs/pymcx/lib/python3.10/site-packages/aesara/gradient.py", line 1213, in access_term_cache
input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)
File "/home/ricardo/miniconda3/envs/pymcx/lib/python3.10/site-packages/aesara/tensor/elemwise.py", line 548, in L_op
rval = self._bgrad(inputs, outs, ograds)
File "/home/ricardo/miniconda3/envs/pymcx/lib/python3.10/site-packages/aesara/tensor/elemwise.py", line 648, in _bgrad
ret.append(transform(scalar_igrad))
File "/home/ricardo/miniconda3/envs/pymcx/lib/python3.10/site-packages/aesara/tensor/elemwise.py", line 621, in transform
if isinstance(r.type, (NullType, DisconnectedType)):
AttributeError: 'float' object has no attribute 'type'
Other OpFromGraph related problems: https://discourse.pymc.io/t/mixture-of-censored-iid-normals/13234/19?u=ricardov94
Apparently this was a known limitation: https://github.com/pymc-devs/pytensor/blob/97317a50f27c54db44891a60265362fedaf5700a/pytensor/compile/builders.py#L168-L169
The second issue was seen in OpFromGraph but actually caused by the already fixed: https://github.com/pymc-devs/pytensor/pull/331