botorch
botorch copied to clipboard
[Bug] Standardize transform breaks DistributedDataParallel
🐛 Bug
Recently the Standardize
transform was modified to better support checkpoint loading with torch.load
. The solution proposed was to add a boolean tensor _is_trained
to the module's state dict. Unfortunately this solution breaks DistributedDataParallel. In particular the communication backends (gloo on CPU and nccl on CUDA) really only want float tensors in the model state dict.
I'm not sure how important it is to the Botorch devs to support DDP, but I thought I would flag the issue in case someone else runs into it. For myself I'm going to have to figure out a workaround in any case.
To reproduce
** Code snippet to reproduce **
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from botorch.models.transforms.outcome import Standardize
class Model(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 10)
self.transform = Standardize(m=10)
# self.transform = None
def forward(self, x):
x = self.linear(x)
return x if self.transform is None else self.transform(x)
def example(rank, world_size):
# create default process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
# create local model
# model = nn.Linear(10, 10).to(rank)
model = Model().to(rank)
# construct DDP model
ddp_model = DDP(model, device_ids=[rank])
# define loss function and optimizer
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
# forward pass
outputs = ddp_model(torch.randn(20, 10).to(rank))
labels = torch.randn(20, 10).to(rank)
# backward pass
loss_fn(outputs, labels).backward()
# update parameters
optimizer.step()
def main():
world_size = 2
mp.spawn(example,
args=(world_size,),
nprocs=world_size,
join=True)
if __name__=="__main__":
# Environment variables which need to be
# set when using c10d's default "env"
# initialization mode.
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
main()
print("success")
** Stack trace/error message **
Traceback (most recent call last):
File "/gpfs/scratchfs01/site/u/stantos5/code/remote/prescient/cortex/scripts/botorch_ddp_bug.py", line 56, in <module>
main()
File "/gpfs/scratchfs01/site/u/stantos5/code/remote/prescient/cortex/scripts/botorch_ddp_bug.py", line 45, in main
mp.spawn(example,
File "/home/stantos5/scratch/miniconda/envs/cortex-env/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 239, in spawn
return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
File "/home/stantos5/scratch/miniconda/envs/cortex-env/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 197, in start_processes
while not context.join():
File "/home/stantos5/scratch/miniconda/envs/cortex-env/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 160, in join
raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:
-- Process 1 terminated with the following error:
Traceback (most recent call last):
File "/home/stantos5/scratch/miniconda/envs/cortex-env/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
fn(i, *args)
File "/gpfs/scratchfs01/site/u/stantos5/code/remote/prescient/cortex/scripts/botorch_ddp_bug.py", line 30, in example
ddp_model = DDP(model, device_ids=[rank])
File "/home/stantos5/scratch/miniconda/envs/cortex-env/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 676, in __init__
_sync_module_states(
File "/home/stantos5/scratch/miniconda/envs/cortex-env/lib/python3.10/site-packages/torch/distributed/utils.py", line 142, in _sync_module_states
_sync_params_and_buffers(
File "/home/stantos5/scratch/miniconda/envs/cortex-env/lib/python3.10/site-packages/torch/distributed/utils.py", line 160, in _sync_params_and_buffers
dist._broadcast_coalesced(
RuntimeError: Invalid scalar type
Expected Behavior
If you remove Standardize
from the model this script runs without errors.
System information
Please complete the following information:
- BoTorch Version: 0.9.4
- GPyTorch Version: 1.11
- PyTorch Version: 2.0.1+cu117
- Computer OS: Red Hat Enterprise Linux 8.6
Additional context
This example is a bit contrived, however I encountered this issue trying to do multi-GPU training of a model using Standardize
with PyTorch lightning. Note that DDP is strongly recommended by the lightning team
To see an example of a similar problem involving non-float dtypes and DDP, see this issue
In addition to the dtype being a problem, using register_buffer
appears to break DDP as well