wespeaker
wespeaker copied to clipboard
Is it possible to implement dataloader for triplet loss/GE2E
Hi, is it possible to make a batch containing M speaker and N utterances for each speaker?
I don't think it can be supported in UIO mode. But in dataset_deprecated.py, it may be implemented.
Thanks for your response but what's UIO mode? And I think maybe it can be implemented in class DistributedSampler in dataset.py?
Thanks for your response but what's UIO mode? And I think maybe it can be implemented in class DistributedSampler in dataset.py?
Check our paper for introduction of the UIO data management. We design this mode for large dataset training. BTW, DistributedSampler is designed for shuffling the data.list and distributes them into different GPUs. It might not satisfy your demands. I think a proper way is to design your own collate_fn for the dataloader. You can refer to our implementation in DINO ssl training codes in collate_fn. Of course some modification is needed.
Thanks for your response but what's UIO mode? And I think maybe it can be implemented in class DistributedSampler in dataset.py?
Check our paper for introduction of the UIO data management. We design this mode for large dataset training. BTW, DistributedSampler is designed for shuffling the data.list and distributes them into different GPUs. It might not satisfy your demands. I think a proper way is to design your own collate_fn for the dataloader. You can refer to our implementation in DINO ssl training codes in collate_fn. Of course some modification is needed.
I understand DistributedSampler is designed for distributing data into different GPUs. But can we distribute according to the spk id? For example, the code can be (and remove shuffle in processor.py):
def __init__(self, lists, num_utts, shuffle=True, partition=True):
self.epoch = -1
self.update()
self.shuffle = shuffle
self.partition = partition
self.num_utts = num_utts
self.spk={}
for i in range(len(lists)):
obj = json.loads(lists[i])
if obj['spk'] not in self.spk:
self.spk[obj['spk']]=[]
self.spk[obj['spk']].append(i)
def update(self):
assert dist.is_available()
if dist.is_initialized():
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
else:
self.rank = 0
self.world_size = 1
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
self.worker_id = 0
self.num_workers = 1
else:
self.worker_id = worker_info.id
self.num_workers = worker_info.num_workers
return dict(rank=self.rank,
world_size=self.world_size,
worker_id=self.worker_id,
num_workers=self.num_workers)
def set_epoch(self, epoch):
self.epoch = epoch
def sample(self,lists):
spk = list(self.spk.keys())
if self.partition:
if self.shuffle:
random.Random(self.epoch).shuffle(spk)
spk = spk[self.rank::self.world_size]
spk = spk[self.worker_id::self.num_workers]
data = []
for i in spk:
data=data+random.choices(self.spk[i], k=self.num_utts)
return data
I just noticed that I set the data type as 'raw', and the above code is not appropriate for 'shard'
Yeah... It makes senses in the 'raw' mode. Hope it works for you! Good luck!
I just noticed that I set the data type as 'raw', and the above code is not appropriate for 'shard'
In raw mode, it's much easier to implement your function(but slow). But in shard mode it's also possible except that it takes some efforts for the implementation.
Possible approach:
- Rewrite the write_shard function to make sure each shard contains multiple occurances from the same speaker
- As suggested in other comments, the easiest way is to implement it in the colloate_fn, but samples in one batch is limited. One workaroud might be: If you notice the shuffle function in the processor, you can actually write a simliar processor and do something in the buffer, which can be set much larger than the batch-size. Then the collate_fn can handle real batch organization.
But overall, you need to balance the randomness and data processing difficulty.