gritlm
gritlm copied to clipboard
CustomRandomSampler not working in huggingface Trainer and Accelerator
Issue
When I test the , it seems that the huggingface Trainer and Accelerator will replace the Sampler by a new object. Please refer to code: get_train_dataloader function in trainer and prepare_data_loader function in accelerate
When I try to print the sampler Class of dataloader before and after self.accelerator.prepare(), I get the following output:
<finetune.data.InTaskRandomSampler object at 0x7ff1a4b7c310>
<torch.utils.data.sampler.SequentialSampler object at 0x7ff1a4131c00>
Same issue can be found in https://discuss.huggingface.co/t/accelerator-prepare-replaces-custom-dataloader-sampler/43392.
Solution
A possible solution is to rewrite a torch.utils.data.distributed.DistributedSampler, and avoid using the self.accelerator.prepare in trainer. Of course it is necessary to rewrite the get_train_dataloader function in trainer .
Sure feel free to open a PR