pytorch_geometric icon indicating copy to clipboard operation
pytorch_geometric copied to clipboard

`OnDiskDataset` locks when `DataLoader` has `num_workers > 1`

Open ilsenatorov opened this issue 1 year ago • 2 comments

🐛 Describe the bug

If I create a OnDiskDataset instance (I'm using the PCQM4Mv2 implementation as reference here) and put it through a DataLoader with num_workers > 1 like this:

from torch_geometric.datasets import PCQM4Mv2
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader

ds = PCQM4Mv2(root="data/pcqm4m_small")
dl = DataLoader(ds, batch_size=4, num_workers=2)
batch = next(iter(dl))

I get a database is locked error:

OperationalError: Caught OperationalError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/ilya/miniconda3/envs/step/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/ilya/miniconda3/envs/step/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ilya/miniconda3/envs/step/lib/python3.11/site-packages/torch_geometric/loader/dataloader.py", line 54, in collate_fn
    return self(self.dataset.multi_get(batch))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ilya/miniconda3/envs/step/lib/python3.11/site-packages/torch_geometric/data/on_disk_dataset.py", line 151, in multi_get
    data_list = self.db.multi_get(indices, batch_size)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ilya/miniconda3/envs/step/lib/python3.11/site-packages/torch_geometric/data/database.py", line 361, in multi_get
    self.cursor.execute(query)
sqlite3.OperationalError: database is locked

Versions

Collecting environment information...
PyTorch version: 2.1.2
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Manjaro Linux (x86_64)
GCC version: (GCC) 13.2.1 20230801
Clang version: 16.0.6
CMake version: Could not collect
Libc version: glibc-2.38

Python version: 3.11.7 | packaged by conda-forge | (main, Dec 15 2023, 08:38:37) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-5.10.203-1-MANJARO-x86_64-with-glibc2.38
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 2080 SUPER
Nvidia driver version: 545.29.06
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      48 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             16
On-line CPU(s) list:                0-15
Vendor ID:                          AuthenticAMD
Model name:                         AMD Ryzen 7 5800X 8-Core Processor
CPU family:                         25
Model:                              33
Thread(s) per core:                 2
Core(s) per socket:                 8
Socket(s):                          1
Stepping:                           0
Frequency boost:                    enabled
CPU(s) scaling MHz:                 71%
CPU max MHz:                        6328.7100
CPU min MHz:                        2200.0000
BogoMIPS:                           8404.84
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm
Virtualization:                     AMD-V
L1d cache:                          256 KiB (8 instances)
L1i cache:                          256 KiB (8 instances)
L2 cache:                           4 MiB (8 instances)
L3 cache:                           32 MiB (1 instance)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0-15
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET, no microcode
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.2
[pip3] pytorch-lightning==2.1.3
[pip3] torch==2.1.2
[pip3] torch-cluster==1.6.3
[pip3] torch_geometric==2.4.0
[pip3] torchaudio==2.1.2
[pip3] torchmetrics==1.2.1
[pip3] torchvision==0.16.2
[pip3] triton==2.1.0
[conda] blas                      1.0                         mkl  
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] libjpeg-turbo             2.0.0                h9bf148f_0    pytorch
[conda] mkl                       2023.1.0         h213fc3f_46344  
[conda] mkl-service               2.4.0           py311h5eee18b_1  
[conda] mkl_fft                   1.3.8           py311h5eee18b_0  
[conda] mkl_random                1.2.4           py311hdb19cb5_0  
[conda] numpy                     1.26.2          py311h08b1b3b_0  
[conda] numpy-base                1.26.2          py311hf175353_0  
[conda] pyg                       2.4.0           py311_torch_2.1.0_cu121    pyg
[conda] pytorch                   2.1.2           py3.11_cuda12.1_cudnn8.9.2_0    pytorch
[conda] pytorch-cluster           1.6.3           py311_torch_2.1.0_cu121    pyg
[conda] pytorch-cuda              12.1                 ha16c6d3_5    pytorch
[conda] pytorch-lightning         2.1.3                    pypi_0    pypi
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torchaudio                2.1.2               py311_cu121    pytorch
[conda] torchmetrics              1.2.1                    pypi_0    pypi
[conda] torchtriton               2.1.0                     py311    pytorch
[conda] torchvision               0.16.2              py311_cu121    pytorch

ilsenatorov avatar Dec 22 '23 10:12 ilsenatorov

It seems that the issue stems from the way multi_get is handled in SQLiteDatabase, specifically this part:

# We create a temporary ID table to then perform an INNER JOIN.
# This avoids having a long IN clause and guarantees sorted outputs:
join_table_name = f'{self.name}__join__{uuid4().hex}'
query = (f'CREATE TABLE {join_table_name} (\n'
         f'  id INTEGER,\n'
         f'  row_id INTEGER\n'
         f')')
self.cursor.execute(query)

query = f'INSERT INTO {join_table_name} (id, row_id) VALUES (?, ?)'
self.cursor.executemany(query, zip(indices, range(len(indices))))

query = f'SELECT * FROM {join_table_name}'
self.cursor.execute(query)

query = (f'SELECT {self._joined_col_names} '
         f'FROM {self.name} INNER JOIN {join_table_name} '
         f'ON {self.name}.id = {join_table_name}.id '
         f'ORDER BY {join_table_name}.row_id')
self.cursor.execute(query)

AFAIK sqlite3 can handle multiprocess reads, but not multiprocess writes. Since multi_get also involves writing data into the temporary index table it locks the process, leading to the error.

ilsenatorov avatar Dec 22 '23 13:12 ilsenatorov

Strange, I am pretty sure this worked in an earlier code version. I fixed this in https://github.com/pyg-team/pytorch_geometric/pull/8667, but I feel we can do definitely more optimizations here since the applied fix is still gated behind locks.

rusty1s avatar Dec 23 '23 11:12 rusty1s

this should fix: https://github.com/pyg-team/pytorch_geometric/pull/9140

jay-bhambhani avatar Apr 02 '24 02:04 jay-bhambhani