Enumeration does not work with mypy
import equinox as eqx
class Error(eqx.Enumeration):
error = 'error'
success = 'success'
def f(x: Error) -> Error:
return x
f(Error.error)
gives under mypy:
<cell>11: error: Argument 1 to "f" has incompatible type "str"; expected "Error" [arg-type]
ERROR:nb-mypy:<cell>11: error: Argument 1 to "f" has incompatible type "str"; expected "Error" [arg-type]
Hmm, this is unfortunate. So based on a bit of digging, it seems that there are (at least) two distinct ways to declare the type of the members of an enum.Enum:
(a) subclass that type at the same time, e.g. class MyEnum(enum.Enum, SomeClass)
(b) add a type annotation to enum.Enum.__class__.__getitem__
Whilst the first is the common well-advertised one, it seems that the latter occurs because MyEnum["x"] is equivalent to MyEnum.x. Which isn't something I actually ever realised before now, enumerations are crammed full of multiple equivalent ways of manipulating them.
Unfortunately, we separately allowed the [...] syntax for looking up the string message associated with Equinox enumerations. (Itself existing as a syntax on the original diffrax.RESULTS object that was the inspiration for eqx.Enumeration, prior to the latter ever existing.)
I suppose this never came up before because we use pyright, and presumably that allows (a) to override (b), whilst it seems that mypy allows (b) to override (a).
I imgaine we could resolve this by removing this line, at the cost of our actual [...] syntax no longer being static-type-checking compatible. That's not a super common API so I suppose that's acceptable. (Alternatively it would be nice if mypy allowed the more-common (a) not to be overriden by the more-obscure (b), but they're not technically wrong either way.)
WDYT?
I think ensuring that the most common case can be type-checked correctly seems like the correct approach. How much pain would it cause users to remove this [...] syntax?
So I think removing this at runtime is likely a somewhat annoying backward incompatibility, as this has been around for several years / is the only syntax for eqx.Enumeration to accomplish the feature in question.
But what I think what we could probably do is:
- keep the current API at runtime, but remove it from static-type-checking;
- add a new API for this functionality.
I've drafted something at #1122. Lmk how that looks / if it solves your use-case? (I have not explicitly tested with mypy, as I'm not confident of accurately reproducing your setup.)
I will take a look a bit later, but a simple repro Colab: https://colab.research.google.com/drive/1Xtl1kaP034VGJqUMed3sCezW9Zpc5TI5?usp=sharing