InvokeAI
InvokeAI copied to clipboard
Add simplified model manager install API to InvocationContext
Summary
This adds two model manager-related methods to the InvocationContext uniform API. They are accessible via context.models.*:
load_and_cache_model(source: Path|str|AnyHttpURL, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None) -> LoadedModel
Load the model located at the indicated path, URL or repo_id.
This will download the model from the indicated location , cache it locally, and load it into the model manager RAM cache if needed. If the optional loader argument is provided, the loader will be invoked to load the model into memory. Otherwise the method will call safetensors.torch.load_file() or torch.load() (with a pickle scan) as appropriate to the file suffix. Diffusers models are supported via HuggingFace repo_ids.
Be aware that the LoadedModel object will have a config attribute of None.
Here is an example of usage:
def invoke(self, context: InvocatinContext) -> ImageOutput:
model_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'
loadnet = context.models.load_and_cache_model(model_url)
with loadnet as loadnet_model:
upscaler = RealESRGAN(loadnet=loadnet_model,...)
download_and_cache_model( source: str | AnyHttpUrl, access_token: Optional[str] = None, timeout: Optional[int] = 0) -> Path
Download the model file located at source to the models cache and return its Path.
This will check models/.download_cache for the desired model file and download it from the indicated source if not already present. The local Path to the downloaded file is then returned.
Other Changes
This PR performs a migration, in which it renames models/.cache to models/.convert_cache, and migrates previously-downloaded ESRGAN, openpose, DepthAnything and Lama inpaint models from the models/core directory into models/.download_cache.
There are a number of legacy model files in models/core, such as GFPGAN, which are no longer used. This PR deletes them and tidies up the models/core directory.
Related Issues / Discussions
I have systematically replaced all the calls to download_with_progress_bar(). This function is no longer used elsewhere and has been removed.
QA Instructions
I have added unit tests for the three new calls. You may test that the load_and_cache_model() call is working by running the upscaler within the web app. On first try, you will see the model file being downloaded into the models .cache directory. On subsequent tries, the model will either load from RAM (if it hasn't been displaced) or will be loaded from the filesystem.
Merge Plan
Squash merge when approved.
Checklist
- [X] The PR has a short but descriptive title, suitable for a changelog
- [X] Tests added / updated (if applicable)
- [X] Documentation added / updated (if applicable)
I have added a migration script that tidies up the models/core directory and removes unused models such as GFPGAN. In addition, I have renamed models/.cache to models/.convert_cache to distinguish it from the directory in which just-in-time models are downloaded into, which is models/.download_cache. While the size of models/.convert_cache is capped such that less-used models are cleared periodically, files in models/.download_cache are not removed unless the user does so manually.
@psychedelicious @RyanJDick I think I've responded to all comments and suggestions. Thanks for the reviews!
I'll do a full review later, we still have too much coupling between the context/model manager and some of the utility classes.
I've just pushed changes that lift the loading out of the LaMa and DepthAnything classes, but not DWOpenPose - I'm not familiar with how ONNX loads models and ran out of time for this for now..
There just doesn't seem to be a good way for utility classes to access the model manager service in the running app without going through the context. One out would be to have the model manager set a singleton in the way that the configuration service does, thereby allowing the utility classes make a call to get_model_manager(). Of course, there would have to be a bit of setup in the event that the model manager hadn't already been initialized.
There just doesn't seem to be a good way for utility classes to access the model manager service in the running app without going through the context. One out would be to have the model manager set a singleton in the way that the configuration service does, thereby allowing the utility classes make a call to get_model_manager(). Of course, there would have to be a bit of setup in the event that the model manager hadn't already been initialized.
Do my changes in these two commits not do this? feat(backend): lift managed model loading out of lama class feat(backend): lift managed model loading out of depthanything class
There just doesn't seem to be a good way for utility classes to access the model manager service in the running app without going through the context. One out would be to have the model manager set a singleton in the way that the configuration service does, thereby allowing the utility classes make a call to get_model_manager(). Of course, there would have to be a bit of setup in the event that the model manager hadn't already been initialized.
Do my changes in these two commits not do this? feat(backend): lift managed model loading out of lama class feat(backend): lift managed model loading out of depthanything class
Sure. Those look good.
@psychedelicious I think all the issues are now addressed. Ok to approve and merge?
Sorry, no we need to do the same pattern for that last processor, DWOpenPose - it uses ONNX models and I'm not familiar with how they work. I can take care of that but won't be til next week.
Sorry, no we need to do the same pattern for that last processor,
DWOpenPose- it uses ONNX models and I'm not familiar with how they work. I can take care of that but won't be til next week.
I'll take a look at it Friday.
I've refactored DWOpenPose using the same pattern as in the other backend image processors. I also added some of the missing typehints so there are fewer red squigglies. I noticed that there is a problem with the pip dependencies. If the onnxruntime package is installed, then even if onnxruntime-gpu is installed as well, the onnx runtime won't use the GPU (see https://github.com/microsoft/onnxruntime/issues/7748). You have to remove onnxruntime and then install onnxruntime-gpu. I don't think pyproject.toml provides a way for an optional dependency to remove a default dependency. Is there a workaround?
I noticed that there is a problem with the pip dependencies. If the onnxruntime package is installed, then even if onnxruntime-gpu is installed as well, the onnx runtime won't use the GPU (see https://github.com/microsoft/onnxruntime/issues/7748). You have to remove onnxruntime and then install onnxruntime-gpu. I don't think pyproject.toml provides a way for an optional dependency to remove a default dependency. Is there a workaround?
I think we'd need to just update the installer script with special handling to uninstall those packages if they are already installed. It's probably time to revise our optional dependency lists. I think "cuda" and "cpu" make sense to be the only two user-facing options. "xformers" is extraneous now (torch's native SDP implementation is just as fast), so it could be removed.
Thanks for cleaning up the pose detector. It would be nice to use the model context so we get memory mgmt, but that is a future task.
I had some feedback from earlier about the public API that I think was lost:
Token: When would a node reasonably provide an API token? We support regex-matched tokens in the config file. I don't think this should be in the invocation API.
Timeout: Similarly, when could a node possibly be able to make a good determination of the timeout for a download? It doesn't know the user's internet connection speed. It's a user setting - could be globally set in the config file and apply to all downloads.
If both of those args are removed, then load_ckpt_from_path and load_ckpt_from_url look very similar. I think this is maybe what @RyanJDick was suggesting with a single load_custom_model method.
Also, will these methods work for diffusers models? If so, "ckpt" probably doesn't need to be in the name.
Thanks for cleaning up the pose detector. It would be nice to use the model context so we get memory mgmt, but that is a future task.
The onnxruntime model loading architecture seems to be very different from what the model manager expects. In particular, the onnxruntime.InferenceSession() constructor doesn't seem to provide any way to accept a model that has been read into RAM or VRAM. The closest I can figure is that you can pass the constructor an IOBytes object to a serialized version of the model in memory. This will require some architectural changes in the model manager that should be its own PR.
- Token: When would a node reasonably provide an API token? We support regex-matched tokens in the config file. I don't think this should be in the invocation API.
Right now the regex token handling is done in a part of the install manager that is not called by the simplifed API. I'll move this code into the core download() routine so that tokens are picked up whenever a URL is requested.
- Timeout: Similarly, when could a node possibly be able to make a good determination of the timeout for a download? It doesn't know the user's internet connection speed. It's a user setting - could be globally set in the config file and apply to all downloads.
I think you're saying this should be a global config option and I agree with that. Can we get the config migration code in so that I have a clean way of updating the config?
Also, will these methods work for diffusers models? If so, "ckpt" probably doesn't need to be in the name.
Not currently. It only works with checkpoints. I'd planned to add diffusers support later, but I guess I should do that now. Converting to draft.
Probably doesn't make sense to spend time on the onnx loading. This is the only model that uses it.
Right now the regex token handling is done in a part of the install manager that is not called by the simplifed API. I'll move this code into the core download() routine so that tokens are picked up whenever a URL is requested.
Sounds good.
I think you're saying this should be a global config option and I agree with that. Can we get the config migration code in so that I have a clean way of updating the config?
I don't think any migration is necessary - just add a sensible default value, maybe it should be 0 (no timeout). I'll check back in on the config migration PR this week.
Not currently. It only works with checkpoints. I'd planned to add diffusers support later, but I guess I should do that now. Converting to draft.
Ok, thanks.
The onnxruntime model loading architecture seems to be very different from what the model manager expects. In particular, the
onnxruntime.InferenceSession()constructor doesn't seem to provide any way to accept a model that has been read into RAM or VRAM. The closest I can figure is that you can pass the constructor anIOBytesobject to a serialized version of the model in memory. This will require some architectural changes in the model manager that should be its own PR.
I've played with this a bit. It is easy to load the openpose onnx sessions into the RAM cache and they will run happily under the existing MM cache system. However, Onnx sessions do their own internal VRAM/CUDA management, and so I found that for the duration of the time that the session object is in RAM, it holds on to a substantial chunk of VRAM (1.7GB). The openpose session is only used during conversion of an image into a pose model, and I think it's better to have slow disk-based loading of the openpose session than to silently consume a chunk of VRAM that interferes with later generation.
@psychedelicious This is ready for your review now. There are now just two calls: load_and_cache_model() and download_and_cache_model() which return a locally cached Path and LoadedModel respectively. In addition, the model source can now be a URL, a local Path, or a repo_id. Support for the latter involved my refactoring the way that multifile downloads work.
@psychedelicious I just updated the whole thing to work properly with the new (and very nice) Pydantic-based events. I've also added a new migration. Please review when you can. I'm having to resolve merge conflicts fairly regularly!
I removed a number of unnecessary changes in invocation_context.py, mostly extraneous type annotations. If mypy is complaining about these, then that's a mypy problem, because all the methods are annotated correctly.
I also moved load_model_from_url from the main model manager class into the invocation context.
@psychedelicious I've addressed the remaining issues you raised. Thanks for a thorough review.
I removed a number of unnecessary changes in
invocation_context.py, mostly extraneous type annotations. If mypy is complaining about these, then that's a mypy problem, because all the methods are annotated correctly.I also moved
load_model_from_urlfrom the main model manager class into the invocation context.
Yes, mypy is having trouble tracking the return type of several methods. I haven't figured out what causes the problem and don't want to add a # type: ignore. But maybe I should 'cause I'm not ready to turn to pyright.
Yes, mypy is having trouble tracking the return type of several methods. I haven't figured out what causes the problem and don't want to add a # type: ignore. But maybe I should 'cause I'm not ready to turn to pyright.
We shouldn't add # type: ignore, that will stop all type checkers from doing anything - including pyright. The places where you made code quality concessions to satisfy mypy involve very straightforward types - either your mypy config is borked or mypy itself is borked. FWIW, I've found pyright to be much faster, more thorough and more correct than mypy.
@RyanJDick Would you mind doing one last review of this PR?
Yes, mypy is having trouble tracking the return type of several methods. I haven't figured out what causes the problem and don't want to add a # type: ignore. But maybe I should 'cause I'm not ready to turn to pyright.
We shouldn't add
# type: ignore, that will stop all type checkers from doing anything - including pyright. The places where you made code quality concessions to satisfy mypy involve very straightforward types - either your mypy config is borked or mypy itself is borked. FWIW, I've found pyright to be much faster, more thorough and more correct than mypy.
You've convinced me. I've switched to pyright!
@RyanJDick Would you mind doing one last review of this PR?
Looks like 43/44 files have changed since I last looked 😅 . I'll plan to spend a chunk of time on this tomorrow.
@RyanJDick Can narrow that down to reviewing invocation_context.py, which changes the public API and is more important to get right the first time. Thanks.
@RyanJDick I've fixed the issues you identified.