wekws
wekws copied to clipboard
内存占用溢出问题
尊敬的开发者您好,首先感谢您们这项有价值的工作,但是在实际使用过程中我遇到了一些问题:训练过程中内存的占用在逐渐升高,直到最后训练被迫中止;我目前正在review源码检查问题,但是由于不是很熟悉代码,所以阅读的比较慢;同时在这里想了解一下您目前是否有解决方案,以解决当前的问题呢?
目前已解决,解决方法如下:
问题原因:fsmn.py 文件 class FSMN(nn.Module) 类的 forward 函数中的 torch.cat(in_cache, dim=-1) 会一直复制占用内存,导致内存不断升高。
更改方法:将这一行代码拆开这样写就可以了
# x7 = self.softmax(x6)
x7, _ = x6
# return x7, None
# ===============================
cat_size = sum(tensor.size(-1) for tensor in in_cache)
ret_cache = torch.zeros([in_cache[0].shape[0], in_cache[0].shape[1], in_cache[0].shape[2], cat_size])
for i in range(cat_size):
ret_cache[:, :, :, i] = in_cache[i].detach().squeeze(-1)
# ===============================
return x7, ret_cache
我对 fsmn 不太熟,@duj12 靖哥你看这里的 detach() 会影响 fsmn 的训练吗
如果是单卡训练的情况下,试试去掉prefetch_factor或者减小prefetch_factor值 , 例如这个 train_data_loader = DataLoader(train_dataset, batch_size=None, pin_memory=args.pin_memory, num_workers=args.num_workers) #prefetch_factor=args.prefetch)