pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

Generic weight averaging callback that supports EMA

Open senarvi opened this issue 11 months ago • 49 comments

A callback that updates an AveragedModel after every training step

What does this PR do?

This is similar to the existing StochasticWeightAveraging callback, but wraps the AveragedModel class from PyTorch. Reduced code duplication means easier maintenance. Also, any averaging function can be used. By default, the callback does averaging on every step, but this can be customized by overriding the should_update(step_idx, epoch_idx) method.

Fixes #10914

Before submitting
  • [x] Was this discussed/agreed via a GitHub issue? (not for typos and docs) => Discussed in issue #10914
  • [x] Did you read the contributor guideline, Pull Request section?
  • [x] Did you make sure your PR does only one thing, instead of bundling different changes together?
  • [x] Did you make sure to update the documentation with your changes?
  • [x] Did you write any new necessary tests? (not for typos and docs)
  • [x] Did you verify new and existing tests pass locally with your changes?
  • [x] Did you list all the breaking changes introduced by this pull request? => There are none.
  • [x] Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)

PR review

Anyone in the community is welcome to review the PR. Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:

Reviewer checklist
  • [ ] Is this pull request ready for review? (if not, please submit in draft mode)
  • [ ] Check that all items from Before submitting are resolved
  • [ ] Make sure the title is self-explanatory and the description concisely explains the PR
  • [ ] Add labels and milestones (and optionally projects) to the PR so it can be classified

📚 Documentation preview 📚: https://pytorch-lightning--20545.org.readthedocs.build/en/20545/

senarvi avatar Jan 14 '25 12:01 senarvi

Hey @senarvi, this looks great!

I saw you already added support for saving and resuming which is great. There are many scenarios there (save every n steps, time-based, every epoch, etc) let's make sure we cover them all (for inspiration, we added quite a few tests here https://github.com/Lightning-AI/pytorch-lightning/pull/20379)

we could still have different callbacks ("StepwiseAveragingCallback" and "EpochwiseAveragingCallback")

No I think it's better to have one with configurable averaging flags, more lightning-esque

Constructs the AveragedModel with use_buffers=True, so that an extra step is not needed for updating the batch normalization statistics. StochasticWeightAveraging performs an extra step in the end. Consequently the implementation is significantly more complex and it's difficult to make sure that it works in all cases. Should we add this as an option in this class too?

I think this is ok, but my doubt with forcing use_buffers to be true is what happens when a user has a module with buffers in it that are not meant to be averaged. I guess at that point they will probably be the same over time (e.g. the RoPE cache), but that's not really a guarantee.

Wdyt about this? I don't necessarily want to make the implementation more complex, so this is just for discussion.

Updates the average model after every step. StochasticWeightAveraging updates the average model after every epoch, and I recall that the original paper updated it only at certain points (the learning rate minima). I guess it would be nice to be able to select whether the average model will be updated after every step, after every epoch, or after certain epochs. Then we would need only one callback and we could remove the StochasticWeightAveraging callback, but would it make this class too complex?

It would be nice to make it configurable, and probably users will want to get to some minimum and then start averaging. The criteria to do so may be very bespoke, so maybe allowing the user to implement a custom hook to decide whether to start averaging or whether to average at a given step would be super handy. Otherwise I'm expecting users will train for some time, save a checkpoint, then reload with this callback added to the trainer and start averaging. Which is totally fine but it requires you to stop and resume.

Regarding removing the StochasticWeightAveraging callback, I don't necessarily see that happening. We have a pretty strong commitment to backward compatibility at this point, so keeping that in with a notice to just use this one will not hurt.

lantiga avatar Jan 14 '25 21:01 lantiga

I think this is ok, but my doubt with forcing use_buffers to be true is what happens when a user has a module with buffers in it that are not meant to be averaged. I guess at that point they will probably be the same over time (e.g. the RoPE cache), but that's not really a guarantee.

That's a good point. I don't know what would be a good solution.

Updates the average model after every step. StochasticWeightAveraging updates the average model after every epoch, and I recall that the original paper updated it only at certain points (the learning rate minima). I guess it would be nice to be able to select whether the average model will be updated after every step, after every epoch, or after certain epochs. Then we would need only one callback and we could remove the StochasticWeightAveraging callback, but would it make this class too complex?

It would be nice to make it configurable, and probably users will want to get to some minimum and then start averaging. The criteria to do so may be very bespoke, so maybe allowing the user to implement a custom hook to decide whether to start averaging or whether to average at a given step would be super handy. Otherwise I'm expecting users will train for some time, save a checkpoint, then reload with this callback added to the trainer and start averaging. Which is totally fine but it requires you to stop and resume.

That's an interesting idea. We could have the user pass a function update_on_step(global_step) or update_on_epoch(epoch) that returns a boolean. After each optimizer step and after each epoch we would call the function to check whether we should update the average model.

It seems that AveragedModel will copy the current model parameters when called the first time, and update the average on subsequent calls. This means that the first average is computed when update_on_step() or update_on_epoch() returns True for the second time. I don't see a better alternative.

I checked how StochasticWeightAveraging does this and I think it doesn't work correctly. It only ever updates the average model parameters in on_train_epoch_start(), so the average is not updated after the last epoch. Just shows why I'd like to keep the logic as simple as possible.

senarvi avatar Jan 15 '25 12:01 senarvi

Hi, I have a couple questions.

  1. You added the on_validation_epoch_start and on_validation_epoch_end hooks to swap the weights, but shouldn't the same happen for test?
  2. In my current workflow I have a separate script that does the model exporting to ONNX. It's short, and really the only Lightning specific thing is the MyLightningModule.load_from_checkpoint(...) method. Since the averaged weights are a part of the callback, I would have to instantiate the trainer for the weights to be loaded. And even then, I wouldn't have a function I could call to explicitly swap the weights (since _swap_weights is private and not really accessible). So, my question is, can we have some sort of an API, outside of the trainer, that can load the averaged weights instead of the regular weights? Perhaps adding some sort of a parameter to the load_from_checkpoint method?

cyanic-selkie avatar Jan 16 '25 14:01 cyanic-selkie

Hi @cyanic-selkie

During training (stage=fit), the actual LightningModule is what we update using the optimizer (I call it the current model) and an AveragedModel is maintained in the background (I call it the average model).

I assume that validation is only called during training. Before and after validation we swap the current model and the average model, so the average model will be validated.

When saving a checkpoint, we save the average model parameters in the state_dict. So if you later load the checkpoint without WeightAveraging callback and run a test or export to ONNX, you will be using the average parameters.

When training ends, we copy the average model parameters to the current model. So if you run a test or export to ONNX after training, you will be using the average parameters.

That's the idea at least. I'm not confident that I have thought about every possible corner case. It would be great if you could test that it works in your case.

senarvi avatar Jan 16 '25 15:01 senarvi

@senarvi Ah! Thanks for the clarification, I should've checked the code out more carefully. I tried your branch out on a quantization aware training enabled model with ONNX export at the end and everything is working beautifully! I hope this gets merged quickly.

cyanic-selkie avatar Jan 17 '25 11:01 cyanic-selkie

The user can now provide either the update_on_step or the update_on_epoch argument. (In theory also both.) It should be a function that takes the step/epoch number and returns True if the average model should be updated at that point of time.

For example:

update_on_step = lambda x: x > 100

or

update_on_epoch = lambda x: x in (3, 5, 7)

Using update_on_epoch, SWA should be possible. I added one unit test for SWA.

I tested EMA in an actual learning task and it gave an improvement, so I'm starting to be more confident that this works.

I think the biggest question that is still left is whether it's a problem that we force use_buffers=True. It would be nice if we could provide the option to instead call update_bn() after training and we wouldn't have to duplicate any of that code. That function takes a data loader and iterates through the data. I can imagine that passing the Trainer's data loader might not work in all cases. We could also leave calling this function to the user.

StochasticWeightAveraging increments the number of epochs in on_fit_start() and during the extra epoch disables the backward pass. I could also copy the code from that class, but there are some details that I don't understand, and I'm not that excited of copying code that I don't fully understand.

@tchaton I think you contributed the StochasticWeightAveraging callback, maybe you have some insight?

senarvi avatar Jan 23 '25 17:01 senarvi

Is there anything blocking this from being merged?

cyanic-selkie avatar Feb 02 '25 18:02 cyanic-selkie

I marked this ready for review. There were no comments whether it's a problem that we force use_buffers=True. Would it make sense to merge this now and perhaps introduce such option later based on the feedback that we receive?

senarvi avatar Feb 02 '25 21:02 senarvi

Codecov Report

Attention: Patch coverage is 94.68085% with 5 lines in your changes missing coverage. Please review.

Project coverage is 79%. Comparing base (831870a) to head (5deb0bb).

:exclamation: There is a different number of reports uploaded between BASE (831870a) and HEAD (5deb0bb). Click for more details.

HEAD has 349 uploads less than BASE
Flag BASE (831870a) HEAD (5deb0bb)
cpu 105 27
python3.10 24 6
lightning_fabric 26 0
pytest 57 0
python 12 3
python3.12 10 3
python3.12.7 35 9
lightning 60 15
python3.11 24 6
gpu 4 0
pytorch2.1 12 6
pytorch_lightning 23 12
pytest-full 52 27
pytorch2.2.2 6 3
pytorch2.3 6 3
pytorch2.5 6 3
pytorch2.6 6 3
pytorch2.4.1 6 3
pytorch2.5.1 5 3
pytorch2.7 5 3
Additional details and impacted files
@@            Coverage Diff            @@
##           master   #20545     +/-   ##
=========================================
- Coverage      87%      79%     -8%     
=========================================
  Files         268      266      -2     
  Lines       23449    23488     +39     
=========================================
- Hits        20389    18475   -1914     
- Misses       3060     5013   +1953     
:rocket: New features to boost your workflow:
  • :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov[bot] avatar Feb 02 '25 21:02 codecov[bot]

BTW: I think it's totally fine to merge this as is and open an issue to gather discussions about averaging buffers.

The other question I have (for the future) is related to fitting both models on GPU. It may make sense to give the ability to keep the AveragedModel on a different device (e.g. cpu) to keep the callback usable with larger models.

lantiga avatar Feb 03 '25 23:02 lantiga

The other question I have (for the future) is related to fitting both models on GPU. It may make sense to give the ability to keep the AveragedModel on a different device (e.g. cpu) to keep the callback usable with larger models.

There's a device argument already, and actually the default is cpu - as with StochasticWeightAveraging.

senarvi avatar Feb 04 '25 15:02 senarvi

Hi! Thanks for this great PR. The current implementation only leverages avg_fn argument should it also consider the in-place version multi_avg_fn ?

h2o64 avatar Feb 21 '25 09:02 h2o64

Hi! Thanks for this great PR. The current implementation only leverages avg_fn argument should it also consider the in-place version multi_avg_fn ?

I think we could just pass **averaged_model_kwargs. I'll look into it over the weekend.

senarvi avatar Feb 21 '25 09:02 senarvi

Now any extra keyword arguments will be passed to the AveragedModel constructor, so you can provide either avg_fn or multi_avg_fn.

You can also set use_buffers=False, if you don't want the model buffers to be averaged. In that case it's your responsibility to call torch.optim.swa_utils.update_bn() afterwards, if needed.

By the way, @h2o64 would you mind testing that the callback works in your use case and leaving a review? Testing it is pretty simple, because the entire implementation is contained in one class. The easiest way to test it is to simply copy the WeightAveraging class from the source repo to your project.

senarvi avatar Feb 22 '25 14:02 senarvi

A respectful bump :)

cyanic-selkie avatar Mar 21 '25 22:03 cyanic-selkie

Hi @cyanic-selkie ! Thanks for bringing this into attention. I fixed the merge conflict in CHANGELOG. I also updated the documentation (training tricks). So, from my part this is ready and I've been using it succesfully. You also said that it's working for you. Some tests fail, but I think the problem is not in this PR. @lantiga does this look like it could be merged?

senarvi avatar Mar 22 '25 08:03 senarvi

Works well for me too <3

A small suggestion, can we find a better way to print Loading the average model parameters for validation. and Recovering the current model parameters after validation.. Those are making tqdm cry and piss.

image

catalpaaa avatar Apr 01 '25 06:04 catalpaaa

will check it shortly :)

Borda avatar Apr 03 '25 16:04 Borda

Thank you @catalpaaa for checking the PR! I removed the logging. I think it's not needed anymore. Originally, my biggest fear was that the weights are not transferred correctly, which could go unnoticed. And actually... I was kind of right.

The test_ema_resume test was failing in the CI pipeline. I had to set the tolerance to a suspiciously high number when comparing the model weights. So I started once again looking into where that difference comes from.

test_ema_resume trains two models, one without interruption and one with an intentional crash and recovery from a checkpoint after N epochs. I noticed that the weights that I load from the checkpoint are not the same that the model has in the beginning of epoch N+1. It appears that Lightning has already loaded the model weights from "state_dict" when entering the on_load_checkpoint() callback. I had assumed that I can swap the "current_model_state" to "state_dict" in the callback, but in fact I have to reload the model state from "current_model_state".

I'm super happy that I found this bug. Now, I think, the only test that's failing is tests_pytorch/callbacks/test_pruning.py::test_pruning_callback_ddp[True-True] in PyTorch | oldest. I don't find other error except Bash exited with code '1'.

senarvi avatar Apr 04 '25 07:04 senarvi

When training with TPUs, I add

callbacks.append(WeightAveraging(avg_fn=get_ema_avg_fn(0.9999)))

to my call backs. I will get the following errors:

concurrent.futures.process._RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/queues.py", line 244, in _feed
    obj = _ForkingPickler.dumps(obj)
  File "/usr/lib/python3.10/multiprocessing/reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'get_ema_avg_fn.<locals>.ema_update'
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/root/Theia/train.py", line 139, in <module>
    main()
  File "/root/Theia/train.py", line 135, in main
    trainer.fit(model, datamodule=data)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 567, in fit
    call._call_and_handle_interrupt(
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/strategies/launchers/xla.py", line 98, in launch
    process_context = xmp.spawn(
  File "/usr/local/lib/python3.10/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 39, in spawn
    return pjrt.spawn(fn, nprocs, start_method, args)
  File "/usr/local/lib/python3.10/dist-packages/torch_xla/_internal/pjrt.py", line 213, in spawn
    run_multiprocess(spawn_fn, start_method=start_method)
  File "/usr/local/lib/python3.10/dist-packages/torch_xla/_internal/pjrt.py", line 169, in run_multiprocess
    replica_results = list(
  File "/usr/local/lib/python3.10/dist-packages/torch_xla/_internal/pjrt.py", line 170, in <genexpr>
    itertools.chain.from_iterable(
  File "/usr/lib/python3.10/concurrent/futures/process.py", line 575, in _chain_from_iterable_of_lists
    for element in iterable:
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 621, in result_iterator
    yield _result_or_cancel(fs.pop())
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 319, in _result_or_cancel
    return fut.result(timeout)
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 458, in result
    return self.__get_result()
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
    raise self._exception
  File "/usr/lib/python3.10/multiprocessing/queues.py", line 244, in _feed
    obj = _ForkingPickler.dumps(obj)
  File "/usr/lib/python3.10/multiprocessing/reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'get_ema_avg_fn.<locals>.ema_update'

The code runs fine with single GPU.

If I pull the decay function out from pytorch's get_ema_avg_fn:

@torch.no_grad()
def ema_update(ema_param: Tensor, current_param: Tensor, num_averaged):
    decay = 0.9999
    return decay * ema_param + (1 - decay) * current_param

callbacks.append(WeightAveraging(avg_fn=ema_update))

Everything runs fine. Please let me know if I'm using this wrong.

catalpaaa avatar Apr 11 '25 09:04 catalpaaa

Thanks @catalpaaa . You're using it correctly. I noticed the same thing with ddp_spawn, so I had to use a similar workaround in the unit tests.

The problem seems to be caused by get_ema_avg_fn() returning a closure. I don't know if there's anything I can do about it. If get_ema_avg_fn was a class instead of a function that returns a closure, like in the unit tests, the problem would be solved. Since the whole point was to avoid duplicating code between pytorch and lightning, maybe it would be best to fix this in pytorch.

senarvi avatar Apr 12 '25 06:04 senarvi

I guess pickle just hates function in function :(

catalpaaa avatar Apr 12 '25 09:04 catalpaaa

*Another respectful bump

amorehead avatar Apr 23 '25 20:04 amorehead

@senarvi, while testing out your callback in a local codebase of mine, I discovered an edge case that should be simple to address. Namely, when one is using a LightningModule's configure_model hook to efficiently initialize one's model weights (e.g., when the model is too large to fit into CPU memory), this callback will currently try to wrap pl_module inside AveragedModel without the model weights loaded (since configure_callbacks is called in Lightning before configure_model is). As such, to ensure that users of Lightning are always loading their model's weights before trying to wrap them in AveragedModel, you can simply change the setup hook for WeightAveraging to read as follows:

@override
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
    """Called when fit, validate, test, predict, or tune begins.

    Creates an `AveragedModel` when fit begins.

    Args:
        trainer: The current `~lightning.pytorch.trainer.trainer.Trainer` instance.
        pl_module: The current `~lightning.pytorch.core.LightningModule` instance.
        stage: The `~lightning.pytorch.trainer.trainer.Trainer` state.
    """
    if stage == "fit":
        device = self._device or pl_module.device
        pl_module.configure_model()  # add this to make sure the model is wrapped correctly
        self._average_model = AveragedModel(
            model=pl_module,
            device=device,
            use_buffers=self._use_buffers,
            **self._kwargs,
        )

amorehead avatar Apr 24 '25 16:04 amorehead

Also, I've filled in the EMAWeightAveraging example into what may be a nice default callback configuration for users. Feel free to have a look:

import torch

from torch.optim.swa_utils import get_ema_avg_fn
from typing import Optional, Union


class EMAWeightAveraging(WeightAveraging):
    """Exponential Moving Average (EMA) Weight Averaging callback."""

    def __init__(
        self,
        device: Optional[Union[torch.device, str, int]] = "cpu",
        use_buffers: bool = True,
        decay: float = 0.999,
        update_every_n_steps: int = 1,
        update_starting_at_step: Optional[int] = None,
        update_starting_at_epoch: Optional[int] = None,
        **kwargs: Any,
    ):
        super().__init__(
            device=device,
            use_buffers=use_buffers,
            **kwargs,
            avg_fn=get_ema_avg_fn(decay=decay),
        )

        self.update_every_n_steps = update_every_n_steps
        self.update_starting_at_step = update_starting_at_step
        self.update_starting_at_epoch = update_starting_at_epoch

    def should_update(self, step_idx: Optional[int] = None, epoch_idx: Optional[int] = None):
        """Decide when to update the model weights.

        Args:
            step_idx: The current step index.
            epoch_idx: The current epoch index.
        Returns:
            bool: True if the model weights should be updated, False otherwise.
        """
        if step_idx is not None:
            # Check step-based conditions only if we have a valid step_idx
            meets_step_requirement = (
                self.update_starting_at_step is None
                or step_idx >= self.update_starting_at_step
            )
            meets_step_frequency = (
                self.update_every_n_steps > 0
                and step_idx % self.update_every_n_steps == 0
            )
            if meets_step_requirement and meets_step_frequency:
                return True

        if epoch_idx is not None:
            # Check epoch-based condition only if we specify one
            meets_epoch_requirement = (
                self.update_starting_at_epoch is not None
                and epoch_idx >= self.update_starting_at_epoch
            )
            if meets_epoch_requirement:
                return True

        return False

amorehead avatar Apr 24 '25 17:04 amorehead

That's a very good catch @amorehead ! Thanks for testing this before it's merged.

I wasn't familiar with the configure_model hook, so I wanted to understand that we're using it correctly. According to the documentation, it's called in a "strategy and precision aware context", so that when using a sharded strategy, the model is sharded instantly. Normally it's called in _call_configure_model() like this:

with (
    trainer.strategy.tensor_init_context(),
    trainer.strategy.model_sharded_context(),
    trainer.precision_plugin.module_init_context(),
):
    _call_lightning_module_hook(trainer, "configure_model")

Should we call _call_configure_model(trainer) instead of pl_module.configure_model()? Otherwise, I guess the model is not sharded at this point.

senarvi avatar Apr 24 '25 18:04 senarvi

@senarvi, you are right! I just noticed that the model will not be sharded unless Lightning's private method lightning.pytorch.trainer.call._call_configure_model is called instead of pl_module.configure_model. However, I don't think this will be enough to fully support model-parallel training strategies (such as FSDP2, which I'm currently trying to test with this callback), since each GPU rank will contain only a portion of the model weights (and there's currently only a single _average_model instance being updated).

  1. Currently, if one tries to wrap the pl_module with AveragedModel after sharding the model with _call_configure_model, they will likely run into this PyTorch issue raised by FSDP's use of copy.deepcopy on the model weights.
  2. Another approach is to follow this recent PyTorch issue and try to wrap the (full, potentially large weights) pl_module with AveragedModel before sharding the model weights (by coincidentally calling the original pl_module.configure_model method). Then, the goal would be to (identically) shard both the pl_module's model weights as well as the _average_model weights, such that calls to self._average_model.update_parameters in on_train_batch_end and on_train_epoch_end would yield updates to the sharded parameters stored on each GPU rank.

To implement approach 2 above, we would need to (1) keep the call to pl_module.configure_model as it is and (2) figure out how to shard self._average_model identically as the model weights in pl_module are sharded. One naive approach to sharding self._average_model would be to refactor this callback to instead store it as a property of pl_module so it gets automatically sharded by Lightning's (automatic) call to lightning.pytorch.trainer.call._call_configure_model, but this could get tricky and may complicate this initial implementation of the callback (which should work well for DDP-based (data-parallel) training).

Update: On further reflection, I think the only way to make this callback compatible with model-parallel training would be to store _average_model as a property of pl_module everywhere in this callback's source code. Otherwise, we would be assuming there is only one copy of the averaged model weights, when in fact there can be many averaged (fragment) model weights. Fortunately, it looks like everywhere _average_model is referenced, a pl_module instance is available. Now, what I'm unsure of is whether this pl_module instance (when training with model parallelism) contains only the model weights belonging to the shard of an individual GPU rank (or whether this points to the full model weights somehow).

amorehead avatar Apr 24 '25 19:04 amorehead

@amorehead , if I understand correctly, each GPU rank is running in its own process that contains a portion (shard) of the model weights and an AveragedModel instance. Because of issue 1, AveragedModel is not able to copy the parameters of an FSDP2 model, but if that issue is solved, the AveragedModel of each GPU rank would contain averaged values of the corresponding shard. In theory, we should be able to construct the final averaged model by gathering the AveragedModel weights from all GPU ranks, right? This should be done before saving a checkpoint, or we'll end up saving only the shard of rank 0.

The other option is to call configure_model() without model_sharded_context, then construct the AveragedModel, and finally shard both models. Your suggesting that if we simply store the AveragedModel as a property of pl_module, it gets sharded automatically. Two questions come to my mind:

  1. Is the sharding of the averaged parameters guaranteed to be identical to the sharding of the original parameters?
  2. If you need to use FSDP, you probably want the averaged parameters to be on CPU, right? I think this would cause the averaged parameters to be stored on the same device as the original parameters.

senarvi avatar Apr 24 '25 21:04 senarvi

@senarvi, for the time being, it seems like the reference implementation mentioned in this GitHub issue is a standardized (PyTorch-tested) way of running EMA weight updates with FSDP. At first glance, with this reference implementation, I'm not sure if the EMA weights can be stored on any device other than a GPU (for sake of sharding). Also, this GitHub issue makes me realize that no matter what, one will have to instantiate the full set of model weights, either to initialize the (full) EMA weights or to call FSDP.summon_full_params for EMA weight updates after every optimizer step. This means that, in my understanding, EMA necessarily will incur potentially costly GPU/CPU memory usage when gathering all model weights, so very large models probably won't be runnable even with PyTorch's official implementation. Surely someone must have already solved this issue, unless no one training foundation models is using EMA :wink:

Update: It looks like DeepSpeed has implemented their own version of EMA for use with their Zero Stage 3 (model-parallel) training strategy. It seems they also run the equivalent of FSDP.summon_full_params to update the EMA model weights. Huh, maybe the all-gather operation isn't as memory-intensive as I'm thinking it may be.

amorehead avatar Apr 24 '25 23:04 amorehead

@amorehead , we don't need to gather all the weights at once, right? To me it looks like DeepSpeed is gathering one parameter at a time:

for param, param_ema in zip(model.parameters(), model_ema.parameters()):
	params_to_fetch = _z3_params_to_fetch([param, param_ema]) if zero_stage == 3 else []
	should_gather_param = len(params_to_fetch) > 0
	with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=should_gather_param):
	    # Update param_ema using param.

But that would require a change in AveragedModel.update_parameters(). This could work:

params_to_fetch = _z3_params_to_fetch(model.parameters() + model_ema.parameters()) if zero_stage == 3 else []
should_gather_param = len(params_to_fetch) > 0
with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=should_gather_param):
	self._averaged_model.update_parameters(pl_module)

But I guess that would gather all the parameters at once.

I would be happy, at this point, to have something that works, even if it's not the most memory-efficient way. I just wouldn't want the code look like this:

if fsdp:
    with FSDP.summon_full_params(pl_module):
        pl_module._averaged_model.update_parameters(pl_module.current_model)
elif deepspeed:
    params_to_fetch = _z3_params_to_fetch(pl_module.parameters()) if zero_stage == 3 else []
    should_gather_param = len(params_to_fetch) > 0
    with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=should_gather_param):
	pl_module._averaged_model.update_parameters(pl_module.current_model)
elif ...:

Maybe we could add some hook and let the user decide what to do?

senarvi avatar Apr 25 '25 06:04 senarvi