orbax
orbax copied to clipboard
Incorrect / inconvenient type annotations inside orbax.checkpoint.checkpoint_args.register_with_handler
The example code from the documentation of the form
mngr.save(0, args=ocp.args.PyTreeSave(tree))
shows a type error saying "expected 0 positional arguments" (to PyTreeSave) in vscodium with pyright.
Playing with variants of the code I'm pretty convinced that this is due to the following type annotation in orbax.checkpoint.checkpoint_args.register_with_handler (which PyTreeSave is decorated with):
def register_with_handler(
handler_cls: Type[CheckpointHandler],
for_save: bool = False,
for_restore: bool = False,
):
...
def decorator(cls: Type[CheckpointArgs]): # <-- here!
...
return cls
return decorator
I think in the eyes of pyright this is equivalent to
[...]
def decorator(cls: Type[CheckpointArgs]) -> Type[CheckpointArgs]:
...
return cls
[...]
So now the returned class is viewed as being CheckpointArgs, when in reality it is the subclass of CheckpointArgs that was passed in. My impression is that this can be fixed either by omitting the type annotation entirely, or by annotating it as follows:
CA = TypeVar('CA', bound=CheckpointArgs)
def register_with_handler(
handler_cls: Type[ocp.CheckpointHandler],
for_save: bool = False,
for_restore: bool = False,
):
...
def decorator(cls: Type[CA]) -> Type[CA]:
...
return cls
return decorator