spatialdata icon indicating copy to clipboard operation
spatialdata copied to clipboard

multiprocessing for dataloader

Open xiao233333 opened this issue 1 year ago • 2 comments

Hi, I have multiple datasets for training. When I build the dataloader, it seems that the multiprocessing doesn't work. Here is the code:

import torch.multiprocessing as mp
mp.set_start_method("spawn", force=True)

dataloader_train = torch.utils.data.DataLoader(train_set, batch_size=256, shuffle=True, num_workers=8)
trainer.fit(model,train_dataloaders=dataloader_train)

and here is the error:

AttributeError                            Traceback (most recent call last)
File ~/micromamba/envs/spatialdata/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:44, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     43         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 44     return trainer_fn(*args, **kwargs)
     46 except _TunerExitException:

File ~/micromamba/envs/spatialdata/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:580, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    574 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    575     self.state.fn,
    576     ckpt_path,
    577     model_provided=True,
    578     model_connected=self.lightning_module is not None,
    579 )
--> 580 self._run(model, ckpt_path=ckpt_path)
    582 assert self.state.stopped

File ~/micromamba/envs/spatialdata/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:987, in Trainer._run(self, model, ckpt_path)
    984 # ----------------------------
    985 # RUN THE TRAINER
    986 # ----------------------------
--> 987 results = self._run_stage()
    989 # ----------------------------
    990 # POST-Training CLEAN UP
    991 # ----------------------------

File ~/micromamba/envs/spatialdata/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1033, in Trainer._run_stage(self)
   1032 with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1033     self.fit_loop.run()
   1034 return None

File ~/micromamba/envs/spatialdata/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:197, in _FitLoop.run(self)
    196 def run(self) -> None:
--> 197     self.setup_data()
    198     if self.skip:

File ~/micromamba/envs/spatialdata/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:263, in _FitLoop.setup_data(self)
    262 self._data_fetcher.setup(combined_loader)
--> 263 iter(self._data_fetcher)  # creates the iterator inside the fetcher
    264 max_batches = sized_len(combined_loader)

File ~/micromamba/envs/spatialdata/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py:104, in _PrefetchDataFetcher.__iter__(self)
    102 @override
    103 def __iter__(self) -> \"_PrefetchDataFetcher\":
--> 104     super().__iter__()
    105     if self.length is not None:
    106         # ignore pre-fetching, it's not necessary

File ~/micromamba/envs/spatialdata/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py:51, in _DataFetcher.__iter__(self)
     49 @override
     50 def __iter__(self) -> \"_DataFetcher\":
---> 51     self.iterator = iter(self.combined_loader)
     52     self.reset()

File ~/micromamba/envs/spatialdata/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py:351, in CombinedLoader.__iter__(self)
    350 iterator = cls(self.flattened, self._limits)
--> 351 iter(iterator)
    352 self._iterator = iterator

File ~/micromamba/envs/spatialdata/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py:92, in _MaxSizeCycle.__iter__(self)
     90 @override
     91 def __iter__(self) -> Self:
---> 92     super().__iter__()
     93     self._consumed = [False] * len(self.iterables)

File ~/micromamba/envs/spatialdata/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py:43, in _ModeIterator.__iter__(self)
     41 @override
     42 def __iter__(self) -> Self:
---> 43     self.iterators = [iter(iterable) for iterable in self.iterables]
     44     self._idx = 0

File ~/micromamba/envs/spatialdata/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py:43, in <listcomp>(.0)
     41 @override
     42 def __iter__(self) -> Self:
---> 43     self.iterators = [iter(iterable) for iterable in self.iterables]
     44     self._idx = 0

File ~/micromamba/envs/spatialdata/lib/python3.10/site-packages/torch/utils/data/dataloader.py:439, in DataLoader.__iter__(self)
    438 else:
--> 439     return self._get_iterator()

File ~/micromamba/envs/spatialdata/lib/python3.10/site-packages/torch/utils/data/dataloader.py:387, in DataLoader._get_iterator(self)
    386 self.check_worker_number_rationality()
--> 387 return _MultiProcessingDataLoaderIter(self)

File ~/micromamba/envs/spatialdata/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1040, in _MultiProcessingDataLoaderIter.__init__(self, loader)
   1034 # NB: Process.start() actually take some time as it needs to
   1035 #     start a process and pass the arguments over via a pipe.
   1036 #     Therefore, we only add a worker to self._workers list after
   1037 #     it started, so that we do not call .join() if program dies
   1038 #     before it starts, and __del__ tries to join but will get:
   1039 #     AssertionError: can only join a started process.
-> 1040 w.start()
   1041 self._index_queues.append(index_queue)

File ~/micromamba/envs/spatialdata/lib/python3.10/multiprocessing/process.py:121, in BaseProcess.start(self)
    120 _cleanup()
--> 121 self._popen = self._Popen(self)
    122 self._sentinel = self._popen.sentinel

File ~/micromamba/envs/spatialdata/lib/python3.10/multiprocessing/context.py:224, in Process._Popen(process_obj)
    222 @staticmethod
    223 def _Popen(process_obj):
--> 224     return _default_context.get_context().Process._Popen(process_obj)

File ~/micromamba/envs/spatialdata/lib/python3.10/multiprocessing/context.py:288, in SpawnProcess._Popen(process_obj)
    287 from .popen_spawn_posix import Popen
--> 288 return Popen(process_obj)

File ~/micromamba/envs/spatialdata/lib/python3.10/multiprocessing/popen_spawn_posix.py:32, in Popen.__init__(self, process_obj)
     31 self._fds = []
---> 32 super().__init__(process_obj)

File ~/micromamba/envs/spatialdata/lib/python3.10/multiprocessing/popen_fork.py:19, in Popen.__init__(self, process_obj)
     18 self.finalizer = None
---> 19 self._launch(process_obj)

File ~/micromamba/envs/spatialdata/lib/python3.10/multiprocessing/popen_spawn_posix.py:47, in Popen._launch(self, process_obj)
     46     reduction.dump(prep_data, fp)
---> 47     reduction.dump(process_obj, fp)
     48 finally:

File ~/micromamba/envs/spatialdata/lib/python3.10/multiprocessing/reduction.py:60, in dump(obj, file, protocol)
     59 '''Replacement for pickle.dump() using ForkingPickler.'''
---> 60 ForkingPickler(file, protocol).dump(obj)

AttributeError: Can't pickle local object '_get_points.<locals>.<lambda>'


xiao233333 avatar Jun 18 '24 03:06 xiao233333

hi @xiao233333 , thank you for the interest in spatialdata. Can you share how did you created the train_dataset that you pass to the dataloader?

giovp avatar Jun 19 '24 15:06 giovp

Hi @giovp, Thanks for your efforts in looking into this issue. I have multiple Xenium datasets. I loaded each of them into a sdata object, then I used ImageTilesDataset to build the dataset for each sample,

dataset = ImageTilesDataset(
        sdata=sdata,
        regions_to_images={"cell_boundaries": "HE"},
        regions_to_coordinate_systems={"cell_boundaries": "global"},
        table_name="table",
        tile_dim_in_units=tile_size,
        transform=transform,
        rasterize=True,
        rasterize_kwargs={"target_width": target_size},
    )

where it contains image tiles and corresponding gene expression and labels. The transform is

def tile_gene_obs_label_transform_fn(cell_types: pd.Index | List[str],
                                     obs_label: str) -> callable:
    def tile_gene_obs_label_transform(
            sdata: SpatialData,) -> tuple[torch.tensor, torch.tensor]:
        tile = sdata['HE'].data.compute()
        tile = torch.tensor(tile, dtype=torch.float32)
        gene = sdata["table"].to_df().values[0]
        gene = torch.tensor(gene, dtype=torch.float32)
        expected_category = sdata["table"].obs[obs_label].values[0]
        expected_category = cell_types.index(expected_category)
        cell_type = F.one_hot(
            torch.tensor(expected_category),
            num_classes=len(cell_types)).type(torch.float32)
        return tile, gene, cell_type
    return tile_gene_obs_label_transform

Once I got all the datasets, I then used the LightningDataModule to split the data by sample and generate the dataloader.

class TilesDataModule(LightningDataModule):
    def __init__(
            self,
            dataset_dict: dict,
            batch_size: int,
            num_workers: int,
            dataset: torch.utils.data.Dataset):
        super().__init__()
        self.dataset_dict = dataset_dict
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.samples = list(self.dataset_dict.keys())
        self.dataset = torch.utils.data.ConcatDataset(
            [self.dataset_dict[i] for i in self.samples])

    def setup(self, stage=None):
        train, valid, test = np.split(
            self.samples,
            [int(.6 * len(self.samples)),
             int(.8 * len(self.samples))])
        self.train = torch.utils.data.ConcatDataset(
            [self.dataset_dict[i] for i in train])
        self.val = torch.utils.data.ConcatDataset(
            [self.dataset_dict[i] for i in valid])
        self.test = torch.utils.data.ConcatDataset(
            [self.dataset_dict[i] for i in test])

    def train_dataloader(self):
        return DataLoader(
            self.train,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
        )

Please point out if I did anything wrong. Thanks!

xiao233333 avatar Jul 03 '24 05:07 xiao233333