vision
vision copied to clipboard
[FEEDBACK] Model Registration beta API
🚀 Feedback Request
This issue is dedicated for collecting community feedback on the Model Registration API. Please review the dedicated RFC and blogpost where we describe the API in detail and provide an overview of its features.
We would love to get your thoughts, comments and input in order to finalize the API and include it on the new release of TorchVision.
It would be great if list_models could list only specific models matching a regex, or at least wildcard searches like list_models("resnet*") timm has this functionality and it really boosts productivity.
I agree with @dataplayer12. Whilst going through some tests in prototype
the fetch all like behaviour might raise issues when dealing with constructors for future models we might plan on adding which do not follow the same initialisation scheme https://github.com/pytorch/vision/blob/9c3e2bf46bc49997679785d76b7d0a9fea0223c7/test/test_prototype_models.py#L8-L19
Furthermore this can get quite tricky when we're dealing with models that do not have the same out shapes or number of outputs even though they "solve" the same task.
https://github.com/pytorch/vision/blob/9c3e2bf46bc49997679785d76b7d0a9fea0223c7/test/test_prototype_models.py#L28-L35
The BC-compatible fix I see is rather non intrusive and rather simple.
We could change find_model
to something like:
def find_model(name: str, pattern: str) -> Callable[..., M]:
name = name.lower()
try:
fn = BUILTIN_MODELS[name]
# check if the name matches the pattern
if not re.match(pattern, name):
return None
except KeyError:
raise ValueError(f"Unknown model {name}")
return fn
Then we could change list_model_fns
to something like:
def list_model_fns(module, pattern: str = "*") -> List[Callable[..., M]]:
model_fns = [find_model(name, pattern) for name in list_models(module)]
model_fns = list(filter(lambda x: x is not None, model_fns))
return model_fns
Other than giving the users the option of selecting only a specific family of models I believe that this might help with easing developer experience in the case of writing tests or various utilities whilst maintaining the same API.
The alternative, in terms of developer experience would be to pass in individually each model class in the function arguments or decorator, when we cannot make the assertion that all model from a module behave in the exact same way.
We're trying to adopt the new API in TorchGeo but it isn't clear how the registration API works for weights that are not built into torchvision. We list our own WeightsEnums but torchvision.models.list_models
doesn't know anything about them and list_models(module=torchgeo.models)
doesn't work. According to the blog:
The model registration methods are kept private on purpose as we currently focus only on supporting the built-in models of TorchVision.
So it's possible this is by design. Guess I'll just wait for them to become public and copy-n-paste all the code for now...
Thanks for the feedback @adamjstewart .
The registrators are private right now because they weren't intended to work for external packages. What kind of workflow would you like to enable? It seems like it would work like this for torchgeo users:
from torchvision.models import list_models
list_models(module=torchgeo.models)
which IMHO seems awkward; torchgeo users probably just want to use something like torchgeo.models.list_models
, and the fact that it (may) rely on torchvision
should just be an implementation detail, not something exposed to users.
IIRC from the design stage, we introduced the module=
parameter because some models have the same name in the torchvision.models
and torchvision.models.quantized
namespaces - we had to introduce module
to disambiguate, but it's probably not something we would have done otherwise. It may give a sense that we intend to support arbitrary packages, but that wasn't the original intention
We're still open to making those public if we can find a nice/easy/useful way to do so, but for now I think a good old copy-n-paste is your best strat :)
It seems like it would work like this for torchgeo users:
from torchvision.models import list_models list_models(module=torchgeo.models)
I would love it if that syntax worked, but it doesn't:
>>> import torchgeo.models
>>> from torchvision.models import list_models
>>> list_models(module=torchgeo.models)
[]