orbax icon indicating copy to clipboard operation
orbax copied to clipboard

Incorrect / inconvenient type annotations inside orbax.checkpoint.checkpoint_args.register_with_handler

Open manulari opened this issue 9 months ago • 0 comments

The example code from the documentation of the form

mngr.save(0, args=ocp.args.PyTreeSave(tree))

shows a type error saying "expected 0 positional arguments" (to PyTreeSave) in vscodium with pyright.

Playing with variants of the code I'm pretty convinced that this is due to the following type annotation in orbax.checkpoint.checkpoint_args.register_with_handler (which PyTreeSave is decorated with):

def register_with_handler(
    handler_cls: Type[CheckpointHandler],
    for_save: bool = False,
    for_restore: bool = False,
):
  ...
  def decorator(cls: Type[CheckpointArgs]): # <-- here!
    ...
    return cls

  return decorator

I think in the eyes of pyright this is equivalent to

[...]
def decorator(cls: Type[CheckpointArgs]) -> Type[CheckpointArgs]:
  ...
  return cls
[...]

So now the returned class is viewed as being CheckpointArgs, when in reality it is the subclass of CheckpointArgs that was passed in. My impression is that this can be fixed either by omitting the type annotation entirely, or by annotating it as follows:

CA = TypeVar('CA', bound=CheckpointArgs)

def register_with_handler(
    handler_cls: Type[ocp.CheckpointHandler],
    for_save: bool = False,
    for_restore: bool = False,
):
  ...

  def decorator(cls: Type[CA]) -> Type[CA]:
    ...
    return cls

  return decorator

manulari avatar Mar 18 '25 12:03 manulari