typing icon indicating copy to clipboard operation
typing copied to clipboard

Generic specialization?

Open vnmabus opened this issue 3 years ago • 2 comments

I was wondering if it was possible to redefine what a particular generic means for a particular type (and maybe its subclasses, superclasses depending on covariance/contravariance?).

My use case comes from the scikit-learn API. In this API each "estimator" class has two set of attributes: those that are passed to __init__ and those that are computed after calling .fit() to fit the model. The convention is that the latest ones end with an underscore. Also the fit method returns self.

In addition to that, there are methods, like predict, that are only allowed after fitting.

Currently the code that uses this library is like this:

my_estimator = MyEstimator(param1, param2)
my_estimator.fit(X_train, y_train)

# Now is safe to access fit attributes and call predict, score, etc
print(my_estimator.fitted_attr1_)
print(my_estimator.predict(X_test))

The idea was to allow Mypy (or other analyzer) to check these invariants using additional types. Instead of typing fit as:

def fit(self, X: ..., y: ...) -> Self

we could type it as

def fit(self, X: ..., y: ...) -> Fitted[Self]

We then would need a way to:

  • Define that Fitted[T] is a subclass of T. Similar to #802.
  • Define the particular fit attributes of Fitted[T] for a particular T.
  • Define that some methods. such as predict can only be used with a Fitted[T] object, and not with a T object.
  • Define that Fitted[Fitted[T]] == Fitted[T].

Then, only a small change would be needed in the previous code to allow type checkers to detect whether the invariants have been broken:

my_estimator = MyEstimator(param1, param2)
my_estimator = my_estimator.fit(X_train, y_train) # Line changed

# Now is safe to access fit attributes and call predict, score, etc
print(my_estimator.fitted_attr1_)
print(my_estimator.predict(X_test))

This is only a possibility. Alternatives include:

  • Defining a subclass just for the type-checker in a if TYPE_CHECKING: environment. This works for the basic usage illustrated here, but not in other generic cases, e.g.: typing a function that accepts a fitted estimator of any type. It also creates a parallel class structure, which should also be subclassed by subclasses, etc.
  • Just typing the whole class and don't let type-checkers to verify these invariants.

However I think that adding this flexibility to the type system could maybe help in other cases.

vnmabus avatar Sep 01 '22 05:09 vnmabus

Hmm, it's possible to get very close to this today.

Fitted = TypeVar('Fitted', bound=bool)

class Estimator(Generic[Fitted]):
  def __new__(cls) -> Estimator[False]:
    ...

  def fit(self, X: ..., y: ...) -> Estimator[True]:
    ...

  def predict(self: Estimator[True], X: ...) -> ...:
    ...

This code works today and is valid in current type system. The one missing thing is here I did Estimator[True]/Estimator[False]. You ideally want Self[True] and Self[False], but that is related to Higher Kinded Types and this comment.

hmc-cs-mdrissi avatar Sep 03 '22 17:09 hmc-cs-mdrissi

@hmc-cs-mdrissi that's a great idea. Though both pyright and mypy partially disagree with its validity. I believe the correct way would be to do Literal[True] and Literal[False]. What do you think?

zmievsa avatar Dec 17 '22 11:12 zmievsa