pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Gradient of MinimizeOp fails with certain parameter shapes

Open Michal-Novomestsky opened this issue 4 months ago • 19 comments

Describe the issue:

When attempting to backprop through the logp of a graph which contains a MinimizeOp, an error is thrown in which it attempts to concatenate tensors of various ranks. I believe this can be remedied by flattening at_least_2d in lines 333-337:

df_dtheta = concatenate( [ atleast_2d(jac_col, left=False).flatten() for jac_col in cast(list[TensorVariable], df_dtheta_columns) ])

Reproducable code example:

import numpy as np
import pymc as pm
import pytensor.tensor as pt
import pytensor.gradient as tg
import pytensor
from pytensor.tensor.optimize import minimize

rng = np.random.default_rng(12345)
n = 10000
d = 10

mu = np.ones(d)
cov = np.diag(np.ones(d))

# Make a simple gaussian with mean x of gaussian prior
with pm.Model() as model:
    x = pm.MvNormal("x", mu=mu, cov=cov)

    y_obs = rng.multivariate_normal(mean=mu, cov=cov, size=n)

    y = pm.MvNormal(
        "y",
        mu=x,
        cov=cov,
        observed=y_obs,
    )

    logp = model.logp()

    # Find the mean which minimizes the logp
    x0, _ = minimize(
        objective=-logp,
        x=model.rvs_to_values[x],
        method="BFGS",
        optimizer_kwargs={"tol": 1e-8},
    )

    y = pytensor.graph.replace.graph_replace(y, {x: x0})
    
    # tg.grad throws the error
    for var in model.value_vars:
        logp = pt.sum(pm.logp(y, var))
        tg.grad(logp, var)

Error message:

Cell In[5], line 41
     39 var = model.rvs_to_values[x]
     40 logp = pt.sum(pm.logp(y, var))
---> 41 tg.grad(logp, var)

File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/gradient.py:747, in grad(cost, wrt, consider_constant, disconnected_inputs, add_names, known_grads, return_disconnected, null_gradients)
    744     if hasattr(g.type, "dtype"):
    745         assert g.type.dtype in pytensor.tensor.type.float_dtypes
--> 747 _rval: Sequence[Variable] = _populate_grad_dict(
    748     var_to_app_to_idx, grad_dict, _wrt, cost_name
    749 )
    751 rval: MutableSequence[Variable | None] = list(_rval)
    753 for i in range(len(_rval)):

File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/gradient.py:1541, in _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name)
   1538     # end if cache miss
   1539     return grad_dict[var]
-> 1541 rval = [access_grad_cache(elem) for elem in wrt]
   1543 return rval

File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/gradient.py:1496, in _populate_grad_dict.<locals>.access_grad_cache(var)
   1494 for node in node_to_idx:
   1495     for idx in node_to_idx[node]:
-> 1496         term = access_term_cache(node)[idx]
   1498         if not isinstance(term, Variable):
   1499             raise TypeError(
   1500                 f"{node.op}.grad returned {type(term)}, expected"
   1501                 " Variable instance."
   1502             )

File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/gradient.py:1326, in _populate_grad_dict.<locals>.access_term_cache(node)
   1318         if o_shape != g_shape:
   1319             raise ValueError(
   1320                 "Got a gradient of shape "
   1321                 + str(o_shape)
   1322                 + " on an output of shape "
   1323                 + str(g_shape)
   1324             )
-> 1326 input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)
   1328 if input_grads is None:
   1329     raise TypeError(
   1330         f"{node.op}.grad returned NoneType, expected iterable."
   1331     )

File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/tensor/optimize.py:565, in MinimizeOp.L_op(self, inputs, outputs, output_grads)
    560 implicit_f = grad(inner_fx, inner_x)
    562 df_dx, *df_dtheta_columns = jacobian(
    563     implicit_f, [inner_x, *inner_args], disconnected_inputs="ignore"
    564 )
--> 565 grad_wrt_args = implict_optimization_grads(
    566     df_dx=df_dx,
    567     df_dtheta_columns=df_dtheta_columns,
    568     args=args,
    569     x_star=x_star,
    570     output_grad=output_grad,
    571     fgraph=self.fgraph,
    572 )
    574 return [zeros_like(x), *grad_wrt_args]

File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/tensor/optimize.py:333, in implict_optimization_grads(df_dx, df_dtheta_columns, args, x_star, output_grad, fgraph)
    290 r"""
    291 Compute gradients of an optimization problem with respect to its parameters.
    292 
   (...)    329     The function graph that contains the inputs and outputs of the optimization problem.
    330 """
    331 df_dx = cast(TensorVariable, df_dx)
--> 333 df_dtheta = concatenate(
    334     [
    335         atleast_2d(jac_col, left=False)
    336         for jac_col in cast(list[TensorVariable], df_dtheta_columns)
    337     ],
    338     axis=-1,
    339 )
    341 replace = dict(zip(fgraph.inputs, (x_star, *args), strict=True))
    343 df_dx_star, df_dtheta_star = cast(
    344     list[TensorVariable],
    345     graph_replace([atleast_2d(df_dx), df_dtheta], replace=replace),
    346 )

File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/tensor/basic.py:2980, in concatenate(tensor_list, axis)
   2973 if not isinstance(tensor_list, tuple | list):
   2974     raise TypeError(
   2975         "The 'tensors' argument must be either a tuple "
   2976         "or a list, make sure you did not forget () or [] around "
   2977         "arguments of concatenate.",
   2978         tensor_list,
   2979     )
-> 2980 return join(axis, *tensor_list)

File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/tensor/basic.py:2794, in join(axis, *tensors_list)
   2792     return tensors_list[0]
   2793 else:
-> 2794     return _join(axis, *tensors_list)

File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/graph/op.py:293, in Op.__call__(self, name, return_list, *inputs, **kwargs)
    249 def __call__(
    250     self, *inputs: Any, name=None, return_list=False, **kwargs
    251 ) -> Variable | list[Variable]:
    252     r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs.
    253 
    254     This method is just a wrapper around :meth:`Op.make_node`.
   (...)    291 
    292     """
--> 293     node = self.make_node(*inputs, **kwargs)
    294     if name is not None:
    295         if len(node.outputs) == 1:

File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/tensor/basic.py:2487, in Join.make_node(self, axis, *tensors)
   2484 ndim = tensors[0].type.ndim
   2486 if not builtins.all(x.ndim == ndim for x in tensors):
-> 2487     raise TypeError(
   2488         "Only tensors with the same number of dimensions can be joined. "
   2489         f"Input ndims were: {[x.ndim for x in tensors]}"
   2490     )
   2492 try:
   2493     static_axis = int(get_scalar_constant_value(axis))

TypeError: Only tensors with the same number of dimensions can be joined. Input ndims were: [2, 2, 3, 4, 2, 2, 2, 3, 3, 4]

PyTensor version information:

Numpy: 1.26.4 PyMC: 0+untagged.10301.gdc7cfee.dirty PyTensor: 2.31.7

Context for the issue:

Bug was originally caused by trying to run pm.sample against a logp containing a MinimizeOp.

Michal-Novomestsky avatar Jul 25 '25 07:07 Michal-Novomestsky

The whole setup is a bit unorthodox. What is this trying to achieve?

    # tg.grad throws the error
    for var in model.value_vars:
        var = model.rvs_to_values[x]
        logp = pt.sum(pm.logp(y, var))
        tg.grad(logp, var)

You immediately ignore var, so this is y logp wrt to x, and y_obs?

ricardoV94 avatar Jul 25 '25 10:07 ricardoV94

Oh whoops, sorry I didn't notice I still had that there - in this simple example to replicate the problem, var = [x], so it doesn't make a difference (that was a remnant from something else I was testing). I've updated the original post now.

As for that whole codeblock, it's essentially trying to replicate the structure of the original problem with as little clutter as possible. It's a reference to the following in mcmc.py: https://github.com/pymc-devs/pymc/blob/dc7cfeeaba98330f2881d627ee988ac267f18df2/pymc/sampling/mcmc.py#L247-L255

The logp being a sum over a pm.logp() of y was there to reduce the dim of pm.logp() from 1 to 0 so minimize doesn't complain. The reason why I didn't simply use model.logp() here is that I suspect it isn't registering the graph replace adding in x0 since I can't find MinimizeOp in the graph of model.logp(). In my actual INLA code, I define a logp rewrite for a MarginalRV so in that case model.logp() works, but for this minimal example I had to concoct a weird workaround.

Michal-Novomestsky avatar Jul 25 '25 10:07 Michal-Novomestsky

It would be great if you could convert this to a pure PyTensor bug. PyMC justs gets in the way

ricardoV94 avatar Jul 25 '25 10:07 ricardoV94

I'll give that a go - I suspect it is something to do with how pm.logp is calculated however, since simply taking tg.grad of a pure MinimizeOp seems to work fine.

Michal-Novomestsky avatar Jul 25 '25 10:07 Michal-Novomestsky

Ah, here's a pure-pytensor example:

import pytensor.tensor as pt
import pytensor.gradient as tg
import pytensor
from pytensor.tensor.optimize import minimize

d = 10
x = pt.vector('x', shape=(d,))
cov = pt.matrix('cov', shape=(d,d))

_, logdet = pt.nlinalg.slogdet(cov)
y = x.T @ cov @ x + logdet

x0, _ = minimize(
    objective=-y,
    x=x,
    method="BFGS",
    optimizer_kwargs={"tol": 1e-8},
)

y = pytensor.graph.replace.graph_replace(y, {x: x0})
tg.grad(y, x)

Which throws the same error as above, except for having a smaller set of tensors of course:

TypeError: Only tensors with the same number of dimensions can be joined. Input ndims were: [2, 3]

I think it can't concatenate vectors with matrices or something to that effect.

Michal-Novomestsky avatar Jul 25 '25 10:07 Michal-Novomestsky

~~Seems to identify a bug, although you should take grad with respect to x0? x is not part of the graph of y after the replace anymore?~~

~~In that case the error disappears, but still a bug was detected there~~

Nvm, x0 still dependes on x, that's like what the initial value?

ricardoV94 avatar Jul 25 '25 10:07 ricardoV94

Yep

Michal-Novomestsky avatar Jul 25 '25 10:07 Michal-Novomestsky

Note that in the real case, it would be wrt the hyperparams not the latent field x, here I've just contracted it for simplicity

Michal-Novomestsky avatar Jul 25 '25 10:07 Michal-Novomestsky

The problem seems to be the MinimizeOp expects the jacobian to be at max 2d, but that's not the case anymore after: https://github.com/pymc-devs/pytensor/pull/1228

Before that I suspect it would error out on the call to the jacobian, now it results in that more surprising bug when it tries to concatenate everything

ricardoV94 avatar Jul 25 '25 11:07 ricardoV94

I tried simply flattening the concatenate (as stated in the OP), but doing that leads to a later bug (the tensor shapes don't quite match up when it does an inner product, so they need to be reordered somehow)

Michal-Novomestsky avatar Jul 25 '25 11:07 Michal-Novomestsky

~~Doesn't avoid it, but isn't it required that the minimization function be scalar?~~ It is, I was looking at the internal grad

ricardoV94 avatar Jul 25 '25 11:07 ricardoV94

Here is a smaller reproducible example:

import pytensor.tensor as pt
from pytensor.tensor.optimize import minimize

x = pt.vector('x')
scalar = pt.vector("scalar")
tensor3 = pt.tensor3('tensor3')

x0, _ = minimize(
    objective=((x * tensor3) + (x * scalar)).sum(),
    x=x,
)

pt.grad(x0.sum(), tensor3)

It tries to get all the parameter jacobians into a df_dtheta matrix to solve but that doesn't work here because they have different ndim.

Also I guess it should eagerly ignore the parameters that are disconnected, for performance?

ricardoV94 avatar Jul 25 '25 11:07 ricardoV94

@jessegrabowski shouldn't we solve one parameter at a time? That would automatically prune branches from disconnected inputs as well

Either way I guess we want the jacobian wrt to the raveled parameters (or equivalently do jacobian()...reshape(x.shape[0], -1). That should be safe to concatenate into a large matrix.

ricardoV94 avatar Jul 25 '25 11:07 ricardoV94

Yes, we need logic for raveling and unraveling the inputs when making the jacobian.

When you say solve one at a time, you mean here? I don't think that works, it would be assuming that cross-terms in df_dtheta_star are zero, which won't be true in general.

jessegrabowski avatar Jul 25 '25 11:07 jessegrabowski

the ravel-jac-unravel pattern happens pretty often, it would be nice if we had some helpers for it (an Op like pymc.blocking.RaveledVars)

jessegrabowski avatar Jul 25 '25 11:07 jessegrabowski

Yes, we need logic for raveling and unraveling the inputs when making the jacobian.

When you say solve one at a time, you mean here? I don't think that works, it would be assuming that cross-terms in df_dtheta_star are zero, which won't be true in general.

You're right. Is it true for disconnected gradients though?

ricardoV94 avatar Jul 25 '25 11:07 ricardoV94

the ravel-jac-unravel pattern happens pretty often, it would be nice if we had some helpers for it (an Op like pymc.blocking.RaveledVars)

Something similar to einops pack/unpack: https://einops.rocks/api/pack_unpack/

Pack returns the object needed to then unpack into the right shapes (basically a tuple with the original shapes)

ricardoV94 avatar Jul 25 '25 11:07 ricardoV94

You're right. Is it true for disconnected gradients though?

We should be filtering those out, yeah

Something similar to einops pack/unpack

Yes, I want exactly this.

jessegrabowski avatar Jul 25 '25 11:07 jessegrabowski

The general pack/unpack may also address #1552

ricardoV94 avatar Jul 25 '25 11:07 ricardoV94