InvokeAI
InvokeAI copied to clipboard
Optimize RAM to VRAM transfer
Summary
This PR speeds up the model manager’s system for moving models back and forth between RAM and VRAM . Instead of calling model.to() to accomplish the transfer, the model manager now stores a copy of the model’s state dict in RAM. When the model needs to be moved into VRAM for inference, the manager makes a VRAM copy of the state dict and assigns it to the model using load_state_dict(). When inference is done, the model is cleared from VRAM by calling load_state_dict() with the CPU copy of the state dict.
Benchmarking an SDXL model shows an improvement from 3 seconds to 0.81 seconds for a model load/unload cycle. Most of the improvement comes from the unload step, as shown in the table below:
model from to old(s) new(s)
----- ---- -- ----- -----
unet cpu cuda:0 0.69 0.52
text_encoder cpu cuda:0 0.15 0.09
text_encoder_2 cpu cuda:0 0.16 0.14
vae cpu cuda:0 0.02 0.02
LOAD TO CUDA TOTAL 1.02 0.77
unet cuda:0 cpu 1.45 0.03
vae cuda:0 cpu 0.09 0.00
text_encoder cuda:0 cpu 0.07 0.00
text_encoder_2 cuda:0 cpu 0.40 0.01
UNLOAD FROM CUDA TOTAL 2.01 0.04
Thanks to @RyanJDick for suggesting this load/unload scheme.
Related Issues / Discussions
QA Instructions
Change models a number of times. Monitor RAM and VRAM for memory leaks.
Merge Plan
Merge when approved
Checklist
- [X] The PR has a short but descriptive title, suitable for a changelog
- [X] Tests added / updated (if applicable)
- [X] Documentation added / updated (if applicable)
This is a huge speed up!!! Awesome. Will wait for @RyanJDick to take a look
I ran into a bug during testing. I hadn't thought about this before, but this approach breaks if a model is moved between devices while a patch is applied. I can trigger this by using a TI. The series of events is:
- The text encoder is loaded and registered with the model cache.
- We apply the TI to the text encoder while the text encoder is on the CPU. This patch creates a new tensor of token embeddings with a different shape.
- We attempt to move the text encoder to the GPU. This operation fails because the state_dict tensor sizes no longer match.
We could probably find a quick way to solve this particular problem, but it makes me worry about the risk of similar bugs. We need clear rules for how the model cache and model patching are intended to interact.
One approach would be to require that models are patched and unpatched during the span of a model cache lock. TIs are a little weird in that the patch is applied on the CPU before copying the model to GPU. We should look into whether we can just do all of this on the GPU. If not, we may have to consider splitting the concepts of model access locking and model device locking.
@RyanJDick I'm not all that familiar with model patching. Is patching done prior to every generation and then reversed? If so, the trick would be to refresh the cached state_dict whenever patching is done on a CPU-based model.
Patching (for LoRA or TI) is managed using context managers (applied on entry, and reversed on exit).
Examples:
- https://github.com/invoke-ai/InvokeAI/blob/e22248466364f6895ed55f4dacda482641a4af58/invokeai/app/invocations/compel.py#L179-L189
- https://github.com/invoke-ai/InvokeAI/blob/e22248466364f6895ed55f4dacda482641a4af58/invokeai/app/invocations/latent.py#L931-L938
Now that the model cache has the power to modify a model's weights (restore them to a previous state), we need clearer ownership semantics (i.e. who can modify a model?, when can they modify it? what guarantees do they have to offer?).
Designing this well would take more thought / effort than I can spend on it right now.
We might be able to take a shortcut to get this working now though. I think this might be achievable with some combination of:
- Store the model state_dict at the time that the model is moved to the device instead of at the time that the model is added to the cache.
- Make TI patching work with on-device models and switch the order of the context managers.
- Make this new optimized behavior configurable. I.e. something like
with model_info.on_device(allow_copy=True) as model:
More investigation needed to figure out which of those makes the most sense.
@RyanJDick I finally got back to this after an interlude. It was a relatively minor fix to get all the model patching done after loading the model into the target device, and the code is cleaner too. I've tested LoRA, TI and clip skip, and they all seem to be working as expected. Seamless doesn't seem to do much of anything, either with this PR or on current main. Not sure what's up with that; I haven't used seamless in over a year.
BTW - seamless working fine for me. Tested SD1.5 and SDXL, all permutations of axes. I wonder if there is some interaction with other settings you were using?
Note: I tested this PR to see if it fixed #6375. It does not.
I understand the overall strategy, but I'm having trouble wrapping my head around the fix for models changing device. If I understand correctly, the solution is very simple - re-order the context managers. Can you ELI5 how changing the order of the context managers fixes this?
Also curious about this edge case - say we have two compel nodes:
- We execute compel node 1. At this time, the models are in VRAM.
- Time passes and we load other models, evicting the UNet and CLIP from VRAM. Maybe they are in RAM, maybe they aren't cached at all.
- We execute compel node 2. Is this a problem?
The context managers were reordered so that the context manager calls that lock the model in VRAM are executed before the patches are applied, and it is the locked model that is passed to the patchers. I also switched the relative order of the TI and LoRA patchers, but only because it made the code formatting easier to read. I tested both orders and got identical images.
Here's the edge case:
- We execute compel node 1. The models may or may not be in VRAM (and may not be in RAM either). When the compel invocation runs, the models are moved into VRAM by the model manager's context manager and locked there for the duration of the context. Within the context the model is patched in VRAM.
- As soon as the compel context is finished, the model is unpatched. It is also likely removed from VRAM unless it happens to fit into the VRAM cache space.
- A new compel node is executed. If the model is no longer in VRAM, a fresh copy of the model weights are copied into VRAM and the process described in step 1 is repeated.
RAM->VRAM operations are about twice as fast as VRAM->RAM on my system. I am tempted to remove the VRAM cache entirely so that we are guaranteed to have a fresh copy of the model weights each time. However, if the patchers are unpatching correctly, this shouldn't be an issue.
Note: I tested this PR to see if it fixed #6375. It does not.
Rats. I was rather hoping it would. I'm digging into the LoRA loading issue now.
BTW - seamless working fine for me. Tested SD1.5 and SDXL, all permutations of axes. I wonder if there is some interaction with other settings you were using?
It is working for me as well. I just had to adjust the image dimensions to see the effect. Seamless is not something I ever use.
@psychedelicious @RyanJDick I have included a fix for #6375 in this PR. There was some old model cache code originally written by Stalker that traversed the garbage collector and forcibly deleted local variables from unused stack frames. This code was written to work around a Python 3.9 GC bug, but it seems to wreak havoc on context managers. The low RAM cache setting simply triggered the problem. I suspect this may have caused rare failures in other contexts as well (pun intended).
I removed the code and tested for signs of memory leaks. I didn't see any, but please keep an eye out.
Going off on a tangent, while reviewing the patching code, I discovered that lora patching uses the following pattern:
- Load the LoRA as
lora_infousing the MM - Get the model in CPU using
lora_info.model - Iterate through each of the LoRA layers, move them into VRAM
- Apply the LoRA layer weights (saving the original model weights to restore later)
- Move the layer back to RAM.
I think it would be more performant to:
- Load the LoRA using the MM
- Enter the context that moves the LoRA weights into VRAM (using the new RAM->VRAM transfer)
- Apply the weights layer by layer
- Exit the context
The downside is that this will transiently use more VRAM because all the LoRA layers are loaded at once.
Another potential optimization would be to stop saving the original model's weights on entry to the patcher context and restoring them on exit. Since we are now keeping a virgin copy of the state dictionary in the RAM cache, the patched model in VRAM is cleared out at the end of a node's invocation and will be replaced with a fresh copy the next time it is needed.
I gave both of these things a quick try and the system felt snappier, but I didn't do timing or any stress tests. If you think this is worth pursuing, I'll submit a new PR for them.
[EDIT] I can shave off ~2s of generation walltime (from 10.8 to 8.8s) by avoiding the unecessary step of restoring weights to the VRAM copy of the model.
The latest commit implements an optimization that circumvents the LoRA unpatching step when working with a model that is resident in CUDA VRAM. This works because the new scheme never copies the model weights back from VRAM ito RAM, but instead reinitializes the VRAM copy from a fresh RAM state dict the next time the model is needed. The behavior for CPU and MPS devices has not changed, since these operate on the RAM copy. When generating with SDXL models, this optimization saves roughly 1s per LoRA per generation, which I think makes the special casing worth it.
The other optimization I tried was to let the model manager load the LoRA into VRAM using its usual model locking mechanism rather than manually moving each layer into VRAM before patching. However, this did not give a performance gain and needed special casing for LoRAs in the model manager because LoRAs don't have load_state_dict.
Other changes in this commit:
- I have removed the VRAM cache along with the configuration variables that control its behavior. The cache is incompatible with the LoRA unpatching optimization, and I was planning to get rid of it anyway given the trouble users have with it.
- I have updated the config schema to 4.0.2 and added a migration script that removes the VRAM settings from the user's invokeai.yaml file.
- I have modified the
test_lora.pytest to accommodate the lack of unpatching when running on CUDA.
Did you test the effect of removing the VRAM cache with a large VRAM cache size (e.g. large enough to hold all working models)? For this usage pattern, I'm afraid that there is going to be a significant speed regression from removing it.
I had a chance to do some testing of this today. As I suspected, removing the VRAM cache does result in a regression when the VRAM cache is large enough to hold the larger models (>1 sec per generation). The improved LoRA patching speed makes up for it in some cases, but not in others.
This PR has grown quite a bit in scope. It now covers:
- Keep copy on CPU to improve VRAM offload speed
- Remove the hacky garbage collection logic that is causing the lora patching bug
- Remove the VRAM cache altogether
- Optimize LoRA patch/unpatch time
How about we split these up so that we can properly evaluate and test each one? I feel like we definitely want 1 and 2 (ideally as separate PRs). 3 and 4 come with major tradeoffs, maybe we can find a way to get both the benefit of a VRAM cache and smarter LoRA patching.
Did you test the effect of removing the VRAM cache with a large VRAM cache size (e.g. large enough to hold all working models)? For this usage pattern, I'm afraid that there is going to be a significant speed regression from removing it.
I had a chance to do some testing of this today. As I suspected, removing the VRAM cache does result in a regression when the VRAM cache is large enough to hold the larger models (>1 sec per generation). The improved LoRA patching speed makes up for it in some cases, but not in others.
This PR has grown quite a bit in scope. It now covers:
- Keep copy on CPU to improve VRAM offload speed
- Remove the hacky garbage collection logic that is causing the lora patching bug
- Remove the VRAM cache altogether
- Optimize LoRA patch/unpatch time
How about we split these up so that we can properly evaluate and test each one? I feel like we definitely want 1 and 2 (ideally as separate PRs). 3 and 4 come with major tradeoffs, maybe we can find a way to get both the benefit of a VRAM cache and smarter LoRA patching.
4 is dependent on 3. How about I just remove the code changes for 3 and 4 and we can consider them as a separate future PR? This is easier for me as I’ll just reset to an earlier commit.
@RyanJDick I’ve undone the model patching changes and the removal of the VRAM cache, and what’s left is the original cpu->vram optimization, the fix to the TI patching, and the weird context manager bug that was causing LoRAs not to patch. It is a fairly minimal PR now, so I hope we can get it merged. I’ll work on the LoRA patching optimization separately.
@lstein There are a few torch features that might stack nicely on this PR to give even more speedup for Host-to-Device copies:
torch.Tensor.pin_memory()torch.Tensor.to(..., non_blocking=True)
Have you looked into these at all? I don't want to expand the scope of this PR, but these could be an easy follow-up if you're interested in trying them out (or I can do it).
I'm going to merge this in and then will start working on further optimizations including the lora loading/unloading.
Awesome. Thanks for splitting up the PRs.
I did some quick manual regression testing - everything looked good. I tried:
- Text-to-image, LoRA, TI
- CPU-only
- A bunch of model switching - no obvious signs of a memory leak.
I also ran some performance tests. With
vram: 0.25:
- SDXL T2I, cold cache: 10.4s -> 9.6s
- SDXL T2I, warm cache: 6.9s -> 6.1s
- SDXL T2I + 2 LoRA, warm cache: 9.0 -> 8.6s
With
vram: 16(no significant change, as expected):
- SDXL T2I, cold cache: 8.0s -> 8.0s
- SDXL T2I, warm cache: 4.7s -> 4.6s
- SDXL T2I + 2 LoRA, warm cache: 6.9s -> 6.9s
Thanks for doing the timings. It's not as big a speedup as I saw, but probably very dependent on hardware.
This PR has grown quite a bit in scope. It now covers: