gritlm icon indicating copy to clipboard operation
gritlm copied to clipboard

CustomRandomSampler not working in huggingface Trainer and Accelerator

Open YanshekWoo opened this issue 1 year ago • 1 comments

Issue

When I test the , it seems that the huggingface Trainer and Accelerator will replace the Sampler by a new object. Please refer to code: get_train_dataloader function in trainer and prepare_data_loader function in accelerate

When I try to print the sampler Class of dataloader before and after self.accelerator.prepare(), I get the following output:

<finetune.data.InTaskRandomSampler object at 0x7ff1a4b7c310>
<torch.utils.data.sampler.SequentialSampler object at 0x7ff1a4131c00>

Same issue can be found in https://discuss.huggingface.co/t/accelerator-prepare-replaces-custom-dataloader-sampler/43392.

Solution

A possible solution is to rewrite a torch.utils.data.distributed.DistributedSampler, and avoid using the self.accelerator.prepare in trainer. Of course it is necessary to rewrite the get_train_dataloader function in trainer .

YanshekWoo avatar Apr 25 '24 04:04 YanshekWoo

Sure feel free to open a PR

Muennighoff avatar Apr 25 '24 04:04 Muennighoff