pytensor
pytensor copied to clipboard
Replace dot of 1 and x -> x
Description
So here, I aim to replace the output of a graph of x.[1]
or x@[1]
with x.
I am not quite sure if this is the correct way to do so. When x = pt.col('x')
this works fine.
However for x = fmatrix("x")
or x = vector("x")
or x = pt.tensor("x", shape=(1, 3, 2), dtype="float64")
I kept running into the error:
TypeError: The type of the replacement (Matrix(float64, shape=(3, 2))) must be compatible with the type of the original Variable (Matrix(float64, shape=(3, 1))).
To solve this, I introduced:
if not old_out.type.is_super(new_out.type):
new_out = alloc_like(new_out, old_out, fgraph)
However, x = pt.tensor("x", shape=(1, 3, 2), dtype="float64")
still complained with:
raise ValueError(
ValueError: Alloc static input type and target shape are incompatible: Matrix(float64, shape=(3, 2)) vs (3, 1)
Finally,
if not old_out.type.is_super(new_out.type):
new_out = new_out.reshape(old_out.shape)
works fine but introduces a Reshape op
in the graph.
Related Issue
- [x] Closes #638
- [ ] Related to #
Checklist
- [x] Checked that the pre-commit linting/style checks pass
- [x] Included tests that prove the fix is effective or that the new feature works
- [ ] Added necessary documentation (docstrings and/or example notebooks)
- [ ] If you are a pro: each commit corresponds to a relevant logical change
Type of change
- [x] New feature / enhancement
- [ ] Bug fix
- [ ] Documentation
- [ ] Maintenance
- [ ] Other (please specify):