pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Allow `TensorType(shape=(1,), broadcastable=(False,))`

Open ricardoV94 opened this issue 2 years ago • 1 comments

Description

This requires re-introducing the broadcastable flags as independent from shape. It seems needed to:

  1. Not force static shape to be unknown
  2. Not change the meaning of the graph accidentally due to shape inference / rewrites

Affected Ops (anything that performs broadcasting of existing dims):

  • [ ] Elemwise
  • [ ] Alloc
  • [ ] GEMM Ops
  • [ ] Unbroadcast

Will require re-introducing Rebroadcast which could toggle broadcastable flags directly independently from static shape gains from SpecifyShape. Probably better named SpecifyBroadcastable.

Probably Elemwise outputs will have to be unbroadcastable as long as at least on input is also unbroadcastable along the same dimension.

x = pt.vector(shape=(1,), broadcastable=(False,))
y = x + x
assert y.type.broadcastable == (False,)

ricardoV94 avatar Aug 07 '23 07:08 ricardoV94