pycave
pycave copied to clipboard
multi-GPU error
Hi,
I'm wondering if multi-gpu is supported.
It works when using one GPU but when using multiple GPUs I couldn't get either the Kmeans or GMM to fit (i.e. devices = 3
). Thanks!
estimator_gmm = GaussianMixture(8, trainer_params=dict(accelerator='gpu', devices=3,
enable_progress_bar = True),
init_means = estimator.model_.centroids)
ProcessRaisedException Traceback (most recent call last)
Cell In[29], line 1
----> 1 estimator_gmm.fit(torch.from_numpy(X))
File conda_envs/squidpy/lib/python3.8/site-packages/pycave/bayes/gmm/estimator.py:175, in GaussianMixture.fit(self, data)
168 else:
169 module = GaussianMixtureRandomInitLightningModule(
170 self.model_,
171 covariance_regularization=self.covariance_regularization,
172 is_batch_training=is_batch_training,
173 use_model_means=self.init_means is not None,
174 )
--> 175 self.trainer(max_epochs=1 + int(is_batch_training)).fit(module, loader)
177 # Fit model
178 logger.info("Fitting Gaussian mixture...")
File conda_envs/squidpy/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:700, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
681 r"""
682 Runs the full optimization routine.
683
(...)
697 datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.
698 """
699 self.strategy.model = model
--> 700 self._call_and_handle_interrupt(
701 self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
702 )
File conda_envs/squidpy/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:652, in Trainer._call_and_handle_interrupt(self, trainer_fn, *args, **kwargs)
650 try:
651 if self.strategy.launcher is not None:
--> 652 return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
653 else:
654 return trainer_fn(*args, **kwargs)
File conda_envs/squidpy/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/multiprocessing.py:103, in _MultiProcessingLauncher.launch(self, function, trainer, *args, **kwargs)
100 else:
101 process_args = [trainer, function, args, kwargs, return_queue]
--> 103 mp.start_processes(
104 self._wrapping_function,
105 args=process_args,
106 nprocs=self._strategy.num_processes,
107 start_method=self._start_method,
108 )
109 worker_output = return_queue.get()
110 if trainer is None:
File conda_envs/squidpy/lib/python3.8/site-packages/torch/multiprocessing/spawn.py:198, in start_processes(fn, args, nprocs, join, daemon, start_method)
195 return context
197 # Loop on join until it returns True or raises an exception.
--> 198 while not context.join():
199 pass
File conda_envs/squidpy/lib/python3.8/site-packages/torch/multiprocessing/spawn.py:160, in ProcessContext.join(self, timeout)
158 msg = "\n\n-- Process %d terminated with the following error:\n" % error_index
159 msg += original_trace
--> 160 raise ProcessRaisedException(msg, error_index, failed_process.pid)
ProcessRaisedException:
-- Process 2 terminated with the following error:
Traceback (most recent call last):
File "conda_envs/squidpy/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
fn(i, *args)
File "conda_envs/squidpy/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/multiprocessing.py", line 129, in _wrapping_function
results = function(*args, **kwargs)
File "conda_envs/squidpy/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 741, in _fit_impl
results = self._run(model, ckpt_path=self.ckpt_path)
File "squidpy/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1101, in _run
self.strategy.setup_environment()
File "squidpy/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py", line 130, in setup_environment
self.accelerator.setup_environment(self.root_device)
File "squidpy/lib/python3.8/site-packages/pytorch_lightning/accelerators/cuda.py", line 43, in setup_environment
torch.cuda.set_device(root_device)
File "squidpy/lib/python3.8/site-packages/torch/cuda/__init__.py", line 314, in set_device
torch._C._cuda_setDevice(device)
File "squidpy/lib/python3.8/site-packages/torch/cuda/__init__.py", line 207, in _lazy_init
raise RuntimeError(
RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method