Named tensors with typed spaces
I took the branch from #407 and added a pytensor.xtensor.spaces module that defines types to distinguish between "unordered spaces" (BaseSpace) and "ordered spaces" (OrderedSpace).
BaseSpace and OrderedSpace are similar to sets & tuples, but do not implement some operations that would mess up interpreting them as dims.
One idea here is to apply the mathematical operations not only to the variables, but also to their spaces.
For example:
# Addition between two variables uses bilateral broadcasting
Space(["a", "b"]) + Space({"c"}) -> Space({"a", "b", "c"})
This matches broadcasting in xarray:
a = xarray.DataArray([[1,2,3]], dims=["a", "b"])
b = xarray.DataArray([1,2,3,4], dims=["c"])
assert set((b + a).dims) == {"a", "b", "c"}
However, xarray.DataArray.dims are tuples, and the commutative rule does not apply to addition of xarray.DataArray variables' dims:
assert (a + b).dims == (b + a).dims # AssertionError
In contrast, with this PR the resulting dims become an unordered space, and the resulting XTensorType are equal:
xa = ptx.as_xtensor(a)
xb = ptx.as_xtensor(b)
xc = xa + xb
xa.type # XTensorType(int32, OrderedSpace('a', 'b'), (1, 3))
xb.type # XTensorType(int32, OrderedSpace('c'), (4,))
xc.type # XTensorType(float64, Space{'c', 'a', 'b'}, (None, None, None))
assert (xa + xb).type == (xb + xa).type
This was basic math, but we could introduce XOps with XOp.infer_space methods that can implement broadcasting rules for any operation:
class XOp(Op):
def infer_space(self, fgraph, node, input_spaces) -> BaseSpace:
raise NotImplementedError()
class SumOverTime(XOp):
def infer_space(self, fgraph, node, input_spaces) -> BaseSpace:
[s] = input_spaces
if "time" not in s:
raise ValueError("No time dim to sum over.")
return Space(s - {"time"})
Similarly, this should allow us to implement dot products requiring OrderedSpace inputs to produce an OrderedSpace output, or a specify_dimorder XOp that orders a BaseSpace into an OrderedSpace.
Looking at the (None, None, None) shape from the code block above, I wonder if we should type XTensorType.shape as a Mapping[DimLike, int | ScalarVariable] 🤔
Looking at the (None, None, None) shape from the code block above, I wonder if we should type XTensorType.shape as a Mapping[DimLike, int | ScalarVariable]
PyTensor variables shouldn't show up in the attributes of Variable types.
had a few more thoughts on this, and found that also for unordered spaces we need to know which index a dimension has in the underlying array. With that information one can then index into shape as well.
Now the question is where this information should be kept. Either the XTensorType keeps it, ~or we don't actually make the BaseSpace unordered~ no, then space math would not work.
Maybe it's enough to keep a is_ordered: bool and a dims: tuple? A .space property could create the corresponding BaseSpace/OrderedSpace if needed 🤔