fastMRI
fastMRI copied to clipboard
Training abruptly crashes on single GPU
While working with the knee dataset on a VarNet from Pytorch-lighting's library and using the FastMriDataModule
data-loaders, I observed that the training is unstable and crashes fairly often. I tried looking for similar issues within this repo but couldn't find any. I looked up PyTorch's forum to check for the same and observed such an issue is encountered when the data loader doesn't work well with multiprocessing link (https://github.com/pytorch/pytorch/issues/8976) -- they recommended using workers=0 which did stabilize my training for some time but after a while it crashes as well.
- Training on single GPU with:
backend = "gpu"
num_gpus = 1
batch_size = 8
using the FastMriDataModule on the single-coil Knee dataset. Reproduced on single V100 and RTX8000 GPU.
lightning 1.8.6
torch 2.0.1
- The Entire Traceback is as follows:
File "/ext3/miniconda3/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1132, in _try_get_data data = self._data_queue.get(timeout=timeout) File "/ext3/miniconda3/lib/python3.10/multiprocessing/queues.py", line 122, in get return _ForkingPickler.loads(res) File "/ext3/miniconda3/lib/python3.10/site-packages/torch/multiprocessing/reductions.py", line 307, in rebuild_storage_fd fd = df.detach() File "/ext3/miniconda3/lib/python3.10/multiprocessing/resource_sharer.py", line 57, in detach with _resource_sharer.get_connection(self._id) as conn: File "/ext3/miniconda3/lib/python3.10/multiprocessing/resource_sharer.py", line 86, in get_connection c = Client(address, authkey=process.current_process().authkey) File "/ext3/miniconda3/lib/python3.10/multiprocessing/connection.py", line 508, in Client answer_challenge(c, authkey) File "/ext3/miniconda3/lib/python3.10/multiprocessing/connection.py", line 752, in answer_challenge message = connection.recv_bytes(256) # reject large message File "/ext3/miniconda3/lib/python3.10/multiprocessing/connection.py", line 216, in recv_bytes buf = self._recv_bytes(maxlength) File "/ext3/miniconda3/lib/python3.10/multiprocessing/connection.py", line 414, in _recv_bytes buf = self._recv(4) File "/ext3/miniconda3/lib/python3.10/multiprocessing/connection.py", line 379, in _recv chunk = read(handle, remaining) ConnectionResetError: [Errno 104] Connection reset by peer
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/scratch/ps4364/fmri2020/varnet_l1_2/unet_knee_sc.py", line 192, in
Hello @pranavsinghps1, this is a confusing error. I don't see a single line in the trace that mentions fastMRI. Are you sure there isn't an issue with your install?
Also, we don't actually test VarNet with the single coil data - it's really meant for multicoil with a batch size of 1. Is there a reference that shows VarNet works for single coil that you're trying to reproduce?
I see, Thank you for your prompt response. I will try to realign with the requirements mentioned here (https://github.com/facebookresearch/fastMRI/blob/main/setup.cfg)
As for the use of VarNet for single coil reconstruction -- I did see that in [1], it is mentioned that VarNet is exclusively used for multicoil reconstruction while U-Net for both -- is there a rationale for this? I was trying to figure out the same. For my VarNet I have removed the sensitivity net and just using the Vanila VarNet with ResNet 18 backbone.
[1] Sriram, Anuroop, et al. "End-to-end variational networks for accelerated MRI reconstruction." Medical Image Computing and Computer Assisted Intervention–MICCAI 2020: 23rd International Conference, Lima, Peru, October 4–8, 2020, Proceedings, Part II 23. Springer International Publishing, 2020.
Hello @pranavsinghps1, the main innovation of that paper is the end-to-end aspect where the model estimates both the sensitivity maps and the final image. In non-E2E VarNets, the sensitivity maps are precomputed via another method (such as ESPiRIT). Those methods are not end-to-end.
However, in the single-coil case there are no sensitivities, so you just have a regular VarNet.
We never prioritized the development of a single-coil VarNet because in the real world, all MRI scanners are multicoil. There are enormous benefits of multicoil over single coil in terms of SNR and image quality. The single-coil data is only a sort of toy setting for interested people initially getting into the area, but only works done on the multi-coil data are likely to have any impact on real-world scanners.
Thank you @mmuckley for the detailed information on this: I had one question: why multi-coil is trained with a batch size of 1 ?
Update on the issue: rewriting the dataloaders using SliceDataset solved the issue.
Hello @pranavsinghps1, the main reason is that many of the multicoil volumes have different matrix sizes for the data. With the VarNet we need to do data consistency on the raw data, so there is no way to do simple batching. In the end we made the VarNet large enough that it used all of 1 GPU's memory, and so we found that batch size of 1, with a large model, was the most effective training strategy.
As for the Issue, could you post more details of your solution? If there is no issue with the core repository, please close the issue.