botorch icon indicating copy to clipboard operation
botorch copied to clipboard

[Bug] Standardize transform breaks DistributedDataParallel

Open samuelstanton opened this issue 2 years ago • 3 comments

🐛 Bug

Recently the Standardize transform was modified to better support checkpoint loading with torch.load. The solution proposed was to add a boolean tensor _is_trained to the module's state dict. Unfortunately this solution breaks DistributedDataParallel. In particular the communication backends (gloo on CPU and nccl on CUDA) really only want float tensors in the model state dict.

I'm not sure how important it is to the Botorch devs to support DDP, but I thought I would flag the issue in case someone else runs into it. For myself I'm going to have to figure out a workaround in any case.

To reproduce

** Code snippet to reproduce **

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP

from botorch.models.transforms.outcome import Standardize


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 10)
        self.transform = Standardize(m=10)
        # self.transform = None
        
    def forward(self, x):
        x = self.linear(x)
        return x if self.transform is None else self.transform(x)


def example(rank, world_size):
    # create default process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)
    # create local model
    # model = nn.Linear(10, 10).to(rank)
    model = Model().to(rank)
    # construct DDP model
    ddp_model = DDP(model, device_ids=[rank])
    # define loss function and optimizer
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    # forward pass
    outputs = ddp_model(torch.randn(20, 10).to(rank))
    labels = torch.randn(20, 10).to(rank)
    # backward pass
    loss_fn(outputs, labels).backward()
    # update parameters
    optimizer.step()

def main():
    world_size = 2
    mp.spawn(example,
        args=(world_size,),
        nprocs=world_size,
        join=True)

if __name__=="__main__":
    # Environment variables which need to be
    # set when using c10d's default "env"
    # initialization mode.
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29500"
    main()
    print("success")

** Stack trace/error message **

Traceback (most recent call last):
  File "/gpfs/scratchfs01/site/u/stantos5/code/remote/prescient/cortex/scripts/botorch_ddp_bug.py", line 56, in <module>
    main()
  File "/gpfs/scratchfs01/site/u/stantos5/code/remote/prescient/cortex/scripts/botorch_ddp_bug.py", line 45, in main
    mp.spawn(example,
  File "/home/stantos5/scratch/miniconda/envs/cortex-env/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 239, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/home/stantos5/scratch/miniconda/envs/cortex-env/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 197, in start_processes
    while not context.join():
  File "/home/stantos5/scratch/miniconda/envs/cortex-env/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 160, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/home/stantos5/scratch/miniconda/envs/cortex-env/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/gpfs/scratchfs01/site/u/stantos5/code/remote/prescient/cortex/scripts/botorch_ddp_bug.py", line 30, in example
    ddp_model = DDP(model, device_ids=[rank])
  File "/home/stantos5/scratch/miniconda/envs/cortex-env/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 676, in __init__
    _sync_module_states(
  File "/home/stantos5/scratch/miniconda/envs/cortex-env/lib/python3.10/site-packages/torch/distributed/utils.py", line 142, in _sync_module_states
    _sync_params_and_buffers(
  File "/home/stantos5/scratch/miniconda/envs/cortex-env/lib/python3.10/site-packages/torch/distributed/utils.py", line 160, in _sync_params_and_buffers
    dist._broadcast_coalesced(
RuntimeError: Invalid scalar type

Expected Behavior

If you remove Standardize from the model this script runs without errors.

System information

Please complete the following information:

  • BoTorch Version: 0.9.4
  • GPyTorch Version: 1.11
  • PyTorch Version: 2.0.1+cu117
  • Computer OS: Red Hat Enterprise Linux 8.6

Additional context

This example is a bit contrived, however I encountered this issue trying to do multi-GPU training of a model using Standardize with PyTorch lightning. Note that DDP is strongly recommended by the lightning team

To see an example of a similar problem involving non-float dtypes and DDP, see this issue

In addition to the dtype being a problem, using register_buffer appears to break DDP as well

samuelstanton avatar Nov 22 '23 20:11 samuelstanton

this is a rough solution

class DistributableStandardize(Standardize):
    r"""Standardize outcomes (zero mean, unit variance).

    This module is stateful: If in train mode, calling forward updates the
    module state (i.e. the mean/std normalizing constants). If in eval mode,
    calling forward simply applies the standardization using the current module
    state.
    """

    def __init__(
        self,
        m: int,
        outputs: Optional[list[int]] = None,
        batch_shape: torch.Size = torch.Size(),  # noqa: B008
        min_stdv: float = 1e-8,
    ) -> None:
        r"""Standardize outcomes (zero mean, unit variance).

        Args:
            m: The output dimension.
            outputs: Which of the outputs to standardize. If omitted, all
                outputs will be standardized.
            batch_shape: The batch_shape of the training targets.
            min_stddv: The minimum standard deviation for which to perform
                standardization (if lower, only de-mean the data).
        """
        OutcomeTransform.__init__(self)
        self.register_parameter(
            "means", nn.Parameter(torch.zeros(*batch_shape, 1, m), requires_grad=False)
        )
        self.register_parameter(
            "stdvs", nn.Parameter(torch.ones(*batch_shape, 1, m), requires_grad=False)
        )
        self.register_parameter(
            "_stdvs_sq", nn.Parameter(torch.ones(*batch_shape, 1, m), requires_grad=False)
        )
        self.register_parameter(
            "_is_trained", nn.Parameter(torch.tensor(0.0), requires_grad=False)
        )
        self._outputs = normalize_indices(outputs, d=m)
        self._m = m
        self._batch_shape = batch_shape
        self._min_stdv = min_stdv

    def load_state_dict(
        self, state_dict: Mapping[str, Any], strict: bool = True
    ) -> None:
        r"""Custom logic for loading the state dict."""
        if "_is_trained" not in state_dict:
            warnings.warn(
                "Key '_is_trained' not found in state_dict. Setting to True. "
                "In a future release, this will result in an error.",
                DeprecationWarning,
            )
            state_dict = {**state_dict, "_is_trained": torch.tensor(1.0)}
        super().load_state_dict(state_dict, strict=strict)

    def forward(
        self, Y: Tensor, Yvar: Optional[Tensor] = None
    ) -> tuple[Tensor, Optional[Tensor]]:
        r"""Standardize outcomes.

        If the module is in train mode, this updates the module state (i.e. the
        mean/std normalizing constants). If the module is in eval mode, simply
        applies the normalization using the module state.

        Args:
            Y: A `batch_shape x n x m`-dim tensor of training targets.
            Yvar: A `batch_shape x n x m`-dim tensor of observation noises
                associated with the training targets (if applicable).

        Returns:
            A two-tuple with the transformed outcomes:

            - The transformed outcome observations.
            - The transformed observation noise (if applicable).
        """
        if self.training:
            if Y.shape[:-2] != self._batch_shape:
                raise RuntimeError(
                    f"Expected Y.shape[:-2] to be {self._batch_shape}, matching "
                    "the `batch_shape` argument to `Standardize`, but got "
                    f"Y.shape[:-2]={Y.shape[:-2]}."
                )
            if Y.size(-1) != self._m:
                raise RuntimeError(
                    f"Wrong output dimension. Y.size(-1) is {Y.size(-1)}; expected "
                    f"{self._m}."
                )
            stdvs = Y.std(dim=-2, keepdim=True)
            stdvs = stdvs.where(stdvs >= self._min_stdv, torch.full_like(stdvs, 1.0))
            means = Y.mean(dim=-2, keepdim=True)
            if self._outputs is not None:
                unused = [i for i in range(self._m) if i not in self._outputs]
                means[..., unused] = 0.0
                stdvs[..., unused] = 1.0
            self.means.data = means
            self.stdvs.data = stdvs
            self._stdvs_sq.data = stdvs.pow(2)
            self._is_trained.data = torch.tensor(1.0)

        Y_tf = (Y - self.means) / self.stdvs
        Yvar_tf = Yvar / self._stdvs_sq if Yvar is not None else None
        return Y_tf, Yvar_tf

samuelstanton avatar Nov 22 '23 21:11 samuelstanton

Hi @samuelstanton. DDP not supporting non-float tensors and buffers seems like a strange design to me. We could replace the bool tensor with float tensor but I doubt that this is a one-off usage in BoTorch. On the other hand, I know that register_buffer is commonly used throughout BoTorch and I'd be hesitant to replace it without a strong motivator.

Overall, these seem like issues that should be resolved on the side of DDP. It seems strange to me that it wouldn't support such core parts of PyTorch.

saitcakmak avatar Nov 22 '23 21:11 saitcakmak

Seems strange to me too. It's up to the BoTorch core team how you want to deal with this, it just took me a while to figure this out so I thought I'd share. Standardize is a simple enough utility that I could always just migrate to something else. However I could imagine this coming up if e.g. someone wanted to scale up SVGP training on multiple devices. Maybe we could touch base with whoever does distributed torch development and see if there's a solution on the roadmap.

samuelstanton avatar Nov 22 '23 22:11 samuelstanton