pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Add `WalkingNestedGraphRewriter` to apply node rewrites to `Scan`s and `OpFromGraph` inner graphs

Open aerubanov opened this issue 2 years ago • 2 comments

Description

Working on https://github.com/pymc-devs/pymc/pull/6996 I added some functionality to rewrite inner graphs, but we want to move it into pytensor https://github.com/pymc-devs/pymc/pull/6996#discussion_r1411821274. It should be

WalkingNestedGraphRewriter which applies the same node_rewriter to both the outer graph and inner graphs. The idea is you would pass the previous rewrite which doesn't distinguish between the core case or a Scan OpFromGraph. It would be the WalkingNestedGraphRewriter that would apply that logic regardless of which NodeRewriter it's given

I`m going to work on it and create this issue for tracking purposes.

aerubanov avatar Dec 01 '23 16:12 aerubanov

It may be worth thinking how optional inner rewrites work right now. Scan objects take a Mode which is annoying for users to define, but could be a nice way to apply inner rewrites. You simply add a rewrite to its mode. If/when that Scan is later compiled the rewrite will be applied there.

OpFromGraph doesn't have a mode, but allows passing kwargs (silly). An option would be to add a Mode to that. Then the only thing this WalkingGraphRewriter would need to do when it finds either Op is to register the rewrites in the mode of those Ops.

Sometimes though, we do want to apply a bunch of rewrites inside the Op, and not wait for it to be compiled. In that case something like this would be pretty useful. I think we can implement it, just wanted to add some context.

ricardoV94 avatar Dec 01 '23 18:12 ricardoV94

Might also be worth checking out this draft PR from Aesara: https://github.com/aesara-devs/aesara/pull/824

ricardoV94 avatar Dec 13 '23 10:12 ricardoV94