pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

Add feature Exponential Moving Average (EMA)

Open hankyul2 opened this issue 4 years ago • 52 comments

🚀 Feature

How about add EMA as callback?

Motivation

I have had difficulty in applying ema. I think it would be nice if there are EMA as callback.

Pitch

If user add ema as callback, ema is applied for validation and test.

Alternatives

Of course, you can add ema as tutorial. like below snippets

class EMA(nn.Module):
    """ Model Exponential Moving Average V2 from timm"""
    def __init__(self, model, decay=0.9999):
        super(EMA, self).__init__()
        # make a copy of the model for accumulating moving average of weights
        self.module = deepcopy(model)
        self.module.eval()
        self.decay = decay

    def _update(self, model, update_fn):
        with torch.no_grad():
            for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
                ema_v.copy_(update_fn(ema_v, model_v))

    def update(self, model):
        self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)

    def set(self, model):
        self._update(model, update_fn=lambda e, m: m)
    

class BasicModule(LightningModule):
    def __init__(self, lr=0.01, use_ema=False):
        super().__init__()
        self.model = models.resnet18(pretrained=False)
        self.model_ema = EMA(self.model, decay=0.9) if use_ema else None
        self.criterion = nn.CrossEntropyLoss()
        self.lr = lr
        
        metric = MetricCollection({'top@1': Accuracy(top_k=1), 'top@5': Accuracy(top_k=5)})
        self.train_metric = metric.clone(prefix='train_')
        self.valid_metric = metric.clone(prefix='valid_')
    
    def training_step(self, batch, batch_idx, optimizer_idx=None):
        return self.shared_step(*batch, self.train_metric)

    def validation_step(self, batch, batch_idx):
        return self.shared_step(*batch, self.valid_metric)

    def shared_step(self, x, y, metric):
        y_hat = self.model(x) if self.training or self.model_ema is None else self.model_ema.module(x)
        loss = self.criterion(y_hat, y)
        self.log_dict(metric(y_hat, y), prog_bar=True)
        return loss

    def configure_optimizers(self):
        return SGD(self.model.parameters(), lr=self.lr)

    def on_before_backward(self, loss: torch.Tensor) -> None:
        if self.model_ema:
            self.model_ema.update(self.model)

Additional context


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.

  • Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.

  • Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

cc @borda

hankyul2 avatar Dec 03 '21 05:12 hankyul2

Hi, as stated in https://github.com/PyTorchLightning/pytorch-lightning/issues/8100#issuecomment-867819299 this can be done by replacing just one part of our SWA.

justusschock avatar Dec 03 '21 11:12 justusschock

@justusschock thank you for your reply.

Is there a way to use ema using swa?

I checked the link. For me, swa seems to update lr scheduler and model weights together. Am I right?

hankyul2 avatar Dec 03 '21 11:12 hankyul2

@hankyul2, I believe this is how it would be implemented:

from pytorch_lightning.callbacks import StochasticWeightAveraging

class EMA_Callback(StochasticWeightAveraging):
    def __init__(self, decay=0.9999):
        super().__init__()
        self.decay = decay
    
    def avg_fn (
        averaged_model_parameter: torch.Tensor, model_parameter: torch.Tensor, num_averaged: torch.LongTensor
    ) -> torch.FloatTensor:
        e = averaged_model_parameter
        m = model_parameter
        return self.decay * e + (1. - self.decay) * m

Let me know if you have any problems. I just learned about SWA because of your issue!

mathemusician avatar Dec 04 '21 01:12 mathemusician

@mathemusician thank you.

hankyul2 avatar Dec 04 '21 02:12 hankyul2

@hankyul2 I don't think that @mathemusician solution is equivalent. avg_fn is called once per epoch while EMA updates happens every training step. I don't think that EMA can be implemented with SWA callback,

hal-314 avatar Dec 10 '21 13:12 hal-314

@hal-314 yeah. I think so.

hankyul2 avatar Dec 10 '21 13:12 hankyul2

@hal-314 @mathemusician

I have implemented EMA Callback with simple functionality.

  • update weight
  • change forward() in validation_loop

I think much more options should be added. For example, save_weight, ema_step_period, etc.

If you find it helpful, please let me know. Then I will implement it more. If you don't, close this issue or leave any comments.

from copy import deepcopy

import torch
from pytorch_lightning import Callback


class EMACallback(Callback):
    def __init__(self, decay=0.995):
        self.decay = decay
        self.module_pair_list = []

    def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        def forward_wrapper(module, org, ema):
            def forward(*args, **kwargs):
                return org(*args, **kwargs) if module.training else ema(*args, **kwargs)
            return forward

        modules = list(filter(lambda x: len(list(x[1].parameters())) > 0, pl_module.named_children()))

        for name, module in modules:
            ema_module = deepcopy(module)
            self.module_pair_list.append((ema_module, module))
            pl_module.add_module(f'EMA_{name}', ema_module)
            module.forward_bc = module.forward
            module.forward = forward_wrapper(module, module.forward_bc, ema_module.forward)

    def on_after_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        for ema_module, module in self.module_pair_list:
            self._update(ema_module, module, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)

    def _update(self, ema_module, module, update_fn):
        with torch.no_grad():
            for ema_v, model_v in zip(ema_module.state_dict().values(), module.state_dict().values()):
                ema_v.copy_(update_fn(ema_v, model_v))

hankyul2 avatar Dec 11 '21 05:12 hankyul2

@hankyul2 I like more the approach from timm as it doesn't involve changing forward method. Users could not use the forward method in xxx_step method. So, I would recommend your first code but updating ema weights in on_train_batch_end instead of on_before_backward to match with timm train.py code.

Finally, I manage to implemented EMA as a callback by using state_dicts instead of the whole module as ModelEMAV2. I can try to make a PR or, at least, post the code here if you or anyone is interested (and I get permission for it). It works for 1 GPU and, likely, for Multi GPU though it isn't tested.

hal-314 avatar Dec 13 '21 17:12 hal-314

@hal-314 Cool 😎. I want to see it. Can you share your code?

hankyul2 avatar Dec 13 '21 21:12 hankyul2

@hankyul2 Here is the code. Be aware that you need overrides package installed (pip install overrides). If you don't want it, comment the import and the @overrides decorator. I only use it to be sure that I'm actually overriding the method correctly.

Bits to be aware:

  • I didn't test multi-gpu path though it should work. If not, there are a couple of asserts that will fail. If your setup is multigpu, could you check that works fine?
  • You can override get_state_dict method to filter some parameters that you don't want to include in EMA. For example, those in metrics. See its doc.
  • EMA weights are saved as "ema_state_dict" in the callback state. So, you can retrieve them manually outside lightning.
  • Tested with Pytorch 1.8 & 1.10 and Pythorch Lightning 1.5.4

I hope it's useful.

from copy import deepcopy
from typing import Optional, Union, Dict, Any

import pytorch_lightning as pl
import torch
from overrides import overrides
from pytorch_lightning.utilities import rank_zero_only


class EMA(pl.Callback):
    """Implements EMA (exponential moving average) to any kind of model.
    EMA weights will be used during validation and stored separately from original model weights.

    How to use EMA:
        - Sometimes, last EMA checkpoint isn't the best as EMA weights metrics can show long oscillations in time. See
          https://github.com/rwightman/pytorch-image-models/issues/102
        - Batch Norm layers and likely any other type of norm layers doesn't need to be updated at the end. See
          discussions in: https://github.com/rwightman/pytorch-image-models/issues/106#issuecomment-609461088 and
          https://github.com/rwightman/pytorch-image-models/issues/224
        - For object detection, SWA usually works better. See   https://github.com/timgaripov/swa/issues/16

    Implementation detail:
        - See EMA in Pytorch Lightning: https://github.com/PyTorchLightning/pytorch-lightning/issues/10914
        - When multi gpu, we broadcast ema weights and the original weights in order to only hold 1 copy in memory.
          This is specially relevant when storing EMA weights on CPU + pinned memory as pinned memory is a limited
          resource. In addition, we want to avoid duplicated operations in ranks != 0 to reduce jitter and improve
          performance.
    """
    def __init__(self, decay: float = 0.9999, ema_device: Optional[Union[torch.device, str]] = None, pin_memory=True):
        super().__init__()
        self.decay = decay
        self.ema_device: str = f"{ema_device}" if ema_device else None  # perform ema on different device from the model
        self.ema_pin_memory = pin_memory if torch.cuda.is_available() else False  # Only works if CUDA is available
        self.ema_state_dict: Dict[str, torch.Tensor] = {}
        self.original_state_dict = {}
        self._ema_state_dict_ready = False

    @staticmethod
    def get_state_dict(pl_module: pl.LightningModule):
        """Returns state dictionary from pl_module. Override if you want filter some parameters and/or buffers out.
        For example, in pl_module has metrics, you don't want to return their parameters.
        
        code:
            # Only consider modules that can be seen by optimizers. Lightning modules can have others nn.Module attached
            # like losses, metrics, etc.
            patterns_to_ignore = ("metrics1", "metrics2")
            return dict(filter(lambda i: i[0].startswith(patterns), pl_module.state_dict().items()))
        """
        return pl_module.state_dict()
        
    @overrides
    def on_train_start(self, trainer: "pl.Trainer", pl_module: pl.LightningModule) -> None:
        # Only keep track of EMA weights in rank zero.
        if not self._ema_state_dict_ready and pl_module.global_rank == 0:
            self.ema_state_dict = deepcopy(self.get_state_dict(pl_module))
            if self.ema_device:
                self.ema_state_dict = {k: tensor.to(device=self.ema_device) for k, tensor in self.ema_state_dict.items()}

            if self.ema_device == "cpu" and self.ema_pin_memory:
                self.ema_state_dict = {k: tensor.pin_memory() for k, tensor in self.ema_state_dict.items()}

        self._ema_state_dict_ready = True

    @rank_zero_only
    def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: pl.LightningModule, *args, **kwargs) -> None:
        # Update EMA weights
        with torch.no_grad():
            for key, value in self.get_state_dict(pl_module).items():
                ema_value = self.ema_state_dict[key]
                ema_value.copy_(self.decay * ema_value + (1. - self.decay) * value, non_blocking=True)

    @overrides
    def on_validation_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
        if not self._ema_state_dict_ready:
            return  # Skip Lightning sanity validation check if no ema weights has been loaded from a checkpoint.

        self.original_state_dict = deepcopy(self.get_state_dict(pl_module))
        pl_module.trainer.training_type_plugin.broadcast(self.ema_state_dict, 0)
        assert self.ema_state_dict.keys() == self.original_state_dict.keys(), \
            f"There are some keys missing in the ema static dictionary broadcasted. " \
            f"They are: {self.original_state_dict.keys() - self.ema_state_dict.keys()}"
        pl_module.load_state_dict(self.ema_state_dict, strict=False)

        if pl_module.global_rank > 0:
            # Remove ema state dict from the memory. In rank 0, it could be in ram pinned memory.
            self.ema_state_dict = {}

    @overrides
    def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        if not self._ema_state_dict_ready:
            return  # Skip Lightning sanity validation check if no ema weights has been loaded from a checkpoint.

        # Replace EMA weights with training weights
        pl_module.load_state_dict(self.original_state_dict, strict=False)

    @overrides
    def on_save_checkpoint(
        self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
    ) -> dict:
        return {"ema_state_dict": self.ema_state_dict, "_ema_state_dict_ready": self._ema_state_dict_ready}

    @overrides
    def on_load_checkpoint(
        self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any]
    ) -> None:
        self._ema_state_dict_ready = callback_state["_ema_state_dict_ready"]
        self.ema_state_dict = callback_state["ema_state_dict"]

@justusschock @mathemusician Do you think that lightning should include an EMA callback? Maybe, it could go to bolts.

hal-314 avatar Dec 14 '21 16:12 hal-314

@hal-314 wow... I think it is good.(but I am not a maintainer or something)

hankyul2 avatar Dec 14 '21 17:12 hankyul2

Personally I think we should include this. However, not sure where it belongs (Currently I'd say not to lightning core but either to flash or bolts)

cc @tchaton @ethanwharris for opinions on this

justusschock avatar Dec 15 '21 13:12 justusschock

@hal-314 thanks for your implementation! I agree that EMA would be very useful to have in Lightning.

I tested your implementation with default parameters (so ema_device=None etc) and it seems to work well on a single GPU. In multi-gpu, the assertion in on_validation_start fails on GPUs other than 0. It appears that the ema_state_dict is not broadcast successfully to all devices (it is an empty dict).

flukeskywalker avatar Dec 21 '21 01:12 flukeskywalker

@flukeskywalker Glad to de that it's useful for you :)

If you fix the multi gpu code, could you mind to share the fix? So, others can use it.

hal-314 avatar Dec 21 '21 07:12 hal-314

@hal-314 Thank you for sharing your code. I test it with 2 gpus in ddp mode. When pl_module.trainer.training_type_plugin.broadcast(self.ema_state_dict, 0) is executed, OOM error occurs. Can I ask you how validation steps works in ddp mode?? or any document that I can reference?

Whole error logs is in below.

Traceback (most recent call last):
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 772, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1195, in _run
    self._dispatch()
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1274, in _dispatch
    self.training_type_plugin.start_training(self)
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
    self._results = trainer.run_stage()
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1284, in run_stage
    return self._run_train()
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1314, in _run_train
    self.fit_loop.run()
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/fit_loop.py", line 234, in advance
    self.epoch_loop.run(data_fetcher)
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 146, in run
    self.on_advance_end()
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 242, in on_advance_end
    self._run_validation()
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 337, in _run_validation
    self.val_loop.run()
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 140, in run
    self.on_run_start(*args, **kwargs)
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 95, in on_run_start
    self._on_evaluation_start()
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 179, in _on_evaluation_start
    self.trainer.call_hook("on_validation_start", *args, **kwargs)
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1490, in call_hook
    callback_fx(*args, **kwargs)
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/callback_hook.py", line 216, in on_validation_start
    callback.on_validation_start(self, self.lightning_module)
  File "/home/hankyul/private/SuperConvergence/src/ema.py", line 134, in on_validation_start
    pl_module.trainer.training_type_plugin.broadcast(self.ema_state_dict, 0)
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/ddp.py", line 411, in broadcast
    broadcast_object_list(obj, src, group=_group.WORLD)
  File "/opt/conda/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 1840, in broadcast_object_list
    object_list[i] = _tensor_to_object(obj_view, obj_size)
  File "/opt/conda/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 1532, in _tensor_to_object
    return _unpickler(io.BytesIO(buf)).load()
  File "/opt/conda/lib/python3.7/site-packages/torch/storage.py", line 161, in _load_from_bytes
    return torch.load(io.BytesIO(b))
  File "/opt/conda/lib/python3.7/site-packages/torch/serialization.py", line 608, in load
    return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
  File "/opt/conda/lib/python3.7/site-packages/torch/serialization.py", line 787, in _legacy_load
    result = unpickler.load()
  File "/opt/conda/lib/python3.7/site-packages/torch/serialization.py", line 743, in persistent_load
    deserialized_objects[root_key] = restore_location(obj, location)
  File "/opt/conda/lib/python3.7/site-packages/torch/serialization.py", line 175, in default_restore_location
    result = fn(storage, location)
  File "/opt/conda/lib/python3.7/site-packages/torch/serialization.py", line 155, in _cuda_deserialize
    return storage_type(obj.size())
  File "/opt/conda/lib/python3.7/site-packages/torch/cuda/__init__.py", line 606, in _lazy_new
    return super(_CudaBase, cls).__new__(cls, *args, **kwargs)
RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

hankyul2 avatar Dec 28 '21 07:12 hankyul2

@hankyul2 Sorry but I don't have experience with multi-gpu in lightning. From the stack, it seems that the OOM occurs when broadcasting the state. Here is where I find broadcast docs in Lightning.

hal-314 avatar Dec 28 '21 12:12 hal-314

@hankyul2 I tried the code @hal-314 in my 2-gpus 1080ti machine. And it worked well. Maybe the out of memory error is just literally...

My experiment configuration as below:

Model: BertModelForSequenceClassification
max_length: 50
padding_to_max_length: True
batch_size: 4(per device)

When I trained without ema, the gpu memory usage is 3311MB each. And turned on the ema, gpu 0 was 3851MB, gpu 1 was 3311MB.

sevenights avatar Jan 06 '22 03:01 sevenights

@hal-314 @hankyul2 sorry, I made a mistake in my last code. And I found the broadcast in on_validation_start didn't work as espect. I change it as follow:

@overrides
def on_validation_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
    if not self._ema_state_dict_ready:
        return  # Skip Lightning sanity validation check if no ema weights has been loaded from a checkpoint.

    self.original_state_dict = deepcopy(self.get_state_dict(pl_module))
    ema_state_dict = pl_module.trainer.training_type_plugin.broadcast(self.ema_state_dict, 0)
    self.ema_state_dict = ema_state_dict

since the broadcast does not change variables in-place(in dist.py, when the self.rank != 0 is True). The callstack I found as follows:

pl_module.training_type_plugin.broadcast(self.ema_state_dict, 0)
-> ddp_spawn.broadcast(obj, src)
-> LightningDistributed.broadcast(obj, group)

and source code as follows

# pytorch_lightning/plugins/training_type/ddp_spawn.py
def broadcast(self, obj: object, src: int = 0) -> object:
    if not distributed_available():
        return obj
    return self.dist.broadcast(obj)
# pytorch_lightning/distributed/dist.py
class LightningDistributed:
    def __init__(self, rank=None, device=None):
        self.rank = rank
        self.device = device

    def broadcast(self, obj: Any, group=_group.WORLD):
        # always wrap into a list so it can be broadcasted.
        obj = [obj]

        if self.rank != 0:
            obj = [None] * len(obj)

        broadcast_object_list(obj, 0, group=group or _group.WORLD)

        return obj[0]

the gpu usage showed below:

without ema:

|    0   N/A  N/A     15733      C   ...da3/envs/pl149/bin/python     4645MiB |
|    1   N/A  N/A     15750      C   ...da3/envs/pl149/bin/python     4645MiB |

with ema:

|    0   N/A  N/A     12561      C   ...da3/envs/pl149/bin/python     5509MiB |
|    0   N/A  N/A     12577      C   ...da3/envs/pl149/bin/python     1227MiB |
|    1   N/A  N/A     12577      C   ...da3/envs/pl149/bin/python     5063MiB |

environment: pytorch==1.8.0 pytorch-lightning==1.4.9 gpus: 1080ti * 2

sevenights avatar Jan 07 '22 02:01 sevenights

I can confirm that @sevenights's fix works for me too with pytorch-lightning==1.5.8

flukeskywalker avatar Jan 08 '22 01:01 flukeskywalker

I would love to see an implementation of EMA. I just migrated from ignite to lightning and lacking an EMA callback really holds me back.

yoyolicoris avatar Jan 12 '22 03:01 yoyolicoris

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

stale[bot] avatar Feb 18 '22 00:02 stale[bot]

@hankyul2 It is very likely that the gradient was calculated during validation, which resulted in out of memory. You can try torch.set_grad_enabled(False) in your validation process.

AbyssGaze avatar May 23 '22 09:05 AbyssGaze

@AbyssGaze Thank you for your suggestion.

hankyul2 avatar May 23 '22 12:05 hankyul2

@Borda Hi, any plan on landing this feature?

Ir1d avatar Jun 11 '22 06:06 Ir1d

Picking this back up as @lucidrains related issue requires EMA.

I think we should include this somewhere ASAP. Bolts would be the easiest landing place for such a callback. Any disagreements @Borda @justusschock? If so I can make an issue to get it into Bolts.

SeanNaren avatar Jun 23 '22 11:06 SeanNaren

@SeanNaren Be aware that #5542 prevents to load automatically EMA weights on validation/testing only (trainer.validate / trainer.test).

On those situations, PL doesn't call callbacks.on_load_checkpoint.

To fix it, you will need to use a custom trainer, so you can comment this line https://github.com/Lightning-AI/lightning/blob/master/src/pytorch_lightning/trainer/trainer.py#L1068

hal-314 avatar Jun 23 '22 19:06 hal-314

@SeanNaren Be aware that #5542 prevents to load automatically EMA weights on validation/testing only (trainer.validate / trainer.test).

On those situations, PL doesn't call callbacks.on_load_checkpoint.

To fix it, you will need to use a custom trainer, so you can comment this line https://github.com/Lightning-AI/lightning/blob/master/src/pytorch_lightning/trainer/trainer.py#L1068

Thanks for the heads up! So for fit it would be fine, just an issue for validate/test? Should see what the level of effort is required to fix this.

SeanNaren avatar Jun 23 '22 19:06 SeanNaren

Validation and test sure. Finetune, I don't think so although I didn't check

hal-314 avatar Jun 24 '22 15:06 hal-314

Picking this back up as @lucidrains related issue requires EMA.

I think we should include this somewhere ASAP. Bolts would be the easiest landing place for such a callback. Any disagreements @Borda @justusschock? If so I can make an issue to get it into Bolts.

yeah, so i'm trying to figure out whether this issue is a blocker to using lightning for a project

the project involves a model containing multiple subnetworks. during training, each subnetwork has an EMA versioned that is updated every so number of training steps (say 10)

on validation time, i need to be able to call all the EMAed versions of all the subnetworks sequentially. this does not have to be distributed

will that be doable given this open issue?

lucidrains avatar Jun 30 '22 16:06 lucidrains

will that be doable given this open issue?

yes absolutely, I think because of how close the connection between the EMA weights is to the actual model/subnetworks, I think it would be best to start by adding the logic directly into the pl.LightningModule. This open issue addresses more of a general approach to keeping a EMAed version of the entire model, but in your case this generality isn't necessary.

SeanNaren avatar Jul 01 '22 12:07 SeanNaren