wenet icon indicating copy to clipboard operation
wenet copied to clipboard

IO 重构,提升多机多卡训练效率 + 代码复用

Open robin1001 opened this issue 1 year ago • 20 comments

robin1001 avatar Nov 18 '23 02:11 robin1001

现在 IO 的主要瓶颈是在哪里?

robin1001 avatar Nov 18 '23 02:11 robin1001

  1. https://github.com/wenet-e2e/wenet/issues/2095
  2. gpu利用率吃不满,尖峰很多
  3. 未来引入多模态之后,io还要考虑其他模态

xingchensong avatar Nov 18 '23 04:11 xingchensong

星辰提到的3, 可以加个多任务, 下边举个例子 以下是来自chatgpt的回答

在多任务训练中,批次(batch)的组织和损失函数的定义需要考虑多个任务之间的关系和权衡。下面是一些常见的方法:

批次组织:在组织批次时,可以采取以下策略:

同时从每个任务的数据集中抽取一定数量的样本,形成一个批次。这种方法适用于任务之间的数据量相近或需要平衡处理。
为每个任务设置不同的批次大小,根据任务的重要性或数据分布进行调整。较重要的任务可以分配更大的批次大小,以便更充分地更新模型参数。
根据硬件资源和内存限制,将批次大小限制在可接受的范围内。
......

这个例子中,我们需要很方便初始化dataset1 dataset2 dataset3.

dataset = interleave(dataset1, dataset2, dataset3)

我们需要这要的功能,并且不止是多任务, 对于语音任务 在一个batch上我们也需要组合不同领域的数据进行tune

Mddct avatar Nov 20 '23 12:11 Mddct

另外,补充下 , 现在recognize.py 依赖的dataset 需要有label, 重构的dataset需要考虑单纯的infer 比如:

dataset(split='train') # speech and label
dataset(split='cv') # speech and label
dataset(split='infer') # only speech

对于和大模型的结合:

textdataset # only text
speechdataset # only speech

...

Mddct avatar Nov 25 '23 13:11 Mddct

当前瓶颈 在于:

1 同步的取数据 (及时有prefetch) 2 可以有并行的地方, 比如已经下载到内存的shard/wavs 可以parallel decode fbank等(现在极度依赖num_worker, num_worker 过大会有bus error问题 (并且num worker并行多的话 tar多了 , 我们也需要的是非tar级的并发), 如果并行起来 就可能是num_worker * parallel_num),

方案一:集成huggingface datasets:

1 优点:

  • 原生函数:
    • map (parallel_map)
    • filter
    • sort (此函数有坑, 没有buffer size, 还没看具体实现)
    • shuffle
    • interleave
    • etc

这些原生函数的实现,dtaset 会极度简化 和多模态结合 也会很方便 预计会减少gpu尖峰, 但是不会像tfdata那样(没有解藕生产消费):

2 缺点:

  • prefetch 依赖pytorch 的 DataLoader prefetch
  • 不支持tar , 原生支持arrow, 如果要支持tar, 需要自己写 _generate_examples, 但是yield是python 这里慢的话 上述能够并行的地方也会被限制,(原生arrow底层是c++实现)

方案二: 造轮子

抽象现有的process函数补&&齐&添加新函数

  • parallel_map (multi_processing)
  • filter
  • sort
  • interleave
  • etc

缺点:

  • 新功能需要一直添加

优点:

  • 简洁
  • 支持现有的所有io形式
  • 其他模态也可以支持

ref: https://huggingface.co/docs/datasets/stream

补充下:大模型时代,好多吧动态batch这种功能给去掉了,hg也不支持(tfdata 支持), 比如whisper 直接encoder 30s , LLM也有直接pad到预先设置的最大(非batch)

Mddct avatar Nov 28 '23 08:11 Mddct

image

方案二能快速支持arrow格式的数据吗

xingchensong avatar Nov 28 '23 11:11 xingchensong

image

方案二能快速支持arrow格式的数据吗

应该可以的, hg也是用的arrow的py包

ref: https://arrow.apache.org/docs/python/

Mddct avatar Nov 28 '23 12:11 Mddct

可以,那支持方案二

xingchensong avatar Nov 28 '23 12:11 xingchensong

新IO需要考虑determinism(data 层面保证可复现) shuffle 需要设置seed

Mddct avatar Nov 29 '23 03:11 Mddct

测试multiprocess 和multithread:

100条音频, 每条9s , ‘并发“计算fbank

Screenshot 2023-11-29 at 19 40 09
  • 单条计算0.02,
  • 100条顺序算为2左右,
  • 多线程(nthread=100)也为2左右, (cpu bound 变成串行 gil)
  • 多进程(nproc=100)为0.1

Mddct avatar Nov 29 '23 11:11 Mddct

关于代码复用性, 下边是代码片段

# 这里我们可以继承IterableDataset
class Dataset():
  def __init__(self, dataset, func=None, *args, **kwargs):
    self._dataset = dataset
    self.args = args
    self.kwargs = kwargs
    self.func = func

  @staticmethod
  def from_source(source):
    return Dataset(source)

  def __iter__(self):
    return self

  def __next__(self):
    if not self._dataset:
      raise StopIteration
    data = next(self._dataset)
    return self.func(data)

  def map(self, func, *args, **kwargs):
    return MapperDataset(self, func, *args, **kwargs)

  def filter(self, func, *args, **kwargs):
    return FilterDataset(self, func, *args, **kwargs)

class MapperDataset(Dataset):
    def __init__(self, dataset, func=None, *args, **kwargs):
      self._dataset = dataset
      self.args = args
      self.kwargs = kwargs
      self.func = func

    def __iter__(self):
      return self

    def __next__(self):
      if not self._dataset:
        raise StopIteration
      data = next(self._dataset)
      return self.func(data)

class FilterDataset(Dataset):
    def __init__(self, dataset, func=None, *args, **kwargs):
      self._dataset = dataset
      self.args = args
      self.kwargs = kwargs
      self.func = func

    def __iter__(self):
      return self

    def __next__(self):
      if not self._dataset:
        raise StopIteration
      data = next(self._dataset)
      while not self.func(data):
        data = next(self._dataset)
      return data

source = iter([1,2,3,4])
dataset = Dataset(source, lambda elem: elem)
dataset = dataset.map(lambda elem: {"speech": elem*2})
dataset = dataset.filter(lambda elem_dict: elem_dict['speech'] > 2)
for d in dataset:
  print(d)

# output:
{'speech': 4}
{'speech': 6}
{'speech': 8}

Mddct avatar Nov 30 '23 06:11 Mddct

wenet的训练脚本在evaluate时只能单卡去过整个eval dataset,这部分感觉也是可以优化的,是否可以通过继承torch的sampler来实现利用DDP加速。比如:

from torch.utils.data.distributed import DistributedSampler
from catalyst.data.sampler import DistributedSamplerWrapper

dataset = ...
shuffle = ...
sampler = ...

# If DDP on
if torch.distributed.is_initialized():
    # If using a custom sampler make it distributed
    if sampler is not None:
        sampler = DistributedSamplerWrapper(sampler,
                                            shuffle=shuffle,
                                            num_replicas=communication.get_world_size(),
                                            rank=communication.get_rank())
    # If no custom sampler then just use the DistributedSampler
    else:
        sampler = DistributedSampler(dataset,
                                     shuffle=shuffle,
                                     num_replicas=communication.get_world_size(),
                                     rank=communication.get_rank())

# shuffle shouldn't be specified in DataLoader when using a sampler
shuffle = shuffle if sampler is None else None
dataloader = DataLoader(dataset, sampler=sampler, shuffle=shuffle, ...)

echocatzh avatar Dec 03 '23 07:12 echocatzh

是的,现在 eval 的时候是单卡 eval 所有,没有发挥多卡优势。这块是个优化点。

robin1001 avatar Dec 03 '23 08:12 robin1001

是的,现在 eval 的时候是单卡 eval 所有,没有发挥多卡优势。这块是个优化点。

期待wenet的新版脚本哈哈

echocatzh avatar Dec 03 '23 09:12 echocatzh

modified code:

class Dataset:

    def __init__(self, source, f=lambda elem: elem, *args, **kw):
        assert callable(f)
        self._dataset = source
        self.f = f
        self.args = args
        self.kw = kw

    def set_epoch(self, epoch):
        self.source.set_epoch(epoch)

    def __iter__(self):
        """ Return an iterator over the source dataset processed by the
            given processor.
        """
        assert self._dataset is not None
        assert callable(self.f)
        for data in self._dataset:
          yield data

    def apply(self, f):
        assert callable(f)
        return Dataset(self, f, *self.args, **self.kw)

    def map(self, func, *args, **kwargs):
        return MapperDataset(self, func, *args, **kwargs)

    def filter(self, func, *args, **kwargs):
        return FilterDataset(self, func, *args, **kwargs)

    def sort(self, func, *args, **kwargs):
        return SortDataset(self, func, *args, **kwargs)

    def zip(self, *datasets):
      return ZipDataset(self, *datasets)



class MapperDataset(Dataset):
    def __init__(self, dataset, func=None, *args, **kwargs):
        self._dataset = dataset
        self.args = args
        self.kwargs = kwargs
        self.func = func

    def __iter__(self):
        return self._generator()

    def _generator(self):
        for data in self._dataset:
            yield self.func(data, *self.args, **self.kwargs)

class FilterDataset(Dataset):
    def __init__(self, dataset, func=None, *args, **kwargs):
        self._dataset = dataset
        self.args = args
        self.kwargs = kwargs
        self.func = func

    def __iter__(self):
        return self._generator()

    def _generator(self):
        for data in self._dataset:
            if self.func(data, *self.args, **self.kwargs):
                yield data

class SortDataset(Dataset):
    def __init__(self, dataset, key=None, reverse=False, buffer_size=None):
        self._dataset = dataset
        self.key = key
        self.reverse = reverse
        self.buffer_size = buffer_size

    def __iter__(self):
        return self._generator()

    def _generator(self):
        buffer = []
        for data in self._dataset:
            buffer.append(data)
            if self.buffer_size is not None and len(buffer) >= self.buffer_size:
                sorted_buffer = sorted(buffer, key=self.key, reverse=self.reverse)
                for sorted_data in sorted_buffer:
                    yield sorted_data
                buffer.clear()
        if buffer:
            sorted_buffer = sorted(buffer, key=self.key, reverse=self.reverse)
            for sorted_data in sorted_buffer:
                yield sorted_data

class ZipDataset(Dataset):
    def __init__(self, *datasets):
        self.datasets = datasets

    def __iter__(self):
        return self._generator()

    def _generator(self):
        iterators = [iter(dataset) for dataset in self.datasets]
        while True:
            try:
                data = [next(iterator) for iterator in iterators]
                yield tuple(data)
            except StopIteration:
                return

class PaddingBatchDataset(Dataset):
    def __init__(self, dataset, batch_size, padding_fn, max_length_fn):
        self.dataset = dataset
        self.batch_size = batch_size
        self.padding_fn = padding_fn
        self.max_length_fn = max_length_fn

    def __iter__(self):
        return self._generator()

    def _generator(self):
        batch = []
        max_length = 0
        for data in self.dataset:
            batch.append(data)
            max_length = self.max_length_fn(data, max_length)
            if len(batch) == self.batch_size:
                padded_batch = self._pad_batch(batch, max_length)
                yield padded_batch
                batch = []
                max_length = 0
        if batch:
            padded_batch = self._pad_batch(batch, max_length)
            yield padded_batch

    def _pad_batch(self, batch, max_length):
        padded_batch = []
        for data in batch:
            padding_length = max_length - self.max_length_fn(data)
            padded_data = self.padding_fn(data, padding_length)
            padded_batch.append(padded_data)
        return padded_batch
  
# 创建数据源
def generator(data):
  for d in data:
    yield d

source = generator([1,2,3,4,1])

# 创建 Dataset 实例
speech_dataset = Dataset(source)

# preprocess
speech_dataset = speech_dataset.map(lambda elem: {"speech": elem * 2})
speech_dataset = speech_dataset.filter(lambda elem_dict: elem_dict['speech'] >= 2)
speech_dataset = speech_dataset.sort(lambda elem_dict: elem_dict['speech'], buffer_size=2)
# fbank
speech_dataset = speech_dataset.map(lambda elem_dict: {'fbank': elem_dict['speech'] + 1, 'speech': elem_dict['speech']})

llm_dataset = Dataset(generator([10,20,30,40,50,60]))
# eg tokenize
llm_dataset = llm_dataset.map(lambda elem: {"tokens": elem + 1 , "text": elem})

task_dataset = speech_dataset.zip(llm_dataset)
task_dataset = task_dataset.sort(lambda elem: elem[1]['tokens'])

# # 迭代并打印结果
for data in task_dataset:
    print(data)
# output:
({'fbank': 3, 'speech': 2}, {'tokens': 11, 'text': 10})
({'fbank': 5, 'speech': 4}, {'tokens': 21, 'text': 20})
({'fbank': 7, 'speech': 6}, {'tokens': 31, 'text': 30})
({'fbank': 9, 'speech': 8}, {'tokens': 41, 'text': 40})
({'fbank': 3, 'speech': 2}, {'tokens': 51, 'text': 50})

Mddct avatar Dec 06 '23 11:12 Mddct

是不是可以直接用lhotse呢?nemo中也集成了lhotse-shard模式 .

kobenaxie avatar Jan 15 '24 11:01 kobenaxie

是不是可以直接用lhotse呢?nemo中也集成了lhotse-shard模式 .

from lhotse.serialization import open_best

def iterate_tarfile_pairwise(
     tar_file: tarfile.TarFile,
 ):
     result = []
     for tarinfo in tar_file:
         if len(result) == 2:
             yield tuple(result)
             result = []
         result.append(parse_tarinfo(tarinfo, tar_file))

     if len(result) == 2:
         yield tuple(result)

     if len(result) == 1:
         raise RuntimeError(
             "Uneven number of files in the tarfile (expected to iterate pairs of text and binary data)."
         )

def parse_tarinfo(
     tarinfo: tarfile.TarInfo,
     tar_file: tarfile.TarFile,
 ):
     """
     Parse a tarinfo object and return the data it points to as well as the internal path.
     """
     path = Path(tarinfo.path)
     suffix = path.suffix.strip(".")

     raw_data = tar_file.extractfile(tarinfo)
     if suffix == "txt":
         txt = raw_data.read().decode("utf-8").strip()
         return (path.name, txt)
     elif suffix in AUDIO_FORMAT_SETS:
         waveform, sample_rate = torchaudio.load(raw_data)
         return (waveform, sample_rate)
     else:
         raise RuntimeError(
             f"Not support file format: {suffix}"
         )

def parse_tar(data):
     for sample in data:
         assert "src" in sample, sample.keys()
         url = sample["src"]
         try:
             with tarfile.open(fileobj=open_best(url, mode="rb"), mode="r|*") as tar:
                 for (key, txt), (waveform, sample_rate) in iterate_tarfile_pairwise(tar):
                     yield {
                         "key": key,
                         "wav": waveform,
                         "sample_rate": sample_rate,
                         "txt": txt,
                     }
         except Exception as ex:
             logging.warning(f"Failed to open {url}")

第一步,先简化解析tar包的逻辑,这样url_openertar_file_and_group可以用parse_tar替换了

kobenaxie avatar Jan 17 '24 03:01 kobenaxie

lthoste 和wenet现在的本质上没有区别,

  • 要加速cpu bound 的transform 多个item, multithread 不行 需要multipricessing, (比如一个tar已经在mem里了, 现在的并发度在shard级别不在shard内部) 这时候可以多进程提多个wav特征) 二者皆有这个问题

  • 要wrapper 适配各种任务, 在现有的wenet上简单抽象下就行

Mddct avatar Jan 17 '24 05:01 Mddct

目前torch官方已经提供了data chain 调用方式: https://github.com/pytorch/pytorch/tree/main/torch/utils/data/datapipes (非torchdata), 并且wenet已经升级到2.xx了 所以倾向于使用torch官方的, 下边是初步代码 @xingchensong @robin1001

import io
import json
import tarfile
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, IterDataPipe
from torch.utils.data import datapipes
import torchaudio
from torchaudio._extension import logging
from torchaudio.compliance.kaldi import fbank

from wenet.dataset.processor import compute_fbank

AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])


class WenetSourceDataPipe(IterDataPipe):

    def __init__(self, dp, data_type='raw', **fmtparams):
        self.dp = datapipes.iter.ShardingFilter(
            datapipes.iter.FileOpener(dp, mode='r'))
        datapipes.iter.StreamReader
        self.data_type = data_type
        self.fmtparams = fmtparams

    def __iter__(self):
        for _, stream in self.dp:
            for line in stream:
                line = line.strip('\n')
                if self.data_type == 'raw':
                    json_obj = json.loads(line)
                    with open(json_obj['wav'], 'rb') as f:
                        json_obj['wav'] = f.read()
                    yield json_obj
                else:
                    yield {'stream': open(line, 'rb')}


class TarFileGroupSourceDataPipe(IterDataPipe):

    def __init__(self, dp) -> None:
        super().__init__()
        self.dp = dp

    def __iter__(self):
        for sample in self.dp:
            try:
                # stream = tarfile.open(fileobj=sample['stream'], mode="r:*")
                with tarfile.open(fileobj=sample['stream'],
                                  mode="r:*") as stream:
                    prev_prefix = None
                    example = {}
                    valid = True
                    for tarinfo in stream:
                        name = tarinfo.name
                        pos = name.rfind('.')
                        assert pos > 0
                        prefix, postfix = name[:pos], name[pos + 1:]
                        if prev_prefix is not None and prefix != prev_prefix:
                            example['key'] = prev_prefix
                            if valid:
                                yield example
                            example = {}
                            valid = True
                        with stream.extractfile(tarinfo) as file_obj:
                            try:
                                if postfix == 'txt':
                                    example['txt'] = file_obj.read().decode(
                                        'utf8').strip()
                                elif postfix in AUDIO_FORMAT_SETS:
                                    example['wav'] = file_obj.read()
                                else:
                                    example[postfix] = file_obj.read()
                            except Exception as ex:
                                valid = False
                                logging.warning(
                                    'error to parse {}'.format(name))
                            prev_prefix = prefix
                    if prev_prefix is not None:
                        example['key'] = prev_prefix
                        yield example
            except Exception as ex:
                logging.warning(
                    'In tar_file_and_group: {} when processing '.format(
                        ex))  #, sample['src']))
            finally:
                stream.close()
                if 'process' in sample:
                    sample['process'].communicate()
                sample['stream'].close()


def decode_wav(elem):
    wav = elem['wav']
    key = elem['key']
    txt = elem['txt']
    with io.BytesIO(wav) as file_obj:
        waveform, sr = torchaudio.load(file_obj)
    return {"key": key, "txt": txt, 'waveform': waveform, "sample_rate": sr}


def compute_fbank(data,
                  num_mel_bins=23,
                  frame_length=25,
                  frame_shift=10,
                  dither=0.0):

    sample_rate = data['sample_rate']
    waveform = data['waveform']
    waveform = waveform * (1 << 15)
    mat = fbank(waveform,
                num_mel_bins=num_mel_bins,
                frame_length=frame_length,
                frame_shift=frame_shift,
                dither=dither,
                energy_floor=0.0,
                sample_frequency=sample_rate)
    data['feat'] = mat
    return data


def padding(data):
    assert isinstance(data, list)
    sample = data
    feats_length = torch.tensor([x['feat'].size(0) for x in sample],
                                dtype=torch.int32)
    order = torch.argsort(feats_length, descending=True)
    feats_lengths = torch.tensor([sample[i]['feat'].size(0) for i in order],
                                 dtype=torch.int32)
    sorted_feats = [sample[i]['feat'] for i in order]
    sorted_keys = [sample[i]['key'] for i in order]
    padded_feats = pad_sequence(sorted_feats,
                                batch_first=True,
                                padding_value=0)
    batch = {
        "keys": sorted_keys,
        "feats": padded_feats,
        "feats_lengths": feats_lengths,
    }
    return batch


def get_dataloader(data_type, files):
    dataset = WenetSourceDataPipe(files, data_type)
    # shard by files
    if data_type == 'shard':
        dataset = WenetSourceDataPipe(files, data_type=data_type)

        dataset = TarFileGroupSourceDataPipe(dataset)

    dataset = dataset.map(decode_wav)
    dataset = dataset.map(compute_fbank)
    dataset = dataset.batch(wrapper_class=padding, batch_size=2)
    dataloader = DataLoader(dataset,
                            batch_size=None,
                            num_workers=4,
                            persistent_workers=True)

    return dataloader


if __name__ == '__main__':
    raw_dataloader = get_dataloader('raw',
                                    ['test/resources/dataset/data.list'])
    tar_dataloader = get_dataloader(
        'shard', ['test/resources/dataset/data.shards.list'])

    print("--------" + "wenet raw data type" + '---------\n')
    for raw_batch in raw_dataloader:
        print(raw_batch)

    print("\n--------" + "wenet shard data type" + '---------\n')
    for shard_batch in tar_dataloader:
        print(shard_batch)

Screenshot 2024-01-21 at 01 40 31

之后重构思路: 1 datasetsource (支持auto shard, shard by line/files) 2 processor 保持原有for-loop 内部的“elem” 处理逻辑, 使用map filter 等方式调用

优势:

  • 比如whisper hybrid tokenizer 可以自己 构造自己任务: 1 datasetsource 2 feats 3 hybrid tokenizer (只需要自己写’elem‘的函数, 然后map下) 4 batch 同时复用现在yaml里边
dataset: whisper_dataset. # 以前是ASRDataset
  • 对于tts/llm, 同上

Mddct avatar Jan 20 '24 17:01 Mddct

目前torch官方已经提供了data chain 调用方式: https://github.com/pytorch/pytorch/tree/main/torch/utils/data/datapipes (非torchdata), 并且wenet已经升级到2.xx了 所以倾向于使用torch官方的, 下边是初步代码 @xingchensong @robin1001

import io
import json
import tarfile
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, IterDataPipe
from torch.utils.data import datapipes
import torchaudio
from torchaudio._extension import logging
from torchaudio.compliance.kaldi import fbank

from wenet.dataset.processor import compute_fbank

AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])


class WenetSourceDataPipe(IterDataPipe):

    def __init__(self, dp, data_type='raw', **fmtparams):
        self.dp = datapipes.iter.ShardingFilter(
            datapipes.iter.FileOpener(dp, mode='r'))
        datapipes.iter.StreamReader
        self.data_type = data_type
        self.fmtparams = fmtparams

    def __iter__(self):
        for _, stream in self.dp:
            for line in stream:
                line = line.strip('\n')
                if self.data_type == 'raw':
                    json_obj = json.loads(line)
                    with open(json_obj['wav'], 'rb') as f:
                        json_obj['wav'] = f.read()
                    yield json_obj
                else:
                    yield {'stream': open(line, 'rb')}


class TarFileGroupSourceDataPipe(IterDataPipe):

    def __init__(self, dp) -> None:
        super().__init__()
        self.dp = dp

    def __iter__(self):
        for sample in self.dp:
            try:
                # stream = tarfile.open(fileobj=sample['stream'], mode="r:*")
                with tarfile.open(fileobj=sample['stream'],
                                  mode="r:*") as stream:
                    prev_prefix = None
                    example = {}
                    valid = True
                    for tarinfo in stream:
                        name = tarinfo.name
                        pos = name.rfind('.')
                        assert pos > 0
                        prefix, postfix = name[:pos], name[pos + 1:]
                        if prev_prefix is not None and prefix != prev_prefix:
                            example['key'] = prev_prefix
                            if valid:
                                yield example
                            example = {}
                            valid = True
                        with stream.extractfile(tarinfo) as file_obj:
                            try:
                                if postfix == 'txt':
                                    example['txt'] = file_obj.read().decode(
                                        'utf8').strip()
                                elif postfix in AUDIO_FORMAT_SETS:
                                    example['wav'] = file_obj.read()
                                else:
                                    example[postfix] = file_obj.read()
                            except Exception as ex:
                                valid = False
                                logging.warning(
                                    'error to parse {}'.format(name))
                            prev_prefix = prefix
                    if prev_prefix is not None:
                        example['key'] = prev_prefix
                        yield example
            except Exception as ex:
                logging.warning(
                    'In tar_file_and_group: {} when processing '.format(
                        ex))  #, sample['src']))
            finally:
                stream.close()
                if 'process' in sample:
                    sample['process'].communicate()
                sample['stream'].close()


def decode_wav(elem):
    wav = elem['wav']
    key = elem['key']
    txt = elem['txt']
    with io.BytesIO(wav) as file_obj:
        waveform, sr = torchaudio.load(file_obj)
    return {"key": key, "txt": txt, 'waveform': waveform, "sample_rate": sr}


def compute_fbank(data,
                  num_mel_bins=23,
                  frame_length=25,
                  frame_shift=10,
                  dither=0.0):

    sample_rate = data['sample_rate']
    waveform = data['waveform']
    waveform = waveform * (1 << 15)
    mat = fbank(waveform,
                num_mel_bins=num_mel_bins,
                frame_length=frame_length,
                frame_shift=frame_shift,
                dither=dither,
                energy_floor=0.0,
                sample_frequency=sample_rate)
    data['feat'] = mat
    return data


def padding(data):
    assert isinstance(data, list)
    sample = data
    feats_length = torch.tensor([x['feat'].size(0) for x in sample],
                                dtype=torch.int32)
    order = torch.argsort(feats_length, descending=True)
    feats_lengths = torch.tensor([sample[i]['feat'].size(0) for i in order],
                                 dtype=torch.int32)
    sorted_feats = [sample[i]['feat'] for i in order]
    sorted_keys = [sample[i]['key'] for i in order]
    padded_feats = pad_sequence(sorted_feats,
                                batch_first=True,
                                padding_value=0)
    batch = {
        "keys": sorted_keys,
        "feats": padded_feats,
        "feats_lengths": feats_lengths,
    }
    return batch


def get_dataloader(data_type, files):
    dataset = WenetSourceDataPipe(files, data_type)
    # shard by files
    if data_type == 'shard':
        dataset = WenetSourceDataPipe(files, data_type=data_type)

        dataset = TarFileGroupSourceDataPipe(dataset)

    dataset = dataset.map(decode_wav)
    dataset = dataset.map(compute_fbank)
    dataset = dataset.batch(wrapper_class=padding, batch_size=2)
    dataloader = DataLoader(dataset,
                            batch_size=None,
                            num_workers=4,
                            persistent_workers=True)

    return dataloader


if __name__ == '__main__':
    raw_dataloader = get_dataloader('raw',
                                    ['test/resources/dataset/data.list'])
    tar_dataloader = get_dataloader(
        'shard', ['test/resources/dataset/data.shards.list'])

    print("--------" + "wenet raw data type" + '---------\n')
    for raw_batch in raw_dataloader:
        print(raw_batch)

    print("\n--------" + "wenet shard data type" + '---------\n')
    for shard_batch in tar_dataloader:
        print(shard_batch)
Screenshot 2024-01-21 at 01 40 31 之后重构思路: 1 datasetsource (支持auto shard, shard by line/files) 2 processor 保持原有for-loop 内部的“elem” 处理逻辑, 使用map filter 等方式调用

优势:

  • 比如whisper hybrid tokenizer 可以自己 构造自己任务: 1 datasetsource 2 feats 3 hybrid tokenizer (只需要自己写’elem‘的函数, 然后map下) 4 batch 同时复用现在yaml里边
dataset: whisper_dataset. # 以前是ASRDataset
  • 对于tts/llm, 同上

WenetSourceDataPipe中的文件句柄stream可以用torch中的StreamWrapper封装,避免在其他位置手动stream.close();

from torch.utils.data.datapipes.utils.common import StreamWrapper
stream = StreamWrapper(open(line, 'rb'))

kobenaxie avatar Jan 23 '24 02:01 kobenaxie