diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

Add extra performance features for EMAModel, torch._foreach operations and better support for non-blocking CPU offloading

Open drhead opened this issue 1 year ago • 16 comments

What does this PR do?

This adds a few extra things to EMAModel intended to reduce its overhead and allow overlapping compute and transfers better.

First, I added a new option for foreach that, if set, will use torch._foreach functions for performing parameter updates and in-place copies. This should reduce kernel launch overhead on these operations by a fair amount. It can increase the peak memory usage of these operations, so it is disabled by default.

Second, I added a function for pinning memory for shadow parameters alongside an option to pass through the non_blocking parameter for the EMAModel.to() function (defaults to False). When used together, this should allow users to easily process EMA updates asynchronously while offloading the parameters to the CPU. Using this, it should be possible to handle EMA updates just as fast as if they lived on the GPU as long as it wasn't already taking longer than the entire training step with regular synchronous transfers.

I think this does need implementation in the training examples and further testing before merging, including profiling performance gains vs. standard for loop implementations and versus regular blocking CPU offload. I would also greatly appreciate someone verifying that this works on deepspeed as I don't have access to suitable hardware to test it on a multi-device setup. I will implement an example for the SD training script soon.

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

Training examples: @sayakpaul

drhead avatar Apr 16 '24 00:04 drhead

I think this does need implementation in the training examples and further testing before merging, including profiling performance gains vs. standard for loop implementations and versus regular blocking CPU offload.

Could you elaborate this a bit? Do you mean we should expose arguments to the users via CLI args so that they can control EMA stuff with more granularity?

sayakpaul avatar Apr 16 '24 04:04 sayakpaul

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

can you check whether the decay works correctly for you in a small test loop? for me, it does not. and i have to do:

                    ema_unet.optimization_step = global_step

(this is a problem in the prev code too)

bghira avatar Apr 16 '24 16:04 bghira

I think this does need implementation in the training examples and further testing before merging, including profiling performance gains vs. standard for loop implementations and versus regular blocking CPU offload.

Could you elaborate this a bit? Do you mean we should expose arguments to the users via CLI args so that they can control EMA stuff with more granularity?

I do have a basic example for it in the SD1.5 text to image training script right now. I think that exposing foreach as a CLI arg is a bit superfluous, so I'm only exposing the non-blocking CPU offload for now since I expect that this would be far more useful for most users. Once I profile it I will see what the memory usage impact of foreach is to see if it is low enough that it would be an appropriate default.

drhead avatar Apr 16 '24 21:04 drhead

So early testing on this version is showing that the foreach implementation is visibly faster even without having run a profiler on it, and doesn't seem to use any more VRAM (testing on SD1.5 training with batch size 1 and gradient checkpointing). Based on that, I think it is safe to set it to True by default.

The offloading is being a bit more troublesome with the standard training script and I'm not seeing the compute overlap that I expect, based on just looking at GPU usage. I know that it should work, since I have a training run active on another machine right now with almost nonstop 100% usage while offloading two EMA states at once. I'll have to profile it to see what differences matter here, the main ones are that my working config does not use Accelerate, is running on a headless server (my testing machine uses WSL), and uses a training loop wrapped in torch.compile and very carefully avoids materializing tensors on CPU.

drhead avatar Apr 17 '24 01:04 drhead

So early testing on this version is showing that the foreach implementation is visibly faster even without having run a profiler on it, and doesn't seem to use any more VRAM (testing on SD1.5 training with batch size 1 and gradient checkpointing). Based on that, I think it is safe to set it to True by default.

I would add that as a CLI argument but not set it to True by default. We should make a note about this feature in our docs and let the users tinker around with it first. If there's sufficient reception, we can definitely change to True as its default.

The offloading is being a bit more troublesome with the standard training script and I'm not seeing the compute overlap that I expect, based on just looking at GPU usage. I know that it should work, since I have a training run active on another machine right now with almost nonstop 100% usage while offloading two EMA states at once. I'll have to profile it to see what differences matter here, the main ones are that my working config does not use Accelerate, is running on a headless server (my testing machine uses WSL), and uses a training loop wrapped in torch.compile and very carefully avoids materializing tensors on CPU.

Yeah would love to know the differences. I think not using accelerate is the primary difference here but of course, we need to know more and better.

sayakpaul avatar Apr 17 '24 03:04 sayakpaul

Yeah would love to know the differences. I think not using accelerate is the primary difference here but of course, we need to know more and better.

This appears to have been correct. Accelerate's dataloader uses blocking transfers, where my other training script had pinned memory and exclusively non-blocking transfers. Submitted a patch that should fix that and make Accelerate's dataloader perform better in these situations: https://github.com/huggingface/accelerate/pull/2685

While it is faster now, there is one remaining apparent performance issue where there are a bunch of cudaHostMalloc calls for some (but not all) of the offloaded params when it is initiating the DtoH transfer: image

I strongly suspect this is a WSL issue, when our other training machine is free I will see if this issue happens there as well. Regardless of this issue though, in its current state this is faster than a blocking transfer would be.

drhead avatar Apr 18 '24 17:04 drhead

After testing/profiling on a different machine I'm fairly confident that the non-blocking offload is working as well as a given environment/other parts of the code can permit it to work. Most significantly dataloader data transfers with Accelerate (patch submitted already) and the .item() calls on metrics being reported, as well as a few other things that probably really can't be/aren't worth reporting without wrapping the training loop in torch.compile(), and which are fairly out of scope (and torch.compile is currently far too unstable to include in an example script IMO).

I'll add foreach as a commandline arg later, and then will mark as ready, then can work on docs and possibly propagating changes to other example scripts.

drhead avatar Apr 20 '24 03:04 drhead

Thank you for investigating. We can just with one example and then open it up to the community to follow your PR as a reference.

and the .item() calls on metrics being reported

How significant is this one?

sayakpaul avatar Apr 20 '24 03:04 sayakpaul

and the .item() calls on metrics being reported

How significant is this one?

They're not a huge issue if it is being called at most once per gradient update, because it necessarily must complete the steps to actually materialize the loss value. I don't have profiling data since I knew it would be a problem and commented them out immediately when testing, but the main offenders would be the tqdm postfix update and the updating of train_loss every forward pass (which is fine if not using gradient accumulation but isn't letting the dispatch queue fill up as much as it could otherwise).

Mitigating it without losing too much functionality would look like:

  • change train_loss to initialize as a scalar tensor on the accelerator.device
  • if accelerate.gather is differentiable/isn't already detaching, we should detach the loss value being passed to it (unlikely that this is causing huge issues but this isn't a bad practice)
  • remove the .item() call for accumulating avg_loss into train_loss (this will now be dispatched like every other operation in the chain)
  • under the sync_gradients branch:
    • add .item() to train_loss in accelerator.log
    • move the set_postfix portion to the sync_gradients branch and change it to use train_loss.item() -- we already materialized it so doing it again can't hurt
    • remove the train_loss = 0 line and replace it with train_loss *= 0 after the set_postfix portion, because that lets us reuse the buffer

In my experiences, this causes nearly no perceptible overhead. You won't see the loss for every forward pass, but the loss for the gradient update is arguably what you actually want to be tracking anyways. I could set up a PR for this if interested.

drhead avatar Apr 20 '24 04:04 drhead

Oh thanks so much. If this is not causing significant overheads, I would like to keep it as is for now.

sayakpaul avatar Apr 20 '24 04:04 sayakpaul

@sayakpaul is this good to merge?

yiyixuxu avatar Apr 22 '24 21:04 yiyixuxu

I assume not. But I will defer to @drhead for confirming.

sayakpaul avatar Apr 23 '24 01:04 sayakpaul

I assume not. But I will defer to @drhead for confirming.

It's effectively complete except for linting

drhead avatar Apr 23 '24 01:04 drhead

This is good to merge. We'd need to:

* Fix the code linting issues.

* Add a block about this feature in the README of `train_text_to_image.py`. This is quite beneficial.

Additionally, do you think we could add a test for this here: https://github.com/huggingface/diffusers/blob/main/tests/others/test_ema.py?

I have implemented the first two -- which tests are you interested in having? I would think running (nearly?) all of the tests on the foreach implementation to ensure they both pass would be appropriate, but if you think less than that is needed then let me know.

drhead avatar Apr 24 '24 19:04 drhead

Thanks so much.

Sorry for not being clear about the tests.

I would think running (nearly?) all of the tests on the foreach implementation to ensure they both pass would be appropriate, but if you think less than that is needed then let me know.

That sounds good to me. We could have a separate test module, test_ema_for_each.py and add all the tests there. But I am okay if you rather create a separate EMAModelForEachTests class in the same test_ema.py script and add the tests there.

sayakpaul avatar Apr 25 '24 01:04 sayakpaul

@drhead a gentle ping here :)

sayakpaul avatar Jun 22 '24 02:06 sayakpaul

I am happy to merge this PR once the conflicts are resolved and a simple test suite is added. I feel the test suite is important here.

sayakpaul avatar Jun 22 '24 02:06 sayakpaul

I am happy to merge this PR once the conflicts are resolved and a simple test suite is added. I feel the test suite is important here.

I admittedly haven't done anything with test suites in a very long time -- do I need to do anything other than merely adding the class like I've done?

drhead avatar Jun 22 '24 03:06 drhead

All good. Will merge once the CI is green. Thanks a bunch for your contributions!

sayakpaul avatar Jun 22 '24 03:06 sayakpaul

Seems like there is a test failure: https://github.com/huggingface/diffusers/actions/runs/9622457113/job/26543763957?pr=7685#step:10:361

sayakpaul avatar Jun 22 '24 03:06 sayakpaul

pin_memory needs to be blocked from running if torch.backends.mps.is_available()

bghira avatar Jun 22 '24 12:06 bghira

2024-06-22 06:27:12,238 [ERROR] (__main__) Failed to pin EMA model to CPU: cannot pin 'MPSBFloat16Type' only dense CPU tensors can be pinned

bghira avatar Jun 22 '24 12:06 bghira

i've been testing a variant of this.

  • we can add arg ema_cpu_only to keep EMA on CPU forever if we don't pin anything, but instead, use s_param.sub_(one_minus_decay * (s_param - param.to(s_param.device))) for calculation to move the base model param to the CPU instead of EMA to GPU. this didn't noticeably increase calculation runtime but does noticeably reduce vram
  • we can add arg ema_update_interval to only update eg. every 5-100 steps
  • save_pretrained needs max_shard_size added as an arg, and pass through to the base method

bghira avatar Jun 22 '24 14:06 bghira

Perhaps these and the mps fix could be clubbed in a separate PR?

sayakpaul avatar Jun 22 '24 15:06 sayakpaul

well, great work on the conceptual design of this change. on a macbook pro m3 max it improves iteration times such that EMA produces no visible impact on training speed, even if we never move it to the GPU.

bghira avatar Jun 22 '24 15:06 bghira

Thanks a lot for working on this feature and for iterating on it.

sayakpaul avatar Jun 24 '24 08:06 sayakpaul