data icon indicating copy to clipboard operation
data copied to clipboard

Cannot read from parquet files that contain binary

Open vedantroy opened this issue 1 year ago • 2 comments

🐛 Describe the bug

I create some parquet files with the following:

def save_tensor(t):
    buf = io.BytesIO()
    th.save(t, buf)
    return buf.getvalue()
    for idx, batch in enumerate(tqdm(dl, total=total_files / args.batch_size)):
        df = pa.table({"img": [save_tensor(item) for item in batch]})
        pq.write_table(df, out_dir / f"{idx}.parquet")

when I try to read the parquet files, with the following pipe:

    datapipe = dp.iter.FSSpecFileLister(dir)
    datapipe = datapipe.load_parquet_as_df()

I get the error:

NotImplementedError: Unsupported Arrow type: binary This exception is thrown by iter of ParquetDFLoaderIterDataPipe(columns=None, device='', dtype=None, > source_dp=FSSpecFileListerIterDataPipe, use_threads=False)

Versions

Collecting environment information... PyTorch version: 1.12.0+cu102 Is debug build: False CUDA used to build PyTorch: 10.2 ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04 LTS (x86_64) GCC version: (Ubuntu 11.2.0-19ubuntu1) 11.2.0 Clang version: Could not collect CMake version: Could not collect Libc version: glibc-2.35

Python version: 3.9.12 (main, Jun 1 2022, 11:38:51) [GCC 7.5.0] (64-bit runtime) Python platform: Linux-5.15.0-41-generic-x86_64-with-glibc2.35 Is CUDA available: True CUDA runtime version: Could not collect GPU models and configuration: GPU 0: NVIDIA GeForce GTX 1080 Ti Nvidia driver version: 515.48.07 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

Versions of relevant libraries: [pip3] mypy-extensions==0.4.3 [pip3] numpy==1.21.4 [pip3] pytorch-ranger==0.1.1 [pip3] torch==1.12.0 [pip3] torch-optimizer==0.1.0 [pip3] torcharrow==0.1.0 [pip3] torchdata==0.4.0 [pip3] torchmetrics==0.7.3 [pip3] torchvision==0.13.0 [conda] cudatoolkit 11.7.0 hd8887f6_10 conda-forge [conda] numpy 1.21.4 pypi_0 pypi [conda] pytorch-ranger 0.1.1 pypi_0 pypi [conda] torch 1.12.0 pypi_0 pypi [conda] torch-optimizer 0.1.0 pypi_0 pypi [conda] torcharrow 0.1.0 pypi_0 pypi [conda] torchdata 0.4.0 pypi_0 pypi [conda] torchmetrics 0.7.3 pypi_0 pypi [conda] torchvision 0.13.0 pypi_0 pypi

vedantroy avatar Jul 31 '22 08:07 vedantroy

cc: @wenleix

ejguan avatar Aug 01 '22 14:08 ejguan

Yeah TorchArrow doesn't support binary type yet. May i ask what's the use case there? (e.g. is the binary data used to represent image or other stuffs that plans to use Tensor, or it's plan to use it as SQL VarBinary? )

wenleix avatar Aug 18 '22 04:08 wenleix

:+1: to this issue.

@wenleix my use-case is as follow:

  1. I have a large dataset of image/text pairs to process from s3
  2. I have implemented all the cleaning and preprocessing for images and text based on pyspark as it is the only distributed environment I can use in my company. Moreover, spark is extremely efficient in accessing and reading the images from the distributed-file-system. Download all the images and text into the cluster local storage is not a scalable solution for my uscase.
  3. Spark and PySpark Dataframe does not allow you to store multi-dimensional arrays of float values (Spark is shit and tend to go out-of-memory when handing columns with large arrays, so flatten stuff out is not the best option)
  4. Following Petastorm's solution I converted the images and text-tokens from numpy arrays to bytearrays like suggested by @vedantroy.
  5. At model training time, instead of using Petastorm dataloaders that are extremely slow, I wanted to construct a DataPipe that reads a single parquet file at the time (like if each parquet file was a single shard of my dataset) and applied the needed collate function to convert bytearray into torch tensors. Note that Petastorm dataloader does not support the iterator interface, thus are not well-suited to support large datasets where a single training epoch takes multiples days.

I think this will makes my training process much faster and I can see this integration with Spark quite powerful. In my understanding the limitation is related to the dtypes supported by torcharrow. Is it possible to add the bytearray type to the list of core dtypes supported by torcharrow. In the end I assume it to be encoded similarly to a string type

andompesta avatar Jan 06 '23 17:01 andompesta

We don't yet have plan to add varbinary data type into TorchArrow at this moment. Have you considered to use an IterDataPipe over pyarrow.Table or pyarrow.RecordBatch and then convert them (e.g. via Python buffer) into PyTorch Tensor for further processing?

wenleix avatar Jan 07 '23 21:01 wenleix

thanks for the answer @wenleix, but I think I haven't understood your suggestion. Could you provide an example ?

andompesta avatar Jan 08 '23 10:01 andompesta

Ok, not sure if someone is interested, but I came up with this solution, based on your advice


from torchdata.datapipes.iter import (
    IterDataPipe,
    FSSpecFileLister,
)
from torchdata.datapipes import functional_datapipe
from torchdata.dataloader2 import (
    DataLoader2,
    PrototypeMultiProcessingReadingService,
)
import pyarrow.parquet as pq
from numpy import load, stack
from io import BytesIO
from typing import Union


@functional_datapipe("parquet_reader")
class ParquetReaderIter(IterDataPipe):

    def __init__(
        self,
        source_datapipe: IterDataPipe,
    ) -> None:
        super().__init__()
        self.source_datapipe = source_datapipe

    def __iter__(self):
        for files_chucnk in self.source_datapipe:
            # from parquet file to pyarrow Table
            chuck_table = pq.read_table(files_chucnk, memory_map=True)
            yield chuck_table


@functional_datapipe("arrow_batch")
class ArrowBatchIter(IterDataPipe):

    def __init__(
        self,
        source_datapipe: IterDataPipe,
        batch_size: int,
        shuffle: bool = False,
    ) -> None:
        super().__init__()
        self.source_datapipe = source_datapipe
        self.batch_size = batch_size
        self.shuffle = shuffle

    def __iter__(self):
        for table in self.source_datapipe:
            # from pyarrow Table to list of batches
            batches = table.to_batches(self.batch_size)
            yield from batches


@functional_datapipe("batch_decoder")
class BatchDecoderIter(IterDataPipe):

    def __init__(
        self,
        source_datapipe: IterDataPipe,
    ) -> None:
        super().__init__()
        self.source_datapipe = source_datapipe

    @staticmethod
    def decode(value):
        # decode values accordinbg to NdarrayCodec
        memfile = BytesIO(value)
        return load(memfile)

    def __iter__(self):
        for batch in self.source_datapipe:
            # your decode function
            yield 


def get_data_pipeline(
    dataset_location: str,
    batch_size: int,
    shuffle: bool,
    chunk_prefetch_buffer: int = 10,
    dataset_file_masks: Union[str, list(str)] = "*.parquet",
) -> IterDataPipe:
    pipeline = (
        FSSpecFileLister(
            dataset_location,
            dataset_file_masks,
        )
        # split parquet files across workers into shards
        .sharding_filter()
        # read chunk of file into a arrow Table
        .parquet_reader()
        # apply pre-fetching to read a new chunk while training
        .prefetch(chunk_prefetch_buffer)
        # apply batching
        .arrow_batch(batch_size)
    )

    if shuffle:
        # add shuffling if needed
        pipeline = pipeline.shuffle()

    # apply collate function
    pipeline = pipeline.batch_decoder()

    return pipeline


def get_multi_processing_dataloader(
    dataset_location: str,
    batch_size: int,
    shuffle: bool,
    num_workers: int,
    **kwargs,
) -> DataLoader2:
    reading_serivce = PrototypeMultiProcessingReadingService(
        num_workers=num_workers,
        multiprocessing_context="fork",
    )

    pipeline = get_data_pipeline(
        dataset_location=dataset_location,
        batch_size=batch_size,
        shuffle=shuffle,
        **kwargs,
    )

    dataloader = DataLoader2(
        pipeline,
        reading_service=reading_serivce,
    )

    return dataloader

andompesta avatar Jan 16 '23 22:01 andompesta