wenet
wenet copied to clipboard
IO 重构,提升多机多卡训练效率 + 代码复用
现在 IO 的主要瓶颈是在哪里?
- https://github.com/wenet-e2e/wenet/issues/2095
- gpu利用率吃不满,尖峰很多
- 未来引入多模态之后,io还要考虑其他模态
星辰提到的3, 可以加个多任务, 下边举个例子 以下是来自chatgpt的回答
在多任务训练中,批次(batch)的组织和损失函数的定义需要考虑多个任务之间的关系和权衡。下面是一些常见的方法:
批次组织:在组织批次时,可以采取以下策略:
同时从每个任务的数据集中抽取一定数量的样本,形成一个批次。这种方法适用于任务之间的数据量相近或需要平衡处理。
为每个任务设置不同的批次大小,根据任务的重要性或数据分布进行调整。较重要的任务可以分配更大的批次大小,以便更充分地更新模型参数。
根据硬件资源和内存限制,将批次大小限制在可接受的范围内。
......
这个例子中,我们需要很方便初始化dataset1 dataset2 dataset3.
dataset = interleave(dataset1, dataset2, dataset3)
我们需要这要的功能,并且不止是多任务, 对于语音任务 在一个batch上我们也需要组合不同领域的数据进行tune
另外,补充下 , 现在recognize.py 依赖的dataset 需要有label, 重构的dataset需要考虑单纯的infer 比如:
dataset(split='train') # speech and label
dataset(split='cv') # speech and label
dataset(split='infer') # only speech
对于和大模型的结合:
textdataset # only text
speechdataset # only speech
...
当前瓶颈 在于:
1 同步的取数据 (及时有prefetch) 2 可以有并行的地方, 比如已经下载到内存的shard/wavs 可以parallel decode fbank等(现在极度依赖num_worker, num_worker 过大会有bus error问题 (并且num worker并行多的话 tar多了 , 我们也需要的是非tar级的并发), 如果并行起来 就可能是num_worker * parallel_num),
方案一:集成huggingface datasets:
1 优点:
- 原生函数:
- map (parallel_map)
- filter
- sort (此函数有坑, 没有buffer size, 还没看具体实现)
- shuffle
- interleave
- etc
这些原生函数的实现,dtaset 会极度简化 和多模态结合 也会很方便 预计会减少gpu尖峰, 但是不会像tfdata那样(没有解藕生产消费):
2 缺点:
- prefetch 依赖pytorch 的 DataLoader prefetch
- 不支持tar , 原生支持arrow, 如果要支持tar, 需要自己写 _generate_examples, 但是yield是python 这里慢的话 上述能够并行的地方也会被限制,(原生arrow底层是c++实现)
方案二: 造轮子
抽象现有的process函数补&&齐&添加新函数
- parallel_map (multi_processing)
- filter
- sort
- interleave
- etc
缺点:
- 新功能需要一直添加
优点:
- 简洁
- 支持现有的所有io形式
- 其他模态也可以支持
ref: https://huggingface.co/docs/datasets/stream
补充下:大模型时代,好多吧动态batch这种功能给去掉了,hg也不支持(tfdata 支持), 比如whisper 直接encoder 30s , LLM也有直接pad到预先设置的最大(非batch)
方案二能快速支持arrow格式的数据吗
方案二能快速支持arrow格式的数据吗
应该可以的, hg也是用的arrow的py包
ref: https://arrow.apache.org/docs/python/
可以,那支持方案二
新IO需要考虑determinism(data 层面保证可复现) shuffle 需要设置seed
测试multiprocess 和multithread:
100条音频, 每条9s , ‘并发“计算fbank
- 单条计算0.02,
- 100条顺序算为2左右,
- 多线程(nthread=100)也为2左右, (cpu bound 变成串行 gil)
- 多进程(nproc=100)为0.1
关于代码复用性, 下边是代码片段
# 这里我们可以继承IterableDataset
class Dataset():
def __init__(self, dataset, func=None, *args, **kwargs):
self._dataset = dataset
self.args = args
self.kwargs = kwargs
self.func = func
@staticmethod
def from_source(source):
return Dataset(source)
def __iter__(self):
return self
def __next__(self):
if not self._dataset:
raise StopIteration
data = next(self._dataset)
return self.func(data)
def map(self, func, *args, **kwargs):
return MapperDataset(self, func, *args, **kwargs)
def filter(self, func, *args, **kwargs):
return FilterDataset(self, func, *args, **kwargs)
class MapperDataset(Dataset):
def __init__(self, dataset, func=None, *args, **kwargs):
self._dataset = dataset
self.args = args
self.kwargs = kwargs
self.func = func
def __iter__(self):
return self
def __next__(self):
if not self._dataset:
raise StopIteration
data = next(self._dataset)
return self.func(data)
class FilterDataset(Dataset):
def __init__(self, dataset, func=None, *args, **kwargs):
self._dataset = dataset
self.args = args
self.kwargs = kwargs
self.func = func
def __iter__(self):
return self
def __next__(self):
if not self._dataset:
raise StopIteration
data = next(self._dataset)
while not self.func(data):
data = next(self._dataset)
return data
source = iter([1,2,3,4])
dataset = Dataset(source, lambda elem: elem)
dataset = dataset.map(lambda elem: {"speech": elem*2})
dataset = dataset.filter(lambda elem_dict: elem_dict['speech'] > 2)
for d in dataset:
print(d)
# output:
{'speech': 4}
{'speech': 6}
{'speech': 8}
wenet的训练脚本在evaluate时只能单卡去过整个eval dataset,这部分感觉也是可以优化的,是否可以通过继承torch的sampler来实现利用DDP加速。比如:
from torch.utils.data.distributed import DistributedSampler
from catalyst.data.sampler import DistributedSamplerWrapper
dataset = ...
shuffle = ...
sampler = ...
# If DDP on
if torch.distributed.is_initialized():
# If using a custom sampler make it distributed
if sampler is not None:
sampler = DistributedSamplerWrapper(sampler,
shuffle=shuffle,
num_replicas=communication.get_world_size(),
rank=communication.get_rank())
# If no custom sampler then just use the DistributedSampler
else:
sampler = DistributedSampler(dataset,
shuffle=shuffle,
num_replicas=communication.get_world_size(),
rank=communication.get_rank())
# shuffle shouldn't be specified in DataLoader when using a sampler
shuffle = shuffle if sampler is None else None
dataloader = DataLoader(dataset, sampler=sampler, shuffle=shuffle, ...)
是的,现在 eval 的时候是单卡 eval 所有,没有发挥多卡优势。这块是个优化点。
是的,现在 eval 的时候是单卡 eval 所有,没有发挥多卡优势。这块是个优化点。
期待wenet的新版脚本哈哈
modified code:
class Dataset:
def __init__(self, source, f=lambda elem: elem, *args, **kw):
assert callable(f)
self._dataset = source
self.f = f
self.args = args
self.kw = kw
def set_epoch(self, epoch):
self.source.set_epoch(epoch)
def __iter__(self):
""" Return an iterator over the source dataset processed by the
given processor.
"""
assert self._dataset is not None
assert callable(self.f)
for data in self._dataset:
yield data
def apply(self, f):
assert callable(f)
return Dataset(self, f, *self.args, **self.kw)
def map(self, func, *args, **kwargs):
return MapperDataset(self, func, *args, **kwargs)
def filter(self, func, *args, **kwargs):
return FilterDataset(self, func, *args, **kwargs)
def sort(self, func, *args, **kwargs):
return SortDataset(self, func, *args, **kwargs)
def zip(self, *datasets):
return ZipDataset(self, *datasets)
class MapperDataset(Dataset):
def __init__(self, dataset, func=None, *args, **kwargs):
self._dataset = dataset
self.args = args
self.kwargs = kwargs
self.func = func
def __iter__(self):
return self._generator()
def _generator(self):
for data in self._dataset:
yield self.func(data, *self.args, **self.kwargs)
class FilterDataset(Dataset):
def __init__(self, dataset, func=None, *args, **kwargs):
self._dataset = dataset
self.args = args
self.kwargs = kwargs
self.func = func
def __iter__(self):
return self._generator()
def _generator(self):
for data in self._dataset:
if self.func(data, *self.args, **self.kwargs):
yield data
class SortDataset(Dataset):
def __init__(self, dataset, key=None, reverse=False, buffer_size=None):
self._dataset = dataset
self.key = key
self.reverse = reverse
self.buffer_size = buffer_size
def __iter__(self):
return self._generator()
def _generator(self):
buffer = []
for data in self._dataset:
buffer.append(data)
if self.buffer_size is not None and len(buffer) >= self.buffer_size:
sorted_buffer = sorted(buffer, key=self.key, reverse=self.reverse)
for sorted_data in sorted_buffer:
yield sorted_data
buffer.clear()
if buffer:
sorted_buffer = sorted(buffer, key=self.key, reverse=self.reverse)
for sorted_data in sorted_buffer:
yield sorted_data
class ZipDataset(Dataset):
def __init__(self, *datasets):
self.datasets = datasets
def __iter__(self):
return self._generator()
def _generator(self):
iterators = [iter(dataset) for dataset in self.datasets]
while True:
try:
data = [next(iterator) for iterator in iterators]
yield tuple(data)
except StopIteration:
return
class PaddingBatchDataset(Dataset):
def __init__(self, dataset, batch_size, padding_fn, max_length_fn):
self.dataset = dataset
self.batch_size = batch_size
self.padding_fn = padding_fn
self.max_length_fn = max_length_fn
def __iter__(self):
return self._generator()
def _generator(self):
batch = []
max_length = 0
for data in self.dataset:
batch.append(data)
max_length = self.max_length_fn(data, max_length)
if len(batch) == self.batch_size:
padded_batch = self._pad_batch(batch, max_length)
yield padded_batch
batch = []
max_length = 0
if batch:
padded_batch = self._pad_batch(batch, max_length)
yield padded_batch
def _pad_batch(self, batch, max_length):
padded_batch = []
for data in batch:
padding_length = max_length - self.max_length_fn(data)
padded_data = self.padding_fn(data, padding_length)
padded_batch.append(padded_data)
return padded_batch
# 创建数据源
def generator(data):
for d in data:
yield d
source = generator([1,2,3,4,1])
# 创建 Dataset 实例
speech_dataset = Dataset(source)
# preprocess
speech_dataset = speech_dataset.map(lambda elem: {"speech": elem * 2})
speech_dataset = speech_dataset.filter(lambda elem_dict: elem_dict['speech'] >= 2)
speech_dataset = speech_dataset.sort(lambda elem_dict: elem_dict['speech'], buffer_size=2)
# fbank
speech_dataset = speech_dataset.map(lambda elem_dict: {'fbank': elem_dict['speech'] + 1, 'speech': elem_dict['speech']})
llm_dataset = Dataset(generator([10,20,30,40,50,60]))
# eg tokenize
llm_dataset = llm_dataset.map(lambda elem: {"tokens": elem + 1 , "text": elem})
task_dataset = speech_dataset.zip(llm_dataset)
task_dataset = task_dataset.sort(lambda elem: elem[1]['tokens'])
# # 迭代并打印结果
for data in task_dataset:
print(data)
# output:
({'fbank': 3, 'speech': 2}, {'tokens': 11, 'text': 10})
({'fbank': 5, 'speech': 4}, {'tokens': 21, 'text': 20})
({'fbank': 7, 'speech': 6}, {'tokens': 31, 'text': 30})
({'fbank': 9, 'speech': 8}, {'tokens': 41, 'text': 40})
({'fbank': 3, 'speech': 2}, {'tokens': 51, 'text': 50})
是不是可以直接用lhotse呢?nemo
中也集成了lhotse-shard模式 .
是不是可以直接用lhotse呢?
nemo
中也集成了lhotse-shard模式 .
from lhotse.serialization import open_best
def iterate_tarfile_pairwise(
tar_file: tarfile.TarFile,
):
result = []
for tarinfo in tar_file:
if len(result) == 2:
yield tuple(result)
result = []
result.append(parse_tarinfo(tarinfo, tar_file))
if len(result) == 2:
yield tuple(result)
if len(result) == 1:
raise RuntimeError(
"Uneven number of files in the tarfile (expected to iterate pairs of text and binary data)."
)
def parse_tarinfo(
tarinfo: tarfile.TarInfo,
tar_file: tarfile.TarFile,
):
"""
Parse a tarinfo object and return the data it points to as well as the internal path.
"""
path = Path(tarinfo.path)
suffix = path.suffix.strip(".")
raw_data = tar_file.extractfile(tarinfo)
if suffix == "txt":
txt = raw_data.read().decode("utf-8").strip()
return (path.name, txt)
elif suffix in AUDIO_FORMAT_SETS:
waveform, sample_rate = torchaudio.load(raw_data)
return (waveform, sample_rate)
else:
raise RuntimeError(
f"Not support file format: {suffix}"
)
def parse_tar(data):
for sample in data:
assert "src" in sample, sample.keys()
url = sample["src"]
try:
with tarfile.open(fileobj=open_best(url, mode="rb"), mode="r|*") as tar:
for (key, txt), (waveform, sample_rate) in iterate_tarfile_pairwise(tar):
yield {
"key": key,
"wav": waveform,
"sample_rate": sample_rate,
"txt": txt,
}
except Exception as ex:
logging.warning(f"Failed to open {url}")
第一步,先简化解析tar包的逻辑,这样url_opener
和tar_file_and_group
可以用parse_tar
替换了
lthoste 和wenet现在的本质上没有区别,
-
要加速cpu bound 的transform 多个item, multithread 不行 需要multipricessing, (比如一个tar已经在mem里了, 现在的并发度在shard级别不在shard内部) 这时候可以多进程提多个wav特征) 二者皆有这个问题
-
要wrapper 适配各种任务, 在现有的wenet上简单抽象下就行
目前torch官方已经提供了data chain 调用方式: https://github.com/pytorch/pytorch/tree/main/torch/utils/data/datapipes (非torchdata), 并且wenet已经升级到2.xx了 所以倾向于使用torch官方的, 下边是初步代码 @xingchensong @robin1001
import io
import json
import tarfile
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, IterDataPipe
from torch.utils.data import datapipes
import torchaudio
from torchaudio._extension import logging
from torchaudio.compliance.kaldi import fbank
from wenet.dataset.processor import compute_fbank
AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
class WenetSourceDataPipe(IterDataPipe):
def __init__(self, dp, data_type='raw', **fmtparams):
self.dp = datapipes.iter.ShardingFilter(
datapipes.iter.FileOpener(dp, mode='r'))
datapipes.iter.StreamReader
self.data_type = data_type
self.fmtparams = fmtparams
def __iter__(self):
for _, stream in self.dp:
for line in stream:
line = line.strip('\n')
if self.data_type == 'raw':
json_obj = json.loads(line)
with open(json_obj['wav'], 'rb') as f:
json_obj['wav'] = f.read()
yield json_obj
else:
yield {'stream': open(line, 'rb')}
class TarFileGroupSourceDataPipe(IterDataPipe):
def __init__(self, dp) -> None:
super().__init__()
self.dp = dp
def __iter__(self):
for sample in self.dp:
try:
# stream = tarfile.open(fileobj=sample['stream'], mode="r:*")
with tarfile.open(fileobj=sample['stream'],
mode="r:*") as stream:
prev_prefix = None
example = {}
valid = True
for tarinfo in stream:
name = tarinfo.name
pos = name.rfind('.')
assert pos > 0
prefix, postfix = name[:pos], name[pos + 1:]
if prev_prefix is not None and prefix != prev_prefix:
example['key'] = prev_prefix
if valid:
yield example
example = {}
valid = True
with stream.extractfile(tarinfo) as file_obj:
try:
if postfix == 'txt':
example['txt'] = file_obj.read().decode(
'utf8').strip()
elif postfix in AUDIO_FORMAT_SETS:
example['wav'] = file_obj.read()
else:
example[postfix] = file_obj.read()
except Exception as ex:
valid = False
logging.warning(
'error to parse {}'.format(name))
prev_prefix = prefix
if prev_prefix is not None:
example['key'] = prev_prefix
yield example
except Exception as ex:
logging.warning(
'In tar_file_and_group: {} when processing '.format(
ex)) #, sample['src']))
finally:
stream.close()
if 'process' in sample:
sample['process'].communicate()
sample['stream'].close()
def decode_wav(elem):
wav = elem['wav']
key = elem['key']
txt = elem['txt']
with io.BytesIO(wav) as file_obj:
waveform, sr = torchaudio.load(file_obj)
return {"key": key, "txt": txt, 'waveform': waveform, "sample_rate": sr}
def compute_fbank(data,
num_mel_bins=23,
frame_length=25,
frame_shift=10,
dither=0.0):
sample_rate = data['sample_rate']
waveform = data['waveform']
waveform = waveform * (1 << 15)
mat = fbank(waveform,
num_mel_bins=num_mel_bins,
frame_length=frame_length,
frame_shift=frame_shift,
dither=dither,
energy_floor=0.0,
sample_frequency=sample_rate)
data['feat'] = mat
return data
def padding(data):
assert isinstance(data, list)
sample = data
feats_length = torch.tensor([x['feat'].size(0) for x in sample],
dtype=torch.int32)
order = torch.argsort(feats_length, descending=True)
feats_lengths = torch.tensor([sample[i]['feat'].size(0) for i in order],
dtype=torch.int32)
sorted_feats = [sample[i]['feat'] for i in order]
sorted_keys = [sample[i]['key'] for i in order]
padded_feats = pad_sequence(sorted_feats,
batch_first=True,
padding_value=0)
batch = {
"keys": sorted_keys,
"feats": padded_feats,
"feats_lengths": feats_lengths,
}
return batch
def get_dataloader(data_type, files):
dataset = WenetSourceDataPipe(files, data_type)
# shard by files
if data_type == 'shard':
dataset = WenetSourceDataPipe(files, data_type=data_type)
dataset = TarFileGroupSourceDataPipe(dataset)
dataset = dataset.map(decode_wav)
dataset = dataset.map(compute_fbank)
dataset = dataset.batch(wrapper_class=padding, batch_size=2)
dataloader = DataLoader(dataset,
batch_size=None,
num_workers=4,
persistent_workers=True)
return dataloader
if __name__ == '__main__':
raw_dataloader = get_dataloader('raw',
['test/resources/dataset/data.list'])
tar_dataloader = get_dataloader(
'shard', ['test/resources/dataset/data.shards.list'])
print("--------" + "wenet raw data type" + '---------\n')
for raw_batch in raw_dataloader:
print(raw_batch)
print("\n--------" + "wenet shard data type" + '---------\n')
for shard_batch in tar_dataloader:
print(shard_batch)
之后重构思路: 1 datasetsource (支持auto shard, shard by line/files) 2 processor 保持原有for-loop 内部的“elem” 处理逻辑, 使用map filter 等方式调用
优势:
- 比如whisper hybrid tokenizer 可以自己 构造自己任务: 1 datasetsource 2 feats 3 hybrid tokenizer (只需要自己写’elem‘的函数, 然后map下) 4 batch 同时复用现在yaml里边
dataset: whisper_dataset. # 以前是ASRDataset
- 对于tts/llm, 同上
目前torch官方已经提供了data chain 调用方式: https://github.com/pytorch/pytorch/tree/main/torch/utils/data/datapipes (非torchdata), 并且wenet已经升级到2.xx了 所以倾向于使用torch官方的, 下边是初步代码 @xingchensong @robin1001
import io import json import tarfile import torch from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader, IterDataPipe from torch.utils.data import datapipes import torchaudio from torchaudio._extension import logging from torchaudio.compliance.kaldi import fbank from wenet.dataset.processor import compute_fbank AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma']) class WenetSourceDataPipe(IterDataPipe): def __init__(self, dp, data_type='raw', **fmtparams): self.dp = datapipes.iter.ShardingFilter( datapipes.iter.FileOpener(dp, mode='r')) datapipes.iter.StreamReader self.data_type = data_type self.fmtparams = fmtparams def __iter__(self): for _, stream in self.dp: for line in stream: line = line.strip('\n') if self.data_type == 'raw': json_obj = json.loads(line) with open(json_obj['wav'], 'rb') as f: json_obj['wav'] = f.read() yield json_obj else: yield {'stream': open(line, 'rb')} class TarFileGroupSourceDataPipe(IterDataPipe): def __init__(self, dp) -> None: super().__init__() self.dp = dp def __iter__(self): for sample in self.dp: try: # stream = tarfile.open(fileobj=sample['stream'], mode="r:*") with tarfile.open(fileobj=sample['stream'], mode="r:*") as stream: prev_prefix = None example = {} valid = True for tarinfo in stream: name = tarinfo.name pos = name.rfind('.') assert pos > 0 prefix, postfix = name[:pos], name[pos + 1:] if prev_prefix is not None and prefix != prev_prefix: example['key'] = prev_prefix if valid: yield example example = {} valid = True with stream.extractfile(tarinfo) as file_obj: try: if postfix == 'txt': example['txt'] = file_obj.read().decode( 'utf8').strip() elif postfix in AUDIO_FORMAT_SETS: example['wav'] = file_obj.read() else: example[postfix] = file_obj.read() except Exception as ex: valid = False logging.warning( 'error to parse {}'.format(name)) prev_prefix = prefix if prev_prefix is not None: example['key'] = prev_prefix yield example except Exception as ex: logging.warning( 'In tar_file_and_group: {} when processing '.format( ex)) #, sample['src'])) finally: stream.close() if 'process' in sample: sample['process'].communicate() sample['stream'].close() def decode_wav(elem): wav = elem['wav'] key = elem['key'] txt = elem['txt'] with io.BytesIO(wav) as file_obj: waveform, sr = torchaudio.load(file_obj) return {"key": key, "txt": txt, 'waveform': waveform, "sample_rate": sr} def compute_fbank(data, num_mel_bins=23, frame_length=25, frame_shift=10, dither=0.0): sample_rate = data['sample_rate'] waveform = data['waveform'] waveform = waveform * (1 << 15) mat = fbank(waveform, num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither, energy_floor=0.0, sample_frequency=sample_rate) data['feat'] = mat return data def padding(data): assert isinstance(data, list) sample = data feats_length = torch.tensor([x['feat'].size(0) for x in sample], dtype=torch.int32) order = torch.argsort(feats_length, descending=True) feats_lengths = torch.tensor([sample[i]['feat'].size(0) for i in order], dtype=torch.int32) sorted_feats = [sample[i]['feat'] for i in order] sorted_keys = [sample[i]['key'] for i in order] padded_feats = pad_sequence(sorted_feats, batch_first=True, padding_value=0) batch = { "keys": sorted_keys, "feats": padded_feats, "feats_lengths": feats_lengths, } return batch def get_dataloader(data_type, files): dataset = WenetSourceDataPipe(files, data_type) # shard by files if data_type == 'shard': dataset = WenetSourceDataPipe(files, data_type=data_type) dataset = TarFileGroupSourceDataPipe(dataset) dataset = dataset.map(decode_wav) dataset = dataset.map(compute_fbank) dataset = dataset.batch(wrapper_class=padding, batch_size=2) dataloader = DataLoader(dataset, batch_size=None, num_workers=4, persistent_workers=True) return dataloader if __name__ == '__main__': raw_dataloader = get_dataloader('raw', ['test/resources/dataset/data.list']) tar_dataloader = get_dataloader( 'shard', ['test/resources/dataset/data.shards.list']) print("--------" + "wenet raw data type" + '---------\n') for raw_batch in raw_dataloader: print(raw_batch) print("\n--------" + "wenet shard data type" + '---------\n') for shard_batch in tar_dataloader: print(shard_batch)
之后重构思路: 1 datasetsource (支持auto shard, shard by line/files) 2 processor 保持原有for-loop 内部的“elem” 处理逻辑, 使用map filter 等方式调用
优势:
- 比如whisper hybrid tokenizer 可以自己 构造自己任务: 1 datasetsource 2 feats 3 hybrid tokenizer (只需要自己写’elem‘的函数, 然后map下) 4 batch 同时复用现在yaml里边
dataset: whisper_dataset. # 以前是ASRDataset
- 对于tts/llm, 同上
WenetSourceDataPipe
中的文件句柄stream
可以用torch
中的StreamWrapper
封装,避免在其他位置手动stream.close()
;
from torch.utils.data.datapipes.utils.common import StreamWrapper
stream = StreamWrapper(open(line, 'rb'))