Generic specialization?
I was wondering if it was possible to redefine what a particular generic means for a particular type (and maybe its subclasses, superclasses depending on covariance/contravariance?).
My use case comes from the scikit-learn API. In this API each "estimator" class has two set of attributes: those that are passed to __init__ and those that are computed after calling .fit() to fit the model. The convention is that the latest ones end with an underscore.
Also the fit method returns self.
In addition to that, there are methods, like predict, that are only allowed after fitting.
Currently the code that uses this library is like this:
my_estimator = MyEstimator(param1, param2)
my_estimator.fit(X_train, y_train)
# Now is safe to access fit attributes and call predict, score, etc
print(my_estimator.fitted_attr1_)
print(my_estimator.predict(X_test))
The idea was to allow Mypy (or other analyzer) to check these invariants using additional types. Instead of typing fit as:
def fit(self, X: ..., y: ...) -> Self
we could type it as
def fit(self, X: ..., y: ...) -> Fitted[Self]
We then would need a way to:
- Define that
Fitted[T]is a subclass ofT. Similar to #802. - Define the particular fit attributes of
Fitted[T]for a particularT. - Define that some methods. such as
predictcan only be used with aFitted[T]object, and not with aTobject. - Define that
Fitted[Fitted[T]] == Fitted[T].
Then, only a small change would be needed in the previous code to allow type checkers to detect whether the invariants have been broken:
my_estimator = MyEstimator(param1, param2)
my_estimator = my_estimator.fit(X_train, y_train) # Line changed
# Now is safe to access fit attributes and call predict, score, etc
print(my_estimator.fitted_attr1_)
print(my_estimator.predict(X_test))
This is only a possibility. Alternatives include:
- Defining a subclass just for the type-checker in a
if TYPE_CHECKING:environment. This works for the basic usage illustrated here, but not in other generic cases, e.g.: typing a function that accepts a fitted estimator of any type. It also creates a parallel class structure, which should also be subclassed by subclasses, etc. - Just typing the whole class and don't let type-checkers to verify these invariants.
However I think that adding this flexibility to the type system could maybe help in other cases.
Hmm, it's possible to get very close to this today.
Fitted = TypeVar('Fitted', bound=bool)
class Estimator(Generic[Fitted]):
def __new__(cls) -> Estimator[False]:
...
def fit(self, X: ..., y: ...) -> Estimator[True]:
...
def predict(self: Estimator[True], X: ...) -> ...:
...
This code works today and is valid in current type system. The one missing thing is here I did Estimator[True]/Estimator[False]. You ideally want Self[True] and Self[False], but that is related to Higher Kinded Types and this comment.
@hmc-cs-mdrissi that's a great idea. Though both pyright and mypy partially disagree with its validity. I believe the correct way would be to do Literal[True] and Literal[False]. What do you think?