typing icon indicating copy to clipboard operation
typing copied to clipboard

Generic ParamSpec in subclass definitions

Open alanhdu opened this issue 2 years ago • 2 comments

Given some base class that is generic over a param-spec, I'd like to be able to define the param-spec using a subclass method implementation. Something like:

T = TypeVar("T")
P = ParamSpec("P")

class Base(Generic[P, T]):
    func: Callable[P, T]
    # for instance
    def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> Tuple[T]:
        return (self.func(*args, **kwargs), )
        
class Subclass(Base):
    def func(self, x: int, *, y: str) -> bytes: ...

That is, the base class is generic over some function definition, and the subclass implement that function as a method. I'm not sure how common this pattern is, but I see it a fair amount when the class gets more complicated and you can't just use a decorator (e.g. torch.nn.Module.forward is a big example of this). Is there some other way of specifying "infer the parameters from this method implementation"? Is this something that would require an update to the specification, is it "just" a feature request to the method implementation "like" an assignment statement?

alanhdu avatar May 25 '23 01:05 alanhdu

As with any type parameter, you can explicitly specify the type arguments for P and T in the Subclass class definition. However, the syntax for ParamSpec allows only positional parameters in the specialization.

from typing import Generic, ParamSpec, TypeVar

T = TypeVar("T")
P = ParamSpec("P")

class Base(Generic[P, T]):
    def func(self, *args: P.args, **kwargs: P.kwargs) -> T:
        ...

    def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> tuple[T]:
        return (self.func(*args, **kwargs),)

class Subclass(Base[[int, str], bytes]):
    def func(self, x: int, y: str, /) -> bytes:
        ...

There isn't a way to specialize the ParamSpec with a signature that includes keyword arguments as shown in your code sample. You could specify ... (which is the ParamSpec equivalent of Any).

class Subclass(Base[..., bytes]):
    def func(self, x: int, *, y: str) -> bytes:
        ...

erictraut avatar May 25 '23 03:05 erictraut

There isn't a way to specialize the ParamSpec with a signature that includes keyword arguments

I've also often wished for a way to specify keyword arguments in a concrete ParamSpec, specifically also for torch.nn.Module.forward:

from abc import abstractmethod
from typing import Generic, Optional, TypeVar, final
from typing_extensions import ParamSpec

import torch
from torch import Tensor, nn

P = ParamSpec("P")
T = TypeVar("T", covariant=True)


class BaseModule(nn.Module, Generic[P, T]):
    @abstractmethod
    def forward(self, *args: P.args, **kwargs: P.kwargs) -> T:
        raise NotImplementedError()

    @final
    def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
        # do other stuff
        return self.forward(*args, **kwargs)


class Linear(BaseModule[[Tensor], Tensor]):
    """Layer with simple signature."""

    def forward(self, input: Tensor) -> Tensor:
        return input @ self.weights + self.bias


class Transformer(BaseModule[[Tensor, Tensor, Optional[Tensor], Optional[Tensor]], Tensor]):
    """Complicated signature with optional arguments."""

    def forward(
        self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None
    ) -> Tensor:
        ...


tf = Transformer()
tf(torch.tensor([1, 1]), torch.tensor([1, 1]))  # type error

Ideally, I could just paste the signature in place of the ParamSpec:

class Transformer(BaseModule[(src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = ..., tgt_mask: Optional[Tensor] = ...), Tensor]): ...

tmke8 avatar Jun 16 '23 07:06 tmke8