equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Bug: typing issue due to `__getattribute__`

Open nstarman opened this issue 1 year ago • 4 comments

Discovered in https://github.com/GalacticDynamics/galax/pull/377, when jaxtyping's run-time type-checking is turned on, Module.__getattribute__ is not set up to allow for modules to be Generic.

The traceback looks like:

../../python3.11/typing.py:1834: in __class_getitem__
    for param in cls.__parameters__:
        cls        = <class 'ParametricClass'>
        params     = (~T,)
../../python3.11/site-packages/equinox/_module.py:582: in __getattribute__
    value = super().__getattribute__(item)
E   AttributeError: type object 'ParametricClass' has no attribute '__parameters__'
        __class__  = <class 'equinox._module._ModuleMeta'>
        cls        = <class 'ParametricClass'>
        item       = '__parameters__'

I think __parameters__ might need to be special-cased.

nstarman avatar Jul 04 '24 17:07 nstarman

Do you have a MWE?

(For what it's worth I use generics successfully with Equinox elsewhere.)

patrick-kidger avatar Jul 04 '24 17:07 patrick-kidger

Do you have a MWE?

I'll try to make one.

(For what it's worth I use generics successfully with Equinox elsewhere.)

Do you have jaxtyping + beartype on? Beartype is hit in the traceback when it calls typing. _generic_class_getitem, which is where the failed __parameters__ attribute retrieval originates.

nstarman avatar Jul 08 '24 16:07 nstarman

Cheers! FWIW I do often also combine Equinox + jaxtyping + beartype. Admittedly that is now a more complicated stack (beartype especially), so I'm definitely willing to believe something goes wrong :D

patrick-kidger avatar Jul 08 '24 21:07 patrick-kidger

Hmm. It's challenging to reproduce the failure I'm seeing in https://github.com/GalacticDynamics/galax/pull/377. The obvious minimal example doesn't raise the same error.

from typing import Generic, TypeVar

import equinox as eqx
from beartype import beartype as typechecker
from jaxtyping import jaxtyped
from typeguard import typechecked as typechecker

T = TypeVar("T")


@jaxtyped(typechecker=typechecker)
class Parametric(eqx.Module, Generic[T]):
    value: T


@jaxtyped(typechecker=typechecker)
def function(parametric: Parametric[T]) -> Parametric[T]:
    return parametric

nstarman avatar Jul 10 '24 18:07 nstarman