pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Add rewrite to merge multiple `SVD` `Op`s with different settings

Open jessegrabowski opened this issue 1 year ago • 1 comments

Description

SVD comes with a bunch of keyword arguments, most important of which is compute_uv. If False, it will return only the singular values for a given matrix. This is nice if you want to save on computation, but it can actually be inefficient if the user wants gradients. In the reverse mode, we need to compute the U and V matrices anyway, and indeed the L_op for SVD (added in #614 ) adds a 2nd SVD Op to the graph with compute_uv = True

When we see two SVD Ops with the same inputs on a graph, differing only by compute_uv, we should change compute_uv = False to True everywhere. This will allow pytensor to see that these outputs are equivalent and re-use them, rather than computing the decomposition multiple times.

jessegrabowski avatar Apr 28 '24 15:04 jessegrabowski

The rewrite can do the merge immediately, it's just not a local rewrite but a global one then.

Also if an Op has compute_uv=True but the arrays are not used in the graph we can set it to False. That can be a local rewrite, but probably fine to handle together in the same global rewrite

ricardoV94 avatar Apr 28 '24 16:04 ricardoV94

Hi, I want to work on this. As I understand it, I will need to create a class SVDSimplify(GraphRewriter) to file pytensor/tensor/rewriting/linalg.py that check for SVD Op in PyTensor graph, and change the keyword argument compute_uv of all of them to True if

  1. there are more than 1 SVD Ops.
  2. at least one of them has compute_uv=True.

From the documentation, I know roughly how I should do it. I am still not 100% sure on all the decorators used (e.g., @register_canonicalize, @register_stabilize, @register_specialize, @node_rewriter() etc.) but I will ask your inputs in the draft PR.

HangenYuu avatar May 12 '24 12:05 HangenYuu

tensor\rewritings\linalg\local_det_chol is a good rewrite to look at, because it also uses the full FunctionGraph (the first argument to the rewrite function. usually called fgraph) to perform the rewrite.

The decorators tell pytensor at which step of the rewriting process the rewrite should be preformed. This one can come last, so I guess it should be @register_specialize. It's a @node_rewriter because it changes a single node of computation (an SVD Op with compute_uv=False), as opposed to a @graph_rewriter that operates on a whole group of nodes.

Tag me on your draft PR and I'm happy to walk you through the sharp bits.

jessegrabowski avatar May 13 '24 05:05 jessegrabowski

This one can come last, so I guess it should be @register_specialize

This one is pretty cheap that we can run in all 3 stages. It will only be triggered if there's an SVD Op anyway

ricardoV94 avatar May 13 '24 08:05 ricardoV94