pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Replace dot of 1 and x -> x

Open Dhruvanshu-Joshi opened this issue 8 months ago • 0 comments

Description

So here, I aim to replace the output of a graph of x.[1] or x@[1] with x. I am not quite sure if this is the correct way to do so. When x = pt.col('x') this works fine. However for x = fmatrix("x") or x = vector("x") or x = pt.tensor("x", shape=(1, 3, 2), dtype="float64") I kept running into the error:

TypeError: The type of the replacement (Matrix(float64, shape=(3, 2))) must be compatible with the type of the original Variable (Matrix(float64, shape=(3, 1))).

To solve this, I introduced:

if not old_out.type.is_super(new_out.type):
    new_out = alloc_like(new_out, old_out, fgraph)

However, x = pt.tensor("x", shape=(1, 3, 2), dtype="float64") still complained with:

raise ValueError(
ValueError: Alloc static input type and target shape are incompatible: Matrix(float64, shape=(3, 2)) vs (3, 1)

Finally,

if not old_out.type.is_super(new_out.type):
    new_out = new_out.reshape(old_out.shape)

works fine but introduces a Reshape op in the graph.

Related Issue

  • [x] Closes #638
  • [ ] Related to #

Checklist

Type of change

  • [x] New feature / enhancement
  • [ ] Bug fix
  • [ ] Documentation
  • [ ] Maintenance
  • [ ] Other (please specify):

Dhruvanshu-Joshi avatar Jun 07 '24 07:06 Dhruvanshu-Joshi