botorch icon indicating copy to clipboard operation
botorch copied to clipboard

[Bug] Standardize transform breaks DistributedDataParallel

Open samuelstanton opened this issue 1 year ago • 3 comments

🐛 Bug

Recently the Standardize transform was modified to better support checkpoint loading with torch.load. The solution proposed was to add a boolean tensor _is_trained to the module's state dict. Unfortunately this solution breaks DistributedDataParallel. In particular the communication backends (gloo on CPU and nccl on CUDA) really only want float tensors in the model state dict.

I'm not sure how important it is to the Botorch devs to support DDP, but I thought I would flag the issue in case someone else runs into it. For myself I'm going to have to figure out a workaround in any case.

To reproduce

** Code snippet to reproduce **

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP

from botorch.models.transforms.outcome import Standardize


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 10)
        self.transform = Standardize(m=10)
        # self.transform = None
        
    def forward(self, x):
        x = self.linear(x)
        return x if self.transform is None else self.transform(x)


def example(rank, world_size):
    # create default process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)
    # create local model
    # model = nn.Linear(10, 10).to(rank)
    model = Model().to(rank)
    # construct DDP model
    ddp_model = DDP(model, device_ids=[rank])
    # define loss function and optimizer
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    # forward pass
    outputs = ddp_model(torch.randn(20, 10).to(rank))
    labels = torch.randn(20, 10).to(rank)
    # backward pass
    loss_fn(outputs, labels).backward()
    # update parameters
    optimizer.step()

def main():
    world_size = 2
    mp.spawn(example,
        args=(world_size,),
        nprocs=world_size,
        join=True)

if __name__=="__main__":
    # Environment variables which need to be
    # set when using c10d's default "env"
    # initialization mode.
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29500"
    main()
    print("success")

** Stack trace/error message **

Traceback (most recent call last):
  File "/gpfs/scratchfs01/site/u/stantos5/code/remote/prescient/cortex/scripts/botorch_ddp_bug.py", line 56, in <module>
    main()
  File "/gpfs/scratchfs01/site/u/stantos5/code/remote/prescient/cortex/scripts/botorch_ddp_bug.py", line 45, in main
    mp.spawn(example,
  File "/home/stantos5/scratch/miniconda/envs/cortex-env/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 239, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/home/stantos5/scratch/miniconda/envs/cortex-env/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 197, in start_processes
    while not context.join():
  File "/home/stantos5/scratch/miniconda/envs/cortex-env/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 160, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/home/stantos5/scratch/miniconda/envs/cortex-env/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/gpfs/scratchfs01/site/u/stantos5/code/remote/prescient/cortex/scripts/botorch_ddp_bug.py", line 30, in example
    ddp_model = DDP(model, device_ids=[rank])
  File "/home/stantos5/scratch/miniconda/envs/cortex-env/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 676, in __init__
    _sync_module_states(
  File "/home/stantos5/scratch/miniconda/envs/cortex-env/lib/python3.10/site-packages/torch/distributed/utils.py", line 142, in _sync_module_states
    _sync_params_and_buffers(
  File "/home/stantos5/scratch/miniconda/envs/cortex-env/lib/python3.10/site-packages/torch/distributed/utils.py", line 160, in _sync_params_and_buffers
    dist._broadcast_coalesced(
RuntimeError: Invalid scalar type

Expected Behavior

If you remove Standardize from the model this script runs without errors.

System information

Please complete the following information:

  • BoTorch Version: 0.9.4
  • GPyTorch Version: 1.11
  • PyTorch Version: 2.0.1+cu117
  • Computer OS: Red Hat Enterprise Linux 8.6

Additional context

This example is a bit contrived, however I encountered this issue trying to do multi-GPU training of a model using Standardize with PyTorch lightning. Note that DDP is strongly recommended by the lightning team

To see an example of a similar problem involving non-float dtypes and DDP, see this issue

In addition to the dtype being a problem, using register_buffer appears to break DDP as well

samuelstanton avatar Nov 22 '23 20:11 samuelstanton