Gradient of MinimizeOp fails with certain parameter shapes
Describe the issue:
When attempting to backprop through the logp of a graph which contains a MinimizeOp, an error is thrown in which it attempts to concatenate tensors of various ranks. I believe this can be remedied by flattening at_least_2d in lines 333-337:
df_dtheta = concatenate( [ atleast_2d(jac_col, left=False).flatten() for jac_col in cast(list[TensorVariable], df_dtheta_columns) ])
Reproducable code example:
import numpy as np
import pymc as pm
import pytensor.tensor as pt
import pytensor.gradient as tg
import pytensor
from pytensor.tensor.optimize import minimize
rng = np.random.default_rng(12345)
n = 10000
d = 10
mu = np.ones(d)
cov = np.diag(np.ones(d))
# Make a simple gaussian with mean x of gaussian prior
with pm.Model() as model:
x = pm.MvNormal("x", mu=mu, cov=cov)
y_obs = rng.multivariate_normal(mean=mu, cov=cov, size=n)
y = pm.MvNormal(
"y",
mu=x,
cov=cov,
observed=y_obs,
)
logp = model.logp()
# Find the mean which minimizes the logp
x0, _ = minimize(
objective=-logp,
x=model.rvs_to_values[x],
method="BFGS",
optimizer_kwargs={"tol": 1e-8},
)
y = pytensor.graph.replace.graph_replace(y, {x: x0})
# tg.grad throws the error
for var in model.value_vars:
logp = pt.sum(pm.logp(y, var))
tg.grad(logp, var)
Error message:
Cell In[5], line 41
39 var = model.rvs_to_values[x]
40 logp = pt.sum(pm.logp(y, var))
---> 41 tg.grad(logp, var)
File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/gradient.py:747, in grad(cost, wrt, consider_constant, disconnected_inputs, add_names, known_grads, return_disconnected, null_gradients)
744 if hasattr(g.type, "dtype"):
745 assert g.type.dtype in pytensor.tensor.type.float_dtypes
--> 747 _rval: Sequence[Variable] = _populate_grad_dict(
748 var_to_app_to_idx, grad_dict, _wrt, cost_name
749 )
751 rval: MutableSequence[Variable | None] = list(_rval)
753 for i in range(len(_rval)):
File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/gradient.py:1541, in _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name)
1538 # end if cache miss
1539 return grad_dict[var]
-> 1541 rval = [access_grad_cache(elem) for elem in wrt]
1543 return rval
File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/gradient.py:1496, in _populate_grad_dict.<locals>.access_grad_cache(var)
1494 for node in node_to_idx:
1495 for idx in node_to_idx[node]:
-> 1496 term = access_term_cache(node)[idx]
1498 if not isinstance(term, Variable):
1499 raise TypeError(
1500 f"{node.op}.grad returned {type(term)}, expected"
1501 " Variable instance."
1502 )
File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/gradient.py:1326, in _populate_grad_dict.<locals>.access_term_cache(node)
1318 if o_shape != g_shape:
1319 raise ValueError(
1320 "Got a gradient of shape "
1321 + str(o_shape)
1322 + " on an output of shape "
1323 + str(g_shape)
1324 )
-> 1326 input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)
1328 if input_grads is None:
1329 raise TypeError(
1330 f"{node.op}.grad returned NoneType, expected iterable."
1331 )
File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/tensor/optimize.py:565, in MinimizeOp.L_op(self, inputs, outputs, output_grads)
560 implicit_f = grad(inner_fx, inner_x)
562 df_dx, *df_dtheta_columns = jacobian(
563 implicit_f, [inner_x, *inner_args], disconnected_inputs="ignore"
564 )
--> 565 grad_wrt_args = implict_optimization_grads(
566 df_dx=df_dx,
567 df_dtheta_columns=df_dtheta_columns,
568 args=args,
569 x_star=x_star,
570 output_grad=output_grad,
571 fgraph=self.fgraph,
572 )
574 return [zeros_like(x), *grad_wrt_args]
File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/tensor/optimize.py:333, in implict_optimization_grads(df_dx, df_dtheta_columns, args, x_star, output_grad, fgraph)
290 r"""
291 Compute gradients of an optimization problem with respect to its parameters.
292
(...) 329 The function graph that contains the inputs and outputs of the optimization problem.
330 """
331 df_dx = cast(TensorVariable, df_dx)
--> 333 df_dtheta = concatenate(
334 [
335 atleast_2d(jac_col, left=False)
336 for jac_col in cast(list[TensorVariable], df_dtheta_columns)
337 ],
338 axis=-1,
339 )
341 replace = dict(zip(fgraph.inputs, (x_star, *args), strict=True))
343 df_dx_star, df_dtheta_star = cast(
344 list[TensorVariable],
345 graph_replace([atleast_2d(df_dx), df_dtheta], replace=replace),
346 )
File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/tensor/basic.py:2980, in concatenate(tensor_list, axis)
2973 if not isinstance(tensor_list, tuple | list):
2974 raise TypeError(
2975 "The 'tensors' argument must be either a tuple "
2976 "or a list, make sure you did not forget () or [] around "
2977 "arguments of concatenate.",
2978 tensor_list,
2979 )
-> 2980 return join(axis, *tensor_list)
File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/tensor/basic.py:2794, in join(axis, *tensors_list)
2792 return tensors_list[0]
2793 else:
-> 2794 return _join(axis, *tensors_list)
File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/graph/op.py:293, in Op.__call__(self, name, return_list, *inputs, **kwargs)
249 def __call__(
250 self, *inputs: Any, name=None, return_list=False, **kwargs
251 ) -> Variable | list[Variable]:
252 r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs.
253
254 This method is just a wrapper around :meth:`Op.make_node`.
(...) 291
292 """
--> 293 node = self.make_node(*inputs, **kwargs)
294 if name is not None:
295 if len(node.outputs) == 1:
File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/tensor/basic.py:2487, in Join.make_node(self, axis, *tensors)
2484 ndim = tensors[0].type.ndim
2486 if not builtins.all(x.ndim == ndim for x in tensors):
-> 2487 raise TypeError(
2488 "Only tensors with the same number of dimensions can be joined. "
2489 f"Input ndims were: {[x.ndim for x in tensors]}"
2490 )
2492 try:
2493 static_axis = int(get_scalar_constant_value(axis))
TypeError: Only tensors with the same number of dimensions can be joined. Input ndims were: [2, 2, 3, 4, 2, 2, 2, 3, 3, 4]
PyTensor version information:
Numpy: 1.26.4 PyMC: 0+untagged.10301.gdc7cfee.dirty PyTensor: 2.31.7
Context for the issue:
Bug was originally caused by trying to run pm.sample against a logp containing a MinimizeOp.
The whole setup is a bit unorthodox. What is this trying to achieve?
# tg.grad throws the error
for var in model.value_vars:
var = model.rvs_to_values[x]
logp = pt.sum(pm.logp(y, var))
tg.grad(logp, var)
You immediately ignore var, so this is y logp wrt to x, and y_obs?
Oh whoops, sorry I didn't notice I still had that there - in this simple example to replicate the problem, var = [x], so it doesn't make a difference (that was a remnant from something else I was testing). I've updated the original post now.
As for that whole codeblock, it's essentially trying to replicate the structure of the original problem with as little clutter as possible. It's a reference to the following in mcmc.py: https://github.com/pymc-devs/pymc/blob/dc7cfeeaba98330f2881d627ee988ac267f18df2/pymc/sampling/mcmc.py#L247-L255
The logp being a sum over a pm.logp() of y was there to reduce the dim of pm.logp() from 1 to 0 so minimize doesn't complain. The reason why I didn't simply use model.logp() here is that I suspect it isn't registering the graph replace adding in x0 since I can't find MinimizeOp in the graph of model.logp(). In my actual INLA code, I define a logp rewrite for a MarginalRV so in that case model.logp() works, but for this minimal example I had to concoct a weird workaround.
It would be great if you could convert this to a pure PyTensor bug. PyMC justs gets in the way
I'll give that a go - I suspect it is something to do with how pm.logp is calculated however, since simply taking tg.grad of a pure MinimizeOp seems to work fine.
Ah, here's a pure-pytensor example:
import pytensor.tensor as pt
import pytensor.gradient as tg
import pytensor
from pytensor.tensor.optimize import minimize
d = 10
x = pt.vector('x', shape=(d,))
cov = pt.matrix('cov', shape=(d,d))
_, logdet = pt.nlinalg.slogdet(cov)
y = x.T @ cov @ x + logdet
x0, _ = minimize(
objective=-y,
x=x,
method="BFGS",
optimizer_kwargs={"tol": 1e-8},
)
y = pytensor.graph.replace.graph_replace(y, {x: x0})
tg.grad(y, x)
Which throws the same error as above, except for having a smaller set of tensors of course:
TypeError: Only tensors with the same number of dimensions can be joined. Input ndims were: [2, 3]
I think it can't concatenate vectors with matrices or something to that effect.
~~Seems to identify a bug, although you should take grad with respect to x0? x is not part of the graph of y after the replace anymore?~~
~~In that case the error disappears, but still a bug was detected there~~
Nvm, x0 still dependes on x, that's like what the initial value?
Yep
Note that in the real case, it would be wrt the hyperparams not the latent field x, here I've just contracted it for simplicity
The problem seems to be the MinimizeOp expects the jacobian to be at max 2d, but that's not the case anymore after: https://github.com/pymc-devs/pytensor/pull/1228
Before that I suspect it would error out on the call to the jacobian, now it results in that more surprising bug when it tries to concatenate everything
I tried simply flattening the concatenate (as stated in the OP), but doing that leads to a later bug (the tensor shapes don't quite match up when it does an inner product, so they need to be reordered somehow)
~~Doesn't avoid it, but isn't it required that the minimization function be scalar?~~ It is, I was looking at the internal grad
Here is a smaller reproducible example:
import pytensor.tensor as pt
from pytensor.tensor.optimize import minimize
x = pt.vector('x')
scalar = pt.vector("scalar")
tensor3 = pt.tensor3('tensor3')
x0, _ = minimize(
objective=((x * tensor3) + (x * scalar)).sum(),
x=x,
)
pt.grad(x0.sum(), tensor3)
It tries to get all the parameter jacobians into a df_dtheta matrix to solve but that doesn't work here because they have different ndim.
Also I guess it should eagerly ignore the parameters that are disconnected, for performance?
@jessegrabowski shouldn't we solve one parameter at a time? That would automatically prune branches from disconnected inputs as well
Either way I guess we want the jacobian wrt to the raveled parameters (or equivalently do jacobian()...reshape(x.shape[0], -1). That should be safe to concatenate into a large matrix.
Yes, we need logic for raveling and unraveling the inputs when making the jacobian.
When you say solve one at a time, you mean here? I don't think that works, it would be assuming that cross-terms in df_dtheta_star are zero, which won't be true in general.
the ravel-jac-unravel pattern happens pretty often, it would be nice if we had some helpers for it (an Op like pymc.blocking.RaveledVars)
Yes, we need logic for raveling and unraveling the inputs when making the jacobian.
When you say solve one at a time, you mean here? I don't think that works, it would be assuming that cross-terms in df_dtheta_star are zero, which won't be true in general.
You're right. Is it true for disconnected gradients though?
the ravel-jac-unravel pattern happens pretty often, it would be nice if we had some helpers for it (an Op like
pymc.blocking.RaveledVars)
Something similar to einops pack/unpack: https://einops.rocks/api/pack_unpack/
Pack returns the object needed to then unpack into the right shapes (basically a tuple with the original shapes)
You're right. Is it true for disconnected gradients though?
We should be filtering those out, yeah
Something similar to einops pack/unpack
Yes, I want exactly this.
The general pack/unpack may also address #1552