torchrec
torchrec copied to clipboard
When using TorchRec version 0.8.0 or later, we cannot train for more than one epoch when set `persistent_workers=true` in the dataloader.
In the _next_batch method of TrainPipelineSparseDist, we check whether the new dataloader_iter is the same as the original dataloader_iter. We proceed to fetch the next batch only if they are different. However, when we set persistent_workers=true in the dataloader, the dataloader_iter remains the same instance for each epoch. As a result, we can not get data when the epoch exceeds 1.
https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/train_pipeline/train_pipelines.py#L578
def _next_batch(self, dataloader_iter: Iterator[In]) -> Optional[In]:
if dataloader_iter is not self._dataloader_iter:
self._dataloader_iter = dataloader_iter
self._dataloader_exhausted = False
if self._dataloader_exhausted:
batch = None
else:
with record_function("## next_batch ##"):
batch = next(dataloader_iter, None)
if batch is None:
self._dataloader_exhausted = True
return batch
Hi, @henrylhtsang @IvanKobzarev @joshuadeng @PaulZhang12 can you see this problem?
Try setting the num_workers = 0, and see if it solves your problem. In my case, it works.