botorch
botorch copied to clipboard
[Bug] Standardize transform breaks DistributedDataParallel
🐛 Bug
Recently the Standardize transform was modified to better support checkpoint loading with torch.load. The solution proposed was to add a boolean tensor _is_trained to the module's state dict. Unfortunately this solution breaks DistributedDataParallel. In particular the communication backends (gloo on CPU and nccl on CUDA) really only want float tensors in the model state dict.
I'm not sure how important it is to the Botorch devs to support DDP, but I thought I would flag the issue in case someone else runs into it. For myself I'm going to have to figure out a workaround in any case.
To reproduce
** Code snippet to reproduce **
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from botorch.models.transforms.outcome import Standardize
class Model(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 10)
self.transform = Standardize(m=10)
# self.transform = None
def forward(self, x):
x = self.linear(x)
return x if self.transform is None else self.transform(x)
def example(rank, world_size):
# create default process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
# create local model
# model = nn.Linear(10, 10).to(rank)
model = Model().to(rank)
# construct DDP model
ddp_model = DDP(model, device_ids=[rank])
# define loss function and optimizer
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
# forward pass
outputs = ddp_model(torch.randn(20, 10).to(rank))
labels = torch.randn(20, 10).to(rank)
# backward pass
loss_fn(outputs, labels).backward()
# update parameters
optimizer.step()
def main():
world_size = 2
mp.spawn(example,
args=(world_size,),
nprocs=world_size,
join=True)
if __name__=="__main__":
# Environment variables which need to be
# set when using c10d's default "env"
# initialization mode.
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
main()
print("success")
** Stack trace/error message **
Traceback (most recent call last):
File "/gpfs/scratchfs01/site/u/stantos5/code/remote/prescient/cortex/scripts/botorch_ddp_bug.py", line 56, in <module>
main()
File "/gpfs/scratchfs01/site/u/stantos5/code/remote/prescient/cortex/scripts/botorch_ddp_bug.py", line 45, in main
mp.spawn(example,
File "/home/stantos5/scratch/miniconda/envs/cortex-env/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 239, in spawn
return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
File "/home/stantos5/scratch/miniconda/envs/cortex-env/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 197, in start_processes
while not context.join():
File "/home/stantos5/scratch/miniconda/envs/cortex-env/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 160, in join
raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:
-- Process 1 terminated with the following error:
Traceback (most recent call last):
File "/home/stantos5/scratch/miniconda/envs/cortex-env/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
fn(i, *args)
File "/gpfs/scratchfs01/site/u/stantos5/code/remote/prescient/cortex/scripts/botorch_ddp_bug.py", line 30, in example
ddp_model = DDP(model, device_ids=[rank])
File "/home/stantos5/scratch/miniconda/envs/cortex-env/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 676, in __init__
_sync_module_states(
File "/home/stantos5/scratch/miniconda/envs/cortex-env/lib/python3.10/site-packages/torch/distributed/utils.py", line 142, in _sync_module_states
_sync_params_and_buffers(
File "/home/stantos5/scratch/miniconda/envs/cortex-env/lib/python3.10/site-packages/torch/distributed/utils.py", line 160, in _sync_params_and_buffers
dist._broadcast_coalesced(
RuntimeError: Invalid scalar type
Expected Behavior
If you remove Standardize from the model this script runs without errors.
System information
Please complete the following information:
- BoTorch Version: 0.9.4
- GPyTorch Version: 1.11
- PyTorch Version: 2.0.1+cu117
- Computer OS: Red Hat Enterprise Linux 8.6
Additional context
This example is a bit contrived, however I encountered this issue trying to do multi-GPU training of a model using Standardize with PyTorch lightning. Note that DDP is strongly recommended by the lightning team
To see an example of a similar problem involving non-float dtypes and DDP, see this issue
In addition to the dtype being a problem, using register_buffer appears to break DDP as well
this is a rough solution
class DistributableStandardize(Standardize):
r"""Standardize outcomes (zero mean, unit variance).
This module is stateful: If in train mode, calling forward updates the
module state (i.e. the mean/std normalizing constants). If in eval mode,
calling forward simply applies the standardization using the current module
state.
"""
def __init__(
self,
m: int,
outputs: Optional[list[int]] = None,
batch_shape: torch.Size = torch.Size(), # noqa: B008
min_stdv: float = 1e-8,
) -> None:
r"""Standardize outcomes (zero mean, unit variance).
Args:
m: The output dimension.
outputs: Which of the outputs to standardize. If omitted, all
outputs will be standardized.
batch_shape: The batch_shape of the training targets.
min_stddv: The minimum standard deviation for which to perform
standardization (if lower, only de-mean the data).
"""
OutcomeTransform.__init__(self)
self.register_parameter(
"means", nn.Parameter(torch.zeros(*batch_shape, 1, m), requires_grad=False)
)
self.register_parameter(
"stdvs", nn.Parameter(torch.ones(*batch_shape, 1, m), requires_grad=False)
)
self.register_parameter(
"_stdvs_sq", nn.Parameter(torch.ones(*batch_shape, 1, m), requires_grad=False)
)
self.register_parameter(
"_is_trained", nn.Parameter(torch.tensor(0.0), requires_grad=False)
)
self._outputs = normalize_indices(outputs, d=m)
self._m = m
self._batch_shape = batch_shape
self._min_stdv = min_stdv
def load_state_dict(
self, state_dict: Mapping[str, Any], strict: bool = True
) -> None:
r"""Custom logic for loading the state dict."""
if "_is_trained" not in state_dict:
warnings.warn(
"Key '_is_trained' not found in state_dict. Setting to True. "
"In a future release, this will result in an error.",
DeprecationWarning,
)
state_dict = {**state_dict, "_is_trained": torch.tensor(1.0)}
super().load_state_dict(state_dict, strict=strict)
def forward(
self, Y: Tensor, Yvar: Optional[Tensor] = None
) -> tuple[Tensor, Optional[Tensor]]:
r"""Standardize outcomes.
If the module is in train mode, this updates the module state (i.e. the
mean/std normalizing constants). If the module is in eval mode, simply
applies the normalization using the module state.
Args:
Y: A `batch_shape x n x m`-dim tensor of training targets.
Yvar: A `batch_shape x n x m`-dim tensor of observation noises
associated with the training targets (if applicable).
Returns:
A two-tuple with the transformed outcomes:
- The transformed outcome observations.
- The transformed observation noise (if applicable).
"""
if self.training:
if Y.shape[:-2] != self._batch_shape:
raise RuntimeError(
f"Expected Y.shape[:-2] to be {self._batch_shape}, matching "
"the `batch_shape` argument to `Standardize`, but got "
f"Y.shape[:-2]={Y.shape[:-2]}."
)
if Y.size(-1) != self._m:
raise RuntimeError(
f"Wrong output dimension. Y.size(-1) is {Y.size(-1)}; expected "
f"{self._m}."
)
stdvs = Y.std(dim=-2, keepdim=True)
stdvs = stdvs.where(stdvs >= self._min_stdv, torch.full_like(stdvs, 1.0))
means = Y.mean(dim=-2, keepdim=True)
if self._outputs is not None:
unused = [i for i in range(self._m) if i not in self._outputs]
means[..., unused] = 0.0
stdvs[..., unused] = 1.0
self.means.data = means
self.stdvs.data = stdvs
self._stdvs_sq.data = stdvs.pow(2)
self._is_trained.data = torch.tensor(1.0)
Y_tf = (Y - self.means) / self.stdvs
Yvar_tf = Yvar / self._stdvs_sq if Yvar is not None else None
return Y_tf, Yvar_tf
Hi @samuelstanton. DDP not supporting non-float tensors and buffers seems like a strange design to me. We could replace the bool tensor with float tensor but I doubt that this is a one-off usage in BoTorch. On the other hand, I know that register_buffer is commonly used throughout BoTorch and I'd be hesitant to replace it without a strong motivator.
Overall, these seem like issues that should be resolved on the side of DDP. It seems strange to me that it wouldn't support such core parts of PyTorch.
Seems strange to me too. It's up to the BoTorch core team how you want to deal with this, it just took me a while to figure this out so I thought I'd share. Standardize is a simple enough utility that I could always just migrate to something else. However I could imagine this coming up if e.g. someone wanted to scale up SVGP training on multiple devices. Maybe we could touch base with whoever does distributed torch development and see if there's a solution on the roadmap.