vision icon indicating copy to clipboard operation
vision copied to clipboard

[FEEDBACK] Model Registration beta API

Open datumbox opened this issue 1 year ago • 2 comments

🚀 Feedback Request

This issue is dedicated for collecting community feedback on the Model Registration API. Please review the dedicated RFC and blogpost where we describe the API in detail and provide an overview of its features.

We would love to get your thoughts, comments and input in order to finalize the API and include it on the new release of TorchVision.

datumbox avatar Aug 03 '22 14:08 datumbox

It would be great if list_models could list only specific models matching a regex, or at least wildcard searches like list_models("resnet*") timm has this functionality and it really boosts productivity.

dataplayer12 avatar Aug 19 '22 00:08 dataplayer12

I agree with @dataplayer12. Whilst going through some tests in prototype the fetch all like behaviour might raise issues when dealing with constructors for future models we might plan on adding which do not follow the same initialisation scheme https://github.com/pytorch/vision/blob/9c3e2bf46bc49997679785d76b7d0a9fea0223c7/test/test_prototype_models.py#L8-L19

Furthermore this can get quite tricky when we're dealing with models that do not have the same out shapes or number of outputs even though they "solve" the same task.

https://github.com/pytorch/vision/blob/9c3e2bf46bc49997679785d76b7d0a9fea0223c7/test/test_prototype_models.py#L28-L35

The BC-compatible fix I see is rather non intrusive and rather simple.

We could change find_model to something like:

def find_model(name: str, pattern: str) -> Callable[..., M]:
    name = name.lower()
    try:
        fn = BUILTIN_MODELS[name]
        # check if the name matches the pattern
        if not re.match(pattern, name):
            return None
    except KeyError:
        raise ValueError(f"Unknown model {name}")
    return fn

Then we could change list_model_fns to something like:

def list_model_fns(module, pattern: str = "*") -> List[Callable[..., M]]:
    model_fns = [find_model(name, pattern) for name in list_models(module)]
    model_fns = list(filter(lambda x: x is not None, model_fns))
    return model_fns

Other than giving the users the option of selecting only a specific family of models I believe that this might help with easing developer experience in the case of writing tests or various utilities whilst maintaining the same API.

The alternative, in terms of developer experience would be to pass in individually each model class in the function arguments or decorator, when we cannot make the assertion that all model from a module behave in the exact same way.

TeodorPoncu avatar Aug 19 '22 10:08 TeodorPoncu

We're trying to adopt the new API in TorchGeo but it isn't clear how the registration API works for weights that are not built into torchvision. We list our own WeightsEnums but torchvision.models.list_models doesn't know anything about them and list_models(module=torchgeo.models) doesn't work. According to the blog:

The model registration methods are kept private on purpose as we currently focus only on supporting the built-in models of TorchVision.

So it's possible this is by design. Guess I'll just wait for them to become public and copy-n-paste all the code for now...

adamjstewart avatar Jan 17 '23 23:01 adamjstewart

Thanks for the feedback @adamjstewart .

The registrators are private right now because they weren't intended to work for external packages. What kind of workflow would you like to enable? It seems like it would work like this for torchgeo users:

from torchvision.models import list_models

list_models(module=torchgeo.models)

which IMHO seems awkward; torchgeo users probably just want to use something like torchgeo.models.list_models, and the fact that it (may) rely on torchvision should just be an implementation detail, not something exposed to users.

IIRC from the design stage, we introduced the module= parameter because some models have the same name in the torchvision.models and torchvision.models.quantized namespaces - we had to introduce module to disambiguate, but it's probably not something we would have done otherwise. It may give a sense that we intend to support arbitrary packages, but that wasn't the original intention

We're still open to making those public if we can find a nice/easy/useful way to do so, but for now I think a good old copy-n-paste is your best strat :)

NicolasHug avatar Jan 25 '23 10:01 NicolasHug

It seems like it would work like this for torchgeo users:

from torchvision.models import list_models

list_models(module=torchgeo.models)

I would love it if that syntax worked, but it doesn't:

>>> import torchgeo.models
>>> from torchvision.models import list_models
>>> list_models(module=torchgeo.models)
[]

adamjstewart avatar Jan 25 '23 16:01 adamjstewart