Add extra performance features for EMAModel, torch._foreach operations and better support for non-blocking CPU offloading
What does this PR do?
This adds a few extra things to EMAModel intended to reduce its overhead and allow overlapping compute and transfers better.
First, I added a new option for foreach that, if set, will use torch._foreach functions for performing parameter updates and in-place copies. This should reduce kernel launch overhead on these operations by a fair amount. It can increase the peak memory usage of these operations, so it is disabled by default.
Second, I added a function for pinning memory for shadow parameters alongside an option to pass through the non_blocking parameter for the EMAModel.to() function (defaults to False). When used together, this should allow users to easily process EMA updates asynchronously while offloading the parameters to the CPU. Using this, it should be possible to handle EMA updates just as fast as if they lived on the GPU as long as it wasn't already taking longer than the entire training step with regular synchronous transfers.
I think this does need implementation in the training examples and further testing before merging, including profiling performance gains vs. standard for loop implementations and versus regular blocking CPU offload. I would also greatly appreciate someone verifying that this works on deepspeed as I don't have access to suitable hardware to test it on a multi-device setup. I will implement an example for the SD training script soon.
Before submitting
- [x] Did you read the contributor guideline?
- [x] Did you read our philosophy doc (important for complex PRs)?
- [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
Training examples: @sayakpaul
I think this does need implementation in the training examples and further testing before merging, including profiling performance gains vs. standard for loop implementations and versus regular blocking CPU offload.
Could you elaborate this a bit? Do you mean we should expose arguments to the users via CLI args so that they can control EMA stuff with more granularity?
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
can you check whether the decay works correctly for you in a small test loop? for me, it does not. and i have to do:
ema_unet.optimization_step = global_step
(this is a problem in the prev code too)
I think this does need implementation in the training examples and further testing before merging, including profiling performance gains vs. standard for loop implementations and versus regular blocking CPU offload.
Could you elaborate this a bit? Do you mean we should expose arguments to the users via CLI args so that they can control EMA stuff with more granularity?
I do have a basic example for it in the SD1.5 text to image training script right now. I think that exposing foreach as a CLI arg is a bit superfluous, so I'm only exposing the non-blocking CPU offload for now since I expect that this would be far more useful for most users. Once I profile it I will see what the memory usage impact of foreach is to see if it is low enough that it would be an appropriate default.
So early testing on this version is showing that the foreach implementation is visibly faster even without having run a profiler on it, and doesn't seem to use any more VRAM (testing on SD1.5 training with batch size 1 and gradient checkpointing). Based on that, I think it is safe to set it to True by default.
The offloading is being a bit more troublesome with the standard training script and I'm not seeing the compute overlap that I expect, based on just looking at GPU usage. I know that it should work, since I have a training run active on another machine right now with almost nonstop 100% usage while offloading two EMA states at once. I'll have to profile it to see what differences matter here, the main ones are that my working config does not use Accelerate, is running on a headless server (my testing machine uses WSL), and uses a training loop wrapped in torch.compile and very carefully avoids materializing tensors on CPU.
So early testing on this version is showing that the foreach implementation is visibly faster even without having run a profiler on it, and doesn't seem to use any more VRAM (testing on SD1.5 training with batch size 1 and gradient checkpointing). Based on that, I think it is safe to set it to True by default.
I would add that as a CLI argument but not set it to True by default. We should make a note about this feature in our docs and let the users tinker around with it first. If there's sufficient reception, we can definitely change to True as its default.
The offloading is being a bit more troublesome with the standard training script and I'm not seeing the compute overlap that I expect, based on just looking at GPU usage. I know that it should work, since I have a training run active on another machine right now with almost nonstop 100% usage while offloading two EMA states at once. I'll have to profile it to see what differences matter here, the main ones are that my working config does not use Accelerate, is running on a headless server (my testing machine uses WSL), and uses a training loop wrapped in torch.compile and very carefully avoids materializing tensors on CPU.
Yeah would love to know the differences. I think not using accelerate is the primary difference here but of course, we need to know more and better.
Yeah would love to know the differences. I think not using
accelerateis the primary difference here but of course, we need to know more and better.
This appears to have been correct. Accelerate's dataloader uses blocking transfers, where my other training script had pinned memory and exclusively non-blocking transfers. Submitted a patch that should fix that and make Accelerate's dataloader perform better in these situations: https://github.com/huggingface/accelerate/pull/2685
While it is faster now, there is one remaining apparent performance issue where there are a bunch of cudaHostMalloc calls for some (but not all) of the offloaded params when it is initiating the DtoH transfer:
I strongly suspect this is a WSL issue, when our other training machine is free I will see if this issue happens there as well. Regardless of this issue though, in its current state this is faster than a blocking transfer would be.
After testing/profiling on a different machine I'm fairly confident that the non-blocking offload is working as well as a given environment/other parts of the code can permit it to work. Most significantly dataloader data transfers with Accelerate (patch submitted already) and the .item() calls on metrics being reported, as well as a few other things that probably really can't be/aren't worth reporting without wrapping the training loop in torch.compile(), and which are fairly out of scope (and torch.compile is currently far too unstable to include in an example script IMO).
I'll add foreach as a commandline arg later, and then will mark as ready, then can work on docs and possibly propagating changes to other example scripts.
Thank you for investigating. We can just with one example and then open it up to the community to follow your PR as a reference.
and the .item() calls on metrics being reported
How significant is this one?
and the .item() calls on metrics being reported
How significant is this one?
They're not a huge issue if it is being called at most once per gradient update, because it necessarily must complete the steps to actually materialize the loss value. I don't have profiling data since I knew it would be a problem and commented them out immediately when testing, but the main offenders would be the tqdm postfix update and the updating of train_loss every forward pass (which is fine if not using gradient accumulation but isn't letting the dispatch queue fill up as much as it could otherwise).
Mitigating it without losing too much functionality would look like:
- change
train_lossto initialize as a scalar tensor on the accelerator.device - if
accelerate.gatheris differentiable/isn't already detaching, we should detach the loss value being passed to it (unlikely that this is causing huge issues but this isn't a bad practice) - remove the
.item()call for accumulatingavg_lossintotrain_loss(this will now be dispatched like every other operation in the chain) - under the sync_gradients branch:
- add
.item()totrain_lossinaccelerator.log - move the
set_postfixportion to the sync_gradients branch and change it to usetrain_loss.item()-- we already materialized it so doing it again can't hurt - remove the
train_loss = 0line and replace it withtrain_loss *= 0after theset_postfixportion, because that lets us reuse the buffer
- add
In my experiences, this causes nearly no perceptible overhead. You won't see the loss for every forward pass, but the loss for the gradient update is arguably what you actually want to be tracking anyways. I could set up a PR for this if interested.
Oh thanks so much. If this is not causing significant overheads, I would like to keep it as is for now.
@sayakpaul is this good to merge?
I assume not. But I will defer to @drhead for confirming.
I assume not. But I will defer to @drhead for confirming.
It's effectively complete except for linting
This is good to merge. We'd need to:
* Fix the code linting issues. * Add a block about this feature in the README of `train_text_to_image.py`. This is quite beneficial.Additionally, do you think we could add a test for this here: https://github.com/huggingface/diffusers/blob/main/tests/others/test_ema.py?
I have implemented the first two -- which tests are you interested in having? I would think running (nearly?) all of the tests on the foreach implementation to ensure they both pass would be appropriate, but if you think less than that is needed then let me know.
Thanks so much.
Sorry for not being clear about the tests.
I would think running (nearly?) all of the tests on the foreach implementation to ensure they both pass would be appropriate, but if you think less than that is needed then let me know.
That sounds good to me. We could have a separate test module, test_ema_for_each.py and add all the tests there. But I am okay if you rather create a separate EMAModelForEachTests class in the same test_ema.py script and add the tests there.
@drhead a gentle ping here :)
I am happy to merge this PR once the conflicts are resolved and a simple test suite is added. I feel the test suite is important here.
I am happy to merge this PR once the conflicts are resolved and a simple test suite is added. I feel the test suite is important here.
I admittedly haven't done anything with test suites in a very long time -- do I need to do anything other than merely adding the class like I've done?
All good. Will merge once the CI is green. Thanks a bunch for your contributions!
Seems like there is a test failure: https://github.com/huggingface/diffusers/actions/runs/9622457113/job/26543763957?pr=7685#step:10:361
pin_memory needs to be blocked from running if torch.backends.mps.is_available()
2024-06-22 06:27:12,238 [ERROR] (__main__) Failed to pin EMA model to CPU: cannot pin 'MPSBFloat16Type' only dense CPU tensors can be pinned
i've been testing a variant of this.
- we can add arg ema_cpu_only to keep EMA on CPU forever if we don't pin anything, but instead, use
s_param.sub_(one_minus_decay * (s_param - param.to(s_param.device)))for calculation to move the base model param to the CPU instead of EMA to GPU. this didn't noticeably increase calculation runtime but does noticeably reduce vram - we can add arg ema_update_interval to only update eg. every 5-100 steps
- save_pretrained needs
max_shard_sizeadded as an arg, and pass through to the base method
Perhaps these and the mps fix could be clubbed in a separate PR?
well, great work on the conceptual design of this change. on a macbook pro m3 max it improves iteration times such that EMA produces no visible impact on training speed, even if we never move it to the GPU.
Thanks a lot for working on this feature and for iterating on it.