aesara icon indicating copy to clipboard operation
aesara copied to clipboard

Use Python's typing instead of attaching `Type` objects to `Variable`s

Open brandonwillard opened this issue 5 years ago • 8 comments

Instead of creating "type" classes (i.e. Types) and manually attaching them to Aesara objects (i.e. via Variable.type), we should actual use Python's built-in type/class inheritance. I'm not immediately aware of any required functionality in the current Aesara type system that wouldn't allow this.

Here's an illustration of the current situation:

import aesara.tensor as at


x = at.vector()

We created a simple Aesara vector, x; now, let's inspect its (Python) types:

>>> type(x).mro()

[aesara.tensor.var.TensorVariable,
 aesara.tensor.var._tensor_py_operators,
 aesara.gof.graph.Variable,
 aesara.gof.graph.Node,
 aesara.gof.utils.object2,
 object]

Oddly, there's also a separate Variable.type field/member/property introduced by inheriting from Variable:

>>> x.type

TensorType(float64, vector)

>>> x.type.Variable

aesara.tensor.var.TensorVariable

>>> x.type.Constant

aesara.tensor.var.TensorConstant

As we can see, this extra Type object holds information about the variable's dtype and broadcastable dimensions (i.e. the vector means x.type.broadcastable equals [False]/[True]). From the latter two properties, it's also clear that the actual Python type (i.e. TensorVariable) is directly associated with the Aesara type (i.e. TensorType) as illustrated by the value of TensorType.Variable.

This leads to the question: why doesn't TensorVariable just inherit from its x.type type/class (i.e. some TensorType)?

If anything, the mechanics behind such inheritance might seem non-standard, as it could require some form of automatic type generation for sub-classes like TensorVariable, but—even so—that's pretty straightforward to do in a __new__ class method (e.g. we already do that in symbolic-pymc).

Regarding actual code changes, it seems like we would at least need to change all the Variable.type invocations with the standard type(...) calls, and/or more direct calls to isinstance(x, TensorType) instead of things like x.type == TensorType(...).

We could also keep Variable.type and have it return type(self) just to make the initial transition easy.

brandonwillard avatar Oct 28 '20 17:10 brandonwillard

@LegrandNico, since you're becoming Aesara's type expert, I would like to introduce you to this potentially revolutionary refactoring topic...

brandonwillard avatar Mar 12 '21 19:03 brandonwillard

The issue fixed in https://github.com/aesara-devs/aesara/pull/892 could possibly have been caught by Mypy's type checking if the approach described here was in place.

In that example, the problematic at.minimum(x, y)-like expression could've been noticed by Mypy, since at.minimum could be declared to have a (FloatLikeType, FloatLikeType) inputs signature, which would've failed when given the incompatible y: BoolLikeType that it was being given.

brandonwillard avatar Apr 07 '22 20:04 brandonwillard

It looks like we would need to do something exactly like NumPy is currently doing with its ndarray typing: i.e. use typing.Generic parameterized on dtype and shape information.

Here's a simple example:

import numpy as np
import typing
import typing_inspect


DataType = typing.TypeVar("DataType")


class TensorType(typing.Generic[DataType]):

    @classmethod
    @property
    def dtype(cls) -> DataType:
        generic_bases = typing_inspect.get_generic_bases(cls)
        type_base = next((b for b in generic_bases if typing_inspect.get_origin(b) == TensorType))
        dtype = typing_inspect.get_args(type_base)[0]
        return dtype


Float32TensorType = TensorType[np.float32]


assert Float32TensorType == TensorType[np.float32]


class Float32Variable(Float32TensorType):
    pass


scalar_float_inst = Float32Variable()

assert scalar_float_inst.dtype is np.float32

assert isinstance(scalar_float_inst, TensorType)
assert isinstance(scalar_float_inst, Float32Variable)

Unfortunately, checks based on the parameterized TensorTypes won't work, but it might not matter given that checks like the preceding Float32Variable one work:

# Fails
assert isinstance(scalar_float_inst, Float32TensorType)

brandonwillard avatar Apr 07 '22 23:04 brandonwillard

There might possibly be some useful typing examples in the array-api project:

  • https://github.com/data-apis/array-api/tree/main/spec/API_specification/signatures

dhirschfeld avatar Apr 08 '22 01:04 dhirschfeld

To clarify, adding this kind of parameterized typing will allow type checking (i.e. Mypy) to catch issues involving incompatible dtypes and (static) shapes. Currently, these kinds of issues require explicit unit tests to catch, and that's quite a burden on our testing requirements.

Also, it should be possible to add this on top of the existing codebase without making the changes implied by this issue.

In other words, we don't need to make the change from using TensorType instances (i.e. where TensorType is a standard class and we create instance of that class for each specific dtype and static shape/ndim) to explicit TensorType types (i.e. where TensorType is a "base" type and "instances" are new type-compatible objects specific to each dtype and static shape/ndim).

brandonwillard avatar Apr 08 '22 19:04 brandonwillard

Based on @brandonwillard's code example above I developed a more complete example of how a TensorType based on Python's typing system could look like.

https://github.com/markusschmaus/Documents/blob/main/aesara_typing_demo.py

One key difference between his code and mine, is that @brandonwillard used inspection to extract the dtype information from the GenericAlias. Since python's typing system is still rapidly evolving, runtime behavior of GenericAlias has changed in the past (isinstance no longer works with GenericAlias), and typing_inspection is still marked experimental, I decided to use a different approach. Since, as he suggested, we need to dynamically generate sub-classes, we can store the relevant information in these sub-classes when they are created in a way which can be easily retrieved later.

The demo passes mypy without errors and contains a few asserts which also succeed.

I tried to get TensorType((3, 7), np.float64) to work without mypy errors, but I believe this is currently impossible using __call__ (see https://github.com/python/mypy/issues/6721#issuecomment-486727328).

Based on this demo I believe that implementing this suggestion is possible and I would be interested in working on this. However this would be a major refactoring and only makes sense if there is support from existing aesara developers.

One aspect I would like to explore, when working on this, is the role of Variables as part of the shape specification and how they interact with broadcasting, and how they could be leveraged for aesara to support dimensions (see https://github.com/aesara-devs/aesara/discussions/1013).

markusschmaus avatar Sep 22 '22 13:09 markusschmaus

One key difference between his code and mine, is that @brandonwillard used inspection to extract the dtype information from the GenericAlias. Since python's typing system is still rapidly evolving, runtime behavior of GenericAlias has changed in the past (isinstance no longer works with GenericAlias), and typing_inspection is still marked experimental, I decided to use a different approach. Since, as he suggested, we need to dynamically generate sub-classes, we can store the relevant information in these sub-classes when they are created in a way which can be easily retrieved later.

Yes, that example was completely experimental; we definitely don't need to take an approach like that. As long as what we're currently calling Variables become instances of Types, and we can adequately track "static"/type-level information like dtype and (partial) shapes using those new Types, we're much better off.

Based on this demo I believe that implementing this suggestion is possible and I would be interested in working on this. However this would be a major refactoring and only makes sense if there is support from existing aesara developers.

I'm ready to make these changes sooner than later, so, if you want to work on this, we can start now. A draft PR is the next step.

One aspect I would like to explore, when working on this, is the role of Variables as part of the shape specification and how they interact with broadcasting, and how they could be leveraged for aesara to support dimensions (see #1013).

Be sure to read the comments here: https://github.com/aesara-devs/aesara/discussions/1013#discussioncomment-3016739. In other words, we shouldn't need anything at the type-level to support most kinds of dimension tracking. Regardless, these changes would help any related efforts a lot.

brandonwillard avatar Sep 22 '22 16:09 brandonwillard

I'm ready to make these changes sooner than later, so, if you want to work on this, we can start now. A draft PR is the next step

I guess this means an PR without any commits, as I don't have any yet, with a description of the intended changes. I will probably not get around to do this today.

Be sure to read the comments here: #1013 (reply in thread). In other words, we shouldn't need anything at the type-level to support most kinds of dimension tracking. Regardless, these changes would help any related efforts a lot.

I looked into this, here are my main takeaways:

  • Even the simplest operations remove any information about any Variable in a shape, see code below
  • HasShape specifies a shape as Tuple[Optional[int], ...], this should be changed to allow (named) scalar integer variables
  • The key (killer) feature of dimensions, for example in Xarray, is their use for smarter broadcasting, the question is how exactly
from aesara import tensor as at

dim = at.iscalar(name="dim")

x = at.TensorType("float64", (dim,))("x")
y = at.TensorType("float64", (dim,))("y")

print(x.type.shape)         # (dim,)
print((x + y).type.shape)   # (None,)

markusschmaus avatar Sep 22 '22 17:09 markusschmaus

I guess this means an PR without any commits, as I don't have any yet, with a description of the intended changes. I will probably not get around to do this today.

That's fine. The idea is that if you create the branch/PR then we can push to it, whereas if one of us creates it then you can't push to it.

  • Even the simplest operations remove any information about any Variable in a shape, see code below

You are—in part—seeing a lack of "shape inference" at the Op-level (specifically, in Elemwise.make_node), and that's entirely independent of https://github.com/aesara-devs/aesara/discussions/1013. We're always in the process of fixing those issues (when they can be fixed), and, in this exact case, recent updates produce the expected results:

from aesara import tensor as at


x = at.tensor("float64", shape=(2, None), name="x")
y = at.tensor("float64", shape=(2, 3), name="y")

print(x.type.shape)
(2, None)
print((x + y).type.shape)
(2, 3)

More importantly, a scalar Variable cannot actually be assigned as a static shape value. The only reason it appears to work is that the run-time type checking in TensorType.__init__ is extremely lax right now. This is just something we need to fix.

As the thread states, explicit symbolic/non-symbolic scalars can be assigned to a dimension in a shape using specify_shape. Nothing is lost in that approach, and, as far as I can tell, https://github.com/aesara-devs/aesara/discussions/1013#discussion-4170082 is really just about a different/OO interface for this functionality—one with a very particular focus on naming/labeling.

N.B. My comment in that thread links to an example (i.e. https://github.com/aesara-devs/aesara/issues/1089#issuecomment-1201647864) demonstrating how this interacts with our shape inference features, so check that out if you want to see some of what's going on under the hood.

  • HasShape specifies a shape as Tuple[Optional[int], ...], this should be changed to allow (named) scalar integer variables
  • The key (killer) feature of dimensions, for example in Xarray, is their use for smarter broadcasting, the question is how exactly

Just to be clear, we're not planning to add dimension labels to Aesara. The utility—outside of the implied high/user-level OO interface—of such a feature isn't very clear, and all the relevant functionality is already covered by specify_shape.

Aside from that, I'm not aware of any ways in which dimension labels could help with broadcasting in Aesara. If the idea is that it helps with some broadcasting-related things at the user-level, then fine, but we would need to determine how such a feature could be added/emulated without unnecessarily touching Aesara's internals.

Working on this issue will introduce you to a few of the relevant Aesara details, though, such that, if there is a non-intrusive way to add such a feature, you'll likely understand exactly how it could be done, or why it can't be done.

brandonwillard avatar Sep 22 '22 21:09 brandonwillard

Before I forget, here's some related (and interesting) work on Python typing for shapes: https://peps.python.org/pep-0646/

brandonwillard avatar Sep 22 '22 21:09 brandonwillard

Come to think of it, we could start using TypeVarTuple and Unpack from typing_extensions right now.

brandonwillard avatar Sep 22 '22 21:09 brandonwillard

Working on this issue will introduce you to a few of the relevant Aesara details, though, such that, if there is a non-intrusive way to add such a feature, you'll likely understand exactly how it could be done, or why it can't be done.

That's exactly why I'm interested in working on this.

Come to think of it, we could start using TypeVarTuple and Unpack from typing_extensions right now.

Thanks, I wasn't aware that they were already available in typing_extensions.

markusschmaus avatar Sep 22 '22 22:09 markusschmaus

Xarray broadcasting works based on dimension names not position like Numpy. It would be a fundamental change but could be interesting to think how it could be built on top of Aesara (not necessarily in Aesara itself).

Regarding the comments it completely disambiguates broadcasting, because dimensions are always matched based on name, and broadcasted in the cases they are missing in one of the inputs. So there's never a "runtime" broadcasting you could not infer from the input types.

Would look something like:

x = matrix(dims=("a", "b"))
y = matrix(dims=("b", "c"))
# Under the hood this would require matching dims by name 
# and expanding non overlapping ones
# (at.expand_dims(x, -1) + at.expand_dims(y, 0))
(x + y).dims == ("a", "b", "c"))

ricardoV94 avatar Sep 23 '22 05:09 ricardoV94

I strongly believe that having some support for dimensions and dimension based broadcasting inside Aesara is a good idea, and I want to be transparent that this is part of my motivation for working on this. But while dimensions are related to types, they are largely independent. The only question is see which is relevant for this work package is to what extend and how we want to support scalar integer variables in shapes, because right now doing so is a type error (HasShape specifies a shape as Tuple[Optional[int], ...]) and if we use the pattern from my demo and dynamically create classes representing these types, we probably need to enforce the type of shape at runtime.

So here is my plan of action:

  1. I need to get my environment fully setup including jax and jaxlib and create a draft PR
  2. Turn Type and all its sub-classes into meta-classes, such that their instances become classes themselves, making shape and dimension read-only
  3. Setup a proper class hierarchy for the dynamically created classes introducing generic abstract base classes, using abc for this could be useful, but is probably not strictly necessary.
  4. Merge Variable and its sub-classes with these base classes, such that type(x) == x.type
  5. Add additional type hints where appropriate

After each of these steps I want Aesara to be in a valid state with all tests passing.

markusschmaus avatar Sep 23 '22 07:09 markusschmaus

scalar integer variables in shapes

What do you mean?

I strongly believe that having some support for dimensions and dimension based broadcasting inside Aesara is a good idea

Aesara tries to emulate Numpy API, so it would be a conflict. I think the best would be if you could create a subclass of TensorType like DimTensorType and create specific look-alike Ops for those types that follow dimension broadcasting rules. It's a good challenge for the flexibility of Aesara, but I do not think it should be a feature in Aesara itself, if only for the extra work of maintaining two whole API compatibilities (Numpy and Xarray).

These Ops can be thin wrappers around the Aesara ones that reuse the same machinery but shuffle dims around and create output types with the specified dims.

By the way @aseyboldt has also put some thought into this, as he is also a fan of Xarray dims.

ricardoV94 avatar Sep 23 '22 08:09 ricardoV94

scalar integer variables in shapes

What do you mean?

I mean this:

country_dim = at.iscalar("country_dim")  # or a constant, or a shared variable
x = at.specify_shape(at.vector("x"), shape=(country_dim,)

This is allowed according to the type hints of specify_shape, but it hands the shape over to the untyped TensorType.clone which then hands it over to TensorType.__init__ which according to its type hints doesn't allow Variable, resulting in a x.type.shape that doesn't conform to the type Tuple[Optional[int], ...] specified in HasShape.

One obvious way to fix this is of course to change the type hints in TensorType.__init__ and HasShape to allow for Variable. However, this introduces some complexities in step 2 of my plan of action, since typing.Literal only works for "bools, ints, strs, bytes, and enum values", but not for objects of class Variable.

With the original idea of using typing_inspect this would be really hard to solve, but with the dynamic sub-class model this might not be such a big deal.

I will stop thinking about this for now and revisit this issue when I get to step 2.

markusschmaus avatar Sep 23 '22 10:09 markusschmaus

Here is the PR https://github.com/aesara-devs/aesara/pull/1207

markusschmaus avatar Sep 23 '22 12:09 markusschmaus

I strongly believe that having some support for dimensions and dimension based broadcasting inside Aesara is a good idea, and I want to be transparent that this is part of my motivation for working on this. But while dimensions are related to types, they are largely independent. The only question is see which is relevant for this work package is to what extend and how we want to support scalar integer variables in shapes, because right now doing so is a type error (HasShape specifies a shape as Tuple[Optional[int], ...]) and if we use the pattern from my demo and dynamically create classes representing these types, we probably need to enforce the type of shape at runtime.

@markusschmaus I totally agree that this would be great to have, I've wanted this for quite some time now. I've started working on a little proof of concept implementation of this. If you want to share ideas, we could maybe meet some time next week or so and I can show you what I have so far?

aseyboldt avatar Sep 23 '22 15:09 aseyboldt

FYI: This issue is not related dimension labeling or XArray's particular style of broadcasting. @ricardoV94's statements above reflect my position on this matter, as well; otherwise, let's keep the discussions to their respective issues/Discussions.

brandonwillard avatar Sep 23 '22 15:09 brandonwillard