Add `WalkingNestedGraphRewriter` to apply node rewrites to `Scan`s and `OpFromGraph` inner graphs
Description
Working on https://github.com/pymc-devs/pymc/pull/6996 I added some functionality to rewrite inner graphs, but we want to move it into pytensor https://github.com/pymc-devs/pymc/pull/6996#discussion_r1411821274. It should be
WalkingNestedGraphRewriterwhich applies the samenode_rewriterto both the outer graph and inner graphs. The idea is you would pass the previous rewrite which doesn't distinguish between the core case or aScanOpFromGraph. It would be theWalkingNestedGraphRewriterthat would apply that logic regardless of whichNodeRewriterit's given
I`m going to work on it and create this issue for tracking purposes.
It may be worth thinking how optional inner rewrites work right now. Scan objects take a Mode which is annoying for users to define, but could be a nice way to apply inner rewrites. You simply add a rewrite to its mode. If/when that Scan is later compiled the rewrite will be applied there.
OpFromGraph doesn't have a mode, but allows passing kwargs (silly). An option would be to add a Mode to that. Then the only thing this WalkingGraphRewriter would need to do when it finds either Op is to register the rewrites in the mode of those Ops.
Sometimes though, we do want to apply a bunch of rewrites inside the Op, and not wait for it to be compiled. In that case something like this would be pretty useful. I think we can implement it, just wanted to add some context.
Might also be worth checking out this draft PR from Aesara: https://github.com/aesara-devs/aesara/pull/824