vision icon indicating copy to clipboard operation
vision copied to clipboard

Unexpected difference in torchvision.transforms.Normalize results when using DataLoader with num_workers>=1 and pin_memory=True.

Open Chen-Bo-Yang opened this issue 4 months ago • 1 comments

🐛 Describe the bug

Bug

Unexpected difference in torchvision.transforms.Normalize results when using DataLoader with num_workers>=1 and pin_memory=True.


Description

I found an unexpected behavior when using torch.utils.data.DataLoader.

  • When I set num_workers >= 1 and pin_memory=True, the results of torchvision.transforms.Normalize become slightly different after accessing the DataLoader object.
  • However, if I run the transforms before accessing the DataLoader, the results remain consistent.

Steps to Reproduce

Here is a minimal reproducible example.

Python script
import argparse
import random
import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--workers', help='number of data loading workers', default=0, type=int)
    parser.add_argument(
        '--pin', help='pin memory for DataLoader', action='store_true')
    parser.add_argument(
        '--pre_run', help='run transform before accessing DataLoader', action='store_true')
    # args, rest = parser.parse_known_args()
    args = parser.parse_args()

    return args

class ToyDataset(Dataset):
    def __init__(self, length=10, transform=None):
        self.length = length
        self.transform = transform

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        image = np.ones((512, 960, 3), dtype=np.uint8) * idx
        if self.transform:
            image = self.transform(image)
        label = idx
        return image, label

def try_transform(description=None):
    if description:
        print(f"{description}")

    image = np.ones((512, 960, 3), dtype=np.uint8)

    h, w, c = image.shape
    half = h // 2
    image[:half+1, :, :] = 1
    image[half-1:, :, :] = 0

    transform = transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                        ])

    print(f"\tsum before transform: {image.sum()}")
    image = transform(image)
    print(f"\tsum after transform: {image.sum()}")


if __name__ == "__main__":
    args = parse_args()
    print(f"workers: {args.workers}, pin_memory: {args.pin}, pre_run: {args.pre_run}")

    dataset = ToyDataset(length=1, 
                         transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                        )
    
    toy_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=args.pin)

    if args.pre_run:
        try_transform(description="Use transform before start accessing DataLoader")

    for images, labels in toy_loader:
        pass
        # print(f"Batch images shape: {images.shape}, Labels: {labels}")

    try_transform(description="Use transform after start accessing DataLoader")
    print()
Bash script
echo "1"
python test_torchvision.py --workers=0 --pin
echo "2"
python test_torchvision.py --workers=1 --pin
echo "3"
python test_torchvision.py --workers=0
echo "4"
python test_torchvision.py --workers=1

echo "5"
python test_torchvision.py --workers=0 --pin --pre_run
echo "6"
python test_torchvision.py --workers=1 --pin --pre_run
echo "7"
python test_torchvision.py --workers=0 --pre_run
echo "8"
python test_torchvision.py --workers=1 --pre_run

Observed Behavior

With workers=1 and pin_memory=True, the output of transforms.Normalize differs slightly (e.g., -2915763.0 vs. -2915762.25). In all other cases, the results are consistent.

Example output:

Testing torchvision transforms
1
workers: 0, pin_memory: True, pre_run: False
Use transform after start accessing DataLoader
        sum before transform: 734400
        sum after transform: -2915763.0

2
workers: 1, pin_memory: True, pre_run: False
Use transform after start accessing DataLoader
        sum before transform: 734400
        sum after transform: -2915762.25

3
workers: 0, pin_memory: False, pre_run: False
Use transform after start accessing DataLoader
        sum before transform: 734400
        sum after transform: -2915763.0

4
workers: 1, pin_memory: False, pre_run: False
Use transform after start accessing DataLoader
        sum before transform: 734400
        sum after transform: -2915763.0

5
workers: 0, pin_memory: True, pre_run: True
Use transform before start accessing DataLoader
        sum before transform: 734400
        sum after transform: -2915763.0
Use transform after start accessing DataLoader
        sum before transform: 734400
        sum after transform: -2915763.0

6
workers: 1, pin_memory: True, pre_run: True
Use transform before start accessing DataLoader
        sum before transform: 734400
        sum after transform: -2915763.0
Use transform after start accessing DataLoader
        sum before transform: 734400
        sum after transform: -2915763.0

7
workers: 0, pin_memory: False, pre_run: True
Use transform before start accessing DataLoader
        sum before transform: 734400
        sum after transform: -2915763.0
Use transform after start accessing DataLoader
        sum before transform: 734400
        sum after transform: -2915763.0

8
workers: 1, pin_memory: False, pre_run: True
Use transform before start accessing DataLoader
        sum before transform: 734400
        sum after transform: -2915763.0
Use transform after start accessing DataLoader
        sum before transform: 734400
        sum after transform: -2915763.0

Expected Behavior

The output of torchvision.transforms.Normalize should remain consistent regardless of num_workers or pin_memory settings.

Versions

Collecting environment information... PyTorch version: 2.8.0+cu129 Is debug build: False CUDA used to build PyTorch: 12.9 ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.3 LTS (x86_64) GCC version: Could not collect Clang version: Could not collect CMake version: Could not collect Libc version: glibc-2.39

Python version: 3.12.11 | packaged by conda-forge | (main, Jun 4 2025, 14:45:31) [GCC 13.3.0] (64-bit runtime) Python platform: Linux-6.14.0-27-generic-x86_64-with-glibc2.39 Is CUDA available: True CUDA runtime version: 12.0.140 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA GeForce RTX 5090 Nvidia driver version: 575.64.03 cuDNN version: Could not collect Is XPU available: False HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 46 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 32 On-line CPU(s) list: 0-31 Vendor ID: GenuineIntel Model name: Intel(R) Core(TM) i9-14900K CPU family: 6 Model: 183 Thread(s) per core: 2 Core(s) per socket: 24 Socket(s): 1 Stepping: 1 CPU(s) scaling MHz: 39% CPU max MHz: 6000.0000 CPU min MHz: 800.0000 BogoMIPS: 6374.40 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb intel_pt sha_ni xsaveopt xsavec xgetbv1 xsaves split_lock_detect user_shstk avx_vnni dtherm ida arat pln pts hwp hwp_notify hwp_act_window hwp_epp hwp_pkg_req hfi vnmi umip pku ospke waitpkg gfni vaes vpclmulqdq rdpid movdiri movdir64b fsrm md_clear serialize pconfig arch_lbr ibt flush_l1d arch_capabilities Virtualization: VT-x L1d cache: 896 KiB (24 instances) L1i cache: 1.3 MiB (24 instances) L2 cache: 32 MiB (12 instances) L3 cache: 36 MiB (1 instance) NUMA node(s): 1 NUMA node0 CPU(s): 0-31 Vulnerability Gather data sampling: Not affected Vulnerability Ghostwrite: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Reg file data sampling: Mitigation; Clear Register File Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected

Versions of relevant libraries: [pip3] numpy==2.1.2 [pip3] nvidia-cublas-cu12==12.9.1.4 [pip3] nvidia-cuda-cupti-cu12==12.9.79 [pip3] nvidia-cuda-nvrtc-cu12==12.9.86 [pip3] nvidia-cuda-runtime-cu12==12.9.79 [pip3] nvidia-cudnn-cu12==9.10.2.21 [pip3] nvidia-cufft-cu12==11.4.1.4 [pip3] nvidia-curand-cu12==10.3.10.19 [pip3] nvidia-cusolver-cu12==11.7.5.82 [pip3] nvidia-cusparse-cu12==12.5.10.65 [pip3] nvidia-cusparselt-cu12==0.7.1 [pip3] nvidia-nccl-cu12==2.27.3 [pip3] nvidia-nvjitlink-cu12==12.9.86 [pip3] nvidia-nvtx-cu12==12.9.79 [pip3] torch==2.8.0+cu129 [pip3] torchvision==0.23.0+cu129 [pip3] triton==3.4.0 [conda] numpy 2.1.2 pypi_0 pypi [conda] nvidia-cublas-cu12 12.9.1.4 pypi_0 pypi [conda] nvidia-cuda-cupti-cu12 12.9.79 pypi_0 pypi [conda] nvidia-cuda-nvrtc-cu12 12.9.86 pypi_0 pypi [conda] nvidia-cuda-runtime-cu12 12.9.79 pypi_0 pypi [conda] nvidia-cudnn-cu12 9.10.2.21 pypi_0 pypi [conda] nvidia-cufft-cu12 11.4.1.4 pypi_0 pypi [conda] nvidia-curand-cu12 10.3.10.19 pypi_0 pypi [conda] nvidia-cusolver-cu12 11.7.5.82 pypi_0 pypi [conda] nvidia-cusparse-cu12 12.5.10.65 pypi_0 pypi [conda] nvidia-cusparselt-cu12 0.7.1 pypi_0 pypi [conda] nvidia-nccl-cu12 2.27.3 pypi_0 pypi [conda] nvidia-nvjitlink-cu12 12.9.86 pypi_0 pypi [conda] nvidia-nvtx-cu12 12.9.79 pypi_0 pypi [conda] torch 2.8.0+cu129 pypi_0 pypi [conda] torchvision 0.23.0+cu129 pypi_0 pypi [conda] triton 3.4.0 pypi_0 pypi

Chen-Bo-Yang avatar Aug 22 '25 04:08 Chen-Bo-Yang

Hi @Chen-Bo-Yang , thanks for the report.

I am not sure why pin_memory would affect the result, but we can explain why setting num_workers > 0 would produce different resuts: num_workers > 0 triggers the DataLoader to create multiple workers and in particular to call torch.set_num_threads(1):

https://github.com/pytorch/pytorch/blob/6aef9f3a6906c011a57541c1de7a246222bc9ac9/torch/utils/data/_utils/worker.py#L257

which causes all calls to torch operators to be single-threaded. In contrast when num_workers=0, the torch operators are multi-threaded, and typically (but not just) sums aren't commutative anymore when multiple threads are invovled. Typically, a + b can lead to slightly different resuts to b + a, and I think this is why you're observing small differences

NicolasHug avatar Aug 26 '25 08:08 NicolasHug