swyft
swyft copied to clipboard
Lightning 2.x support
Hi @cweniger et al.,
I notice that currently swyft
is based on pytorch-lightning==1.9.5
. Do you have any plans to upgrade this legacy component to a new version, e.g.,lightning==2.4.x
?
A bit of background: In a collaboration with DAMTP (Cambridge) I am currently porting the swyft
library to Intel GPU, because their supercomputer "Dawn" is powered by Intel GPUs. Together with my colleagues we have made a version of lightning
that supports Intel GPUs, but it is based on lightning 2.x. Therefore, I made some changes in the swyft code to bump up lightning
to 2.4, but it seems that swyft
relies on an API that is no longer available in 2.x. The following error occurs when I try to do trainer.infer(network, obs, prior_samples)
:
Traceback (most recent call last):
File "/nfs/site/home/xucai/Works/swyft/tests/truncation.py", line 78, in <module>
predictions, bounds, samples = round(obs, bounds = bounds)
File "/nfs/site/home/xucai/Works/swyft/tests/truncation.py", line 68, in round
predictions = trainer.infer(network, obs, prior_samples)
File "/nfs/site/home/xucai/Works/swyft/swyft/lightning/core.py", line 318, in infer
ratio_batches = self.predict(model, dl)
File "/nfs/site/home/xucai/.conda/envs/swyft/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py", line 858, in predict
return call._call_and_handle_interrupt(
File "/nfs/site/home/xucai/.conda/envs/swyft/lib/python3.9/site-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/nfs/site/home/xucai/.conda/envs/swyft/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py", line 897, in _predict_impl
results = self._run(model, ckpt_path=ckpt_path)
File "/nfs/site/home/xucai/.conda/envs/swyft/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run
results = self._run_stage()
File "/nfs/site/home/xucai/.conda/envs/swyft/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py", line 1020, in _run_stage
return self.predict_loop.run()
File "/nfs/site/home/xucai/.conda/envs/swyft/lib/python3.9/site-packages/lightning/pytorch/loops/utilities.py", line 178, in _decorator
return loop_run(self, *args, **kwargs)
File "/nfs/site/home/xucai/.conda/envs/swyft/lib/python3.9/site-packages/lightning/pytorch/loops/prediction_loop.py", line 107, in run
self.reset()
File "/nfs/site/home/xucai/.conda/envs/swyft/lib/python3.9/site-packages/lightning/pytorch/loops/prediction_loop.py", line 176, in reset
raise ValueError('`trainer.predict()` only supports the `CombinedLoader(mode="sequential")` mode.')
ValueError: `trainer.predict()` only supports the `CombinedLoader(mode="sequential")` mode.
How do we get rid of this issue? Do we need to use CombinedLoader
in non-sequential mode?
If it helps, I can submit a PR with the modification that bumps up the lightning version to 2.4.
Thanks, Maxwell