pytensor
pytensor copied to clipboard
Allow `TensorType(shape=(1,), broadcastable=(False,))`
Description
This requires re-introducing the broadcastable flags as independent from shape. It seems needed to:
- Not force static shape to be unknown
- Not change the meaning of the graph accidentally due to shape inference / rewrites
Affected Ops (anything that performs broadcasting of existing dims):
- [ ] Elemwise
- [ ] Alloc
- [ ] GEMM Ops
- [ ] Unbroadcast
Will require re-introducing Rebroadcast which could toggle broadcastable flags directly independently from static shape gains from SpecifyShape. Probably better named SpecifyBroadcastable.
Probably Elemwise outputs will have to be unbroadcastable as long as at least on input is also unbroadcastable along the same dimension.
x = pt.vector(shape=(1,), broadcastable=(False,))
y = x + x
assert y.type.broadcastable == (False,)