streaming icon indicating copy to clipboard operation
streaming copied to clipboard

Audio?

Open cinjon opened this issue 1 year ago • 4 comments

🚀 Feature Request

Are there any examples with StreamingDataset using audio? I'm trying to see how I'd do that with the MDSWriter. I think straight bytes is best, but not sure.

cinjon avatar Aug 24 '23 21:08 cinjon

Hi @cinjon , what sort of audio files do you have ? One way to do is to write your own MDS encoder/decoder class which convert the audio to bytes during writing to file and bytes to audio during reading the data so that you don't have to convert it manually. Something like this

import shutil

from streaming.base.format.mds.encodings import Encoding, _encodings, mds_encode, mds_decode
from streaming import MDSWriter, StreamingDataset

class StrInt(Encoding):
    """Store int as variable-length digits str."""

    def encode(self, obj: int) -> bytes:
        self._validate(obj, int)
        text = str(obj)
        return text.encode('utf-8')

    def decode(self, data: bytes) -> int:
        text = data.decode('utf-8')
        return int(text)

_encodings['strint'] = StrInt

if __name__ == '__main__':
    print(set(_encodings))

    data = [{'x': i} for i in range(100)]

    columns = {
        'x': 'strint',
    }

    out_root = 'dirname'

    with MDSWriter(out=out_root, columns=columns) as out:
        for sample in data:
            out.write(sample)

    dataset = StreamingDataset(local=out_root, num_canonical_nodes=1)

    for sample in dataset:
        print(f'sample: {sample["x"]}\ttype: {type(sample["x"])}')

    # Clean up
    shutil.rmtree(out_root)

    # Validation
    sample = 200
    foo1 = mds_encode('strint', sample)
    print(f'value: {foo1}\ttype: {type(foo1)}') # bytes

    foo2 = mds_decode('strint', foo1)
    print(f'value: {foo2}\ttype: {type(foo2)}') # int

karan6181 avatar Aug 24 '23 22:08 karan6181

@cinjon , does the above solution helped you in any way? or would love to hear if you have figure out an alternative solution !! Thanks!

karan6181 avatar Sep 11 '23 15:09 karan6181

Hey, sorry for the delay. We went a different route and are reconsidering this now.

What you wrote here, plus the CIFAR tutorial, seems to cover most of the bases: https://docs.mosaicml.com/projects/streaming/en/stable/examples/cifar10.html. One thing I'm unclear on though is that for each of my audio files I have many data points; how do I "expand" the resulting StreamingDataset to accommodate this?

cinjon avatar Oct 31 '23 01:10 cinjon

@karan6181 I put together a setup to do this. It works, but it's still quite slow, even with 11 cpu cores going into the DataLoader and pin_memory=True. Any idea how to speed it up? I'm getting ~10s time to build a batch of 256.

class MyStreamingDataset(streaming.StreamingDataset):
  def __init__(self, local, remote, shuffle):
    super().__init__(local=local, remote=remote, shuffle=shuffle)

  def __getitem__(self, idx: int) -> Any:
    # columns = {
    #     'start_time': 'float32',
    #     'key': 'str',
    #     'end_time': 'float32',
    #     'label': 'int8',
    #     'wav': 'bytes',
    # }
    obj = super().__getitem__(idx)
      
    end_time = obj['end_time']
    start_time = obj['start_time']
    label = obj['label']    
    wav = io.BytesIO(obj['wav'])

    # window_in_secs = 5, and the loaded wav is 6s long.
    relative_start_time = end_time - window_in_secs - start_time
    if label:
      # Do a positive sample, can only use a small part of the sample.
      max_reduction = min(relative_start_time, predict_secs)
      this_start_time = relative_start_time - max_reduction * random.random()
      offset = int(target_sr * this_start_time)
      label = torch.tensor(1, dtype=torch.int64)
    else:
      # Do a negative sample. Here, the entire sample is fair game.
      max_reduction = relative_start_time
      this_start_time = random.random() * relative_start_time
      offset = int(target_sr * this_start_time)
      label = torch.tensor(0, dtype=torch.int64)

    num_frames = window_in_secs * target_sr
    # NOTE: This loading step takes .01 seconds by itself. That's sadface, but amongst 12 workers, this should come to ~.21 seconds of the whole operation.
    wav, sr = torchaudio.load(wav, frame_offset=offset, num_frames=num_frames)
    wav = wav.mean(axis=0, keepdims=True)
    return wav, label

cinjon avatar Nov 01 '23 18:11 cinjon

Hey @cinjon, were you able to resolve this? What was your approach?

snarayan21 avatar May 29 '24 19:05 snarayan21

@cinjon Closing this due to inactivity. Please feel free to re-open. Thank You!

karan6181 avatar Jul 23 '24 03:07 karan6181