MinkowskiEngine icon indicating copy to clipboard operation
MinkowskiEngine copied to clipboard

Randomness found during training on GPU

Open lidq92 opened this issue 3 years ago • 5 comments

Randomness found during training on GPU Randomness is found when training the MEmodel with GPU (A100 or 2080Ti or P40 or T4) on the same enviroment (same machine with Ubuntu 18.04, torch==1.12.1, MinkowskiEngine==0.5.4 (installed with system python) with CUDA 10.2)

I'm confused about the randomness problem, and what might have caused it?

Thanks for your help.

Best regards, Dingquan


To Reproduce

  • [Skip] reproducible on both CPU and GPU without MinkowskiEngine used (i.e., torch only)

  • [Skip] reproducible on CPU for the same OMP_NUM_THREADS

CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=1 python MWE.py
CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=16 python MWE.py
  • [To Reproduce] not reproducible on GPU with MinkowskiEngine used
# Run the command several times
CUDA_VISIBLE_DEVICES=0 OMP_NUM_THREADS=1 python MWE.py
CUDA_VISIBLE_DEVICES=0 OMP_NUM_THREADS=16 python MWE.py
  • a minimally reproducible code.
import os
import torch
import random
import argparse
import numpy as np
import torch.nn as nn
import MinkowskiEngine as ME


class MWEDataset(torch.utils.data.Dataset):
    def __init__(self):
        super(MWEDataset, self).__init__()

    def __len__(self):
        return 8

    def __getitem__(self, idx):
        coords = torch.randint(0, 1023, (10000,3), dtype=torch.int)
        feats = torch.rand((10000,3), dtype=torch.float)
        label = np.random.rand(1, 1).astype(np.float32)

        return {
            "coordinates": coords, 
            "features": feats,
            "labels": label,
        }


def minkowski_collate_fn(list_data):
    coordinates_batch, features_batch, labels_batch = ME.utils.sparse_collate(
        [d["coordinates"] for d in list_data],
        [d["features"] for d in list_data],
        [d["labels"] for d in list_data],
        dtype=torch.float32,
    )
    
    return {
        "coordinates": coordinates_batch,
        "features": features_batch,
        "labels": labels_batch,
    }


def global_avg_pool(inputs):
    batch_size = torch.max(inputs.coordinates[:, 0]).item() + 1
    outputs = []
    for k in range(batch_size):
        input = inputs.features[inputs.coordinates[:, 0] == k]
        output = torch.mean(input, dim=0)
        outputs.append(output)
    outputs = torch.stack(outputs, dim=0)
    return outputs


class MWEModel(ME.MinkowskiNetwork):
    def __init__(self, D=3, CHANNELS=[3, 3, 3, 1]): 
        ME.MinkowskiNetwork.__init__(self, D)
        self.conv1 = ME.MinkowskiConvolution(CHANNELS[0], CHANNELS[1], kernel_size=3, dimension=D)
        self.bn1 = ME.MinkowskiBatchNorm(CHANNELS[1])
        self.relu1 = ME.MinkowskiReLU()
        self.pool1 = ME.MinkowskiMaxPooling(kernel_size=3, stride=2, dimension=D)
        self.gap = ME.MinkowskiGlobalAvgPooling()
        self.feature = ME.MinkowskiToFeature()
        self.fc1 = nn.Linear(CHANNELS[1], CHANNELS[2])
        self.bn = nn.BatchNorm1d(CHANNELS[2])
        self.fc2 = nn.Linear(CHANNELS[2], CHANNELS[3])
        
    def forward(self, x):
        # out = x # not OK (at least two results are observed among several runs)
        # out = self.relu1(x) # randomness observed after several runs
        # out = self.relu1(self.bn1(x)) # randomness (no consistency among several runs)
        # out = self.relu1(self.bn1(self.conv1(x))) # randomness (no consistency among several runs)
        # out = self.pool1(self.relu1(self.bn1(x))) # randomness (no consistency among several runs)
        out = self.pool1(self.relu1(self.bn1(self.conv1(x)))) # randomness (no consistency among several runs)

        out = self.gap(out).F # self.feature(self.gap(out)) # global_avg_pool(out) # 

        # out = torch.rand((2,3), device=x.device) # reproducible on both CPU and GPU without MinkowskiEngine used (i.e., torch only)
        out = self.fc2(self.bn(self.fc1(out)))
        
        return out


def run(config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = MWEModel().to(device)  
    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
    loss_func = torch.nn.SmoothL1Loss()
    trainset = MWEDataset()
    
    def seed_worker(worker_id):
        worker_seed = torch.initial_seed() % 2**32
        np.random.seed(worker_seed)
        random.seed(worker_seed)

    g = torch.Generator()
    g.manual_seed(0)
    train_loader = torch.utils.data.DataLoader(trainset, config.batch_size, shuffle=True, 
                                               num_workers=1, worker_init_fn=seed_worker, generator=g,
                                               collate_fn=minkowski_collate_fn)
    
    model.train()
    for k in range(config.max_epoch):
        for i, batch in enumerate(train_loader):
            coords, feats, labels = batch["coordinates"], batch["features"], batch["labels"] 
            x = ME.SparseTensor(feats, coords.int(), device=device) 
            y = labels.to(device)
            y_pred = model(x)
            loss = loss_func(y_pred, y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            print(loss.item())
        

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='MWEtestME')
    parser.add_argument("--seed", type=int, default=19920517)
    parser.add_argument('--batch_size', type=int, default=2)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--max_epoch', type=int, default=10)
    config = parser.parse_args()
    
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.enabled = False
    np.random.seed(config.seed)
    random.seed(config.seed)
    os.environ['PYTHONHASHSEED'] = str(config.seed)
    
    run(config)


Expected behavior Randomness should not be observed among several runs with the same command.


Desktop

  • OS: Ubuntu 18.04
  • Python version: 3.7.10
  • Pytorch version: 1.12.1
  • CUDA version: 10.2
  • NVIDIA Driver version: 450.80
  • Minkowski Engine version: 0.5.4
  • Output of the following command. (If you installed the latest MinkowskiEngine, paste the output of python -c "import MinkowskiEngine as ME; ME.print_diagnostics()". Otherwise, paste the output of the following command.)
==========System==========
Linux-4.4.0-142-generic-x86_64-with-debian-buster-sid
DISTRIB_ID=Ubuntu
DISTRIB_RELEASE=18.04
DISTRIB_CODENAME=bionic
DISTRIB_DESCRIPTION="Ubuntu 18.04.5 LTS"
3.7.10 (default, Feb 26 2021, 18:47:35) 
[GCC 7.3.0]
==========Pytorch==========
1.12.1
torch.cuda.is_available(): True
==========NVIDIA-SMI==========
/usr/bin/nvidia-smi
Driver Version 450.80.02
CUDA Version 11.0
VBIOS Version 90.04.96.00.01
Image Version G183.0200.00.02
==========NVCC==========
/usr/local/cuda/bin/nvcc
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2019 NVIDIA Corporation
Built on Wed_Oct_23_19:24:38_PDT_2019
Cuda compilation tools, release 10.2, V10.2.89
==========CC==========
/usr/bin/c++
c++ (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Copyright (C) 2017 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

==========MinkowskiEngine==========
0.5.4
MinkowskiEngine compiled with CUDA Support: True
NVCC version MinkowskiEngine is compiled: 10020
CUDART version MinkowskiEngine is compiled: 10020

lidq92 avatar Nov 09 '22 10:11 lidq92

Another machine (that reproduced the randomness) info

==========System==========
Linux-4.4.0-176-generic-x86_64-with-debian-buster-sid
DISTRIB_ID=Ubuntu
DISTRIB_RELEASE=18.04
DISTRIB_CODENAME=bionic
DISTRIB_DESCRIPTION="Ubuntu 18.04.4 LTS"
3.7.10 | packaged by conda-forge | (default, Feb 19 2021, 16:07:37) 
[GCC 9.3.0]
==========Pytorch==========
1.9.0
torch.cuda.is_available(): True
==========NVIDIA-SMI==========
/usr/bin/nvidia-smi
Driver Version 440.33.01
CUDA Version 10.2
VBIOS Version 86.02.23.00.01
Image Version G610.0200.00.03
==========NVCC==========
/usr/local/cuda/bin/nvcc
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2019 NVIDIA Corporation
Built on Wed_Oct_23_19:24:38_PDT_2019
Cuda compilation tools, release 10.2, V10.2.89
==========CC==========
/usr/bin/c++
c++ (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Copyright (C) 2017 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

==========MinkowskiEngine==========
0.5.4
MinkowskiEngine compiled with CUDA Support: True
NVCC version MinkowskiEngine is compiled: 10020
CUDART version MinkowskiEngine is compiled: 10020

Besides, on this machine, results are the same on CPU no matter which value is set for OMP_NUM_THREADS

lidq92 avatar Nov 09 '22 11:11 lidq92

Hello,

I have the same issue on my own model using MinkowskiEngine with all the metrics values during training.

I tried to reproduce your issue, and I also observe the same problem using your code. The scale of differences seems to be lower using your simple example than the differences I have with the big models I am using in my project.

I am very interested if you have found a solution because it is indeed a serious problem for reproducibility even if the values remain "somewhat" similar.

loic-lb avatar Jun 13 '23 16:06 loic-lb

#554 solved a similar issue for me. Essentially, sorting by coordinates before passing a tensor through each layer produced deterministic outputs (for my tests)

mic-rud avatar Oct 15 '23 15:10 mic-rud