jsonargparse icon indicating copy to clipboard operation
jsonargparse copied to clipboard

Any ways to support torch functions

Open ioangatop opened this issue 1 year ago • 3 comments

🚀 Feature request

Hi! I would like to use torch functions straight from the yaml file, for example:

class_path: SomeClass
init_args:
  process:
    class_path: torch.argmax
    init_args:
      dim: 1

However, I have had time to succeed, as the typing check fails, as for example the following:

class SomeClass:
  def __init__(self, process: Callable[..., torch.Tensor]) -> None:
    ...

The way around is to wrap them around callable classes, but it would be great to support them, like the dot imports for the torch optimizes, so I dont have to duplicate them or clear a wrapper class

ioangatop avatar Jul 08 '24 15:07 ioangatop

I think this would not be possible. If you run inspect.signature(torch.argmax) it just fails. And this wouldn't even be possible to fix in pytorch side. The problem is that torch.argmax has multiple signatures, see help(torch.argmax), which is something that native python functions don't support.

I will keep this in mind in case some better idea comes up. But now I think a wrapper class is the best option. Possibly a single class which gets the torch function name so that there is no need for one class for each function.

mauvilsa avatar Jul 09 '24 06:07 mauvilsa

Thanks for the fast response!

Possibly a single class which gets the torch function name so that there is no need for one class for each function.

This is also what I did, but I also came up with a different kinda hacky idea to pass it as dict and parse it later as a partial function. For example:

class SomeClass:
  def __init__(self, process: Callable[..., torch.Tensor] | Dict[str, Any]) -> None:
    self.process = self.parse(process) if isinstance(process, dict) else process

  def parse(self, item):
    return functools.partial(
      jsonargparse._util.import_object(item["class_path"]), **item.get("init_args", {})
    )

if you have any other idea how to improve it, please do let me know 🙏

ioangatop avatar Jul 09 '24 11:07 ioangatop

I was wrong about this. I noticed that both mypy and pylance validate the types for torch.argmax just fine. So, jsonargparse should also be able to. One issue is that typeshed-client (the library used to get stub types) does not resolve these functions, so I created https://github.com/JelleZijlstra/typeshed_client/issues/117. The inspect.signature failing is also a problem which would need a separate fix.

mauvilsa avatar Mar 17 '25 06:03 mauvilsa

@ioangatop finally there is a potential fix for this in pull request #770. To resolve parameters of functions like torch.argmax first you need to do

set_parsing_settings(stubs_resolver_allow_py_files=True)

Please test it out.

mauvilsa avatar Sep 08 '25 05:09 mauvilsa