petastorm icon indicating copy to clipboard operation
petastorm copied to clipboard

Speeding up loading data from spark

Open jmpanfil opened this issue 3 years ago • 3 comments

I've been working on using petastorm to train PyTorch models from spark dataframes (somewhat following this guide). I'm curious if there are any ways I can speed up data loading.

Here's a basic overview of my current flow. df_train is a spark dataframe with three columns: x1 (float), x2 (binary 0,1), y (float). I'm using pyspark.

x_feat = ['x1', 'x2']
y_name = 'y'
spark.conf.set(SparkDatasetConverter.PARENT_CACHE_DIR_URL_CONF, "file:///dbfs/tmp/petastorm/cache")
converter_train = make_spark_converter(df_train)

with converter_train.make_torch_dataloader(batch_size=bs) as train_dataloader:
    train_dataloader_iter = iter(train_dataloader)
    steps_per_epoch = len(converter_train) // bs
    for step in range(steps_per_epoch):
      pd_batch = next(train_dataloader_iter)
      pd_batch['features'] = torch.transpose(torch.stack([pd_batch[x] for x in x_feat]), 0, 1)
      inputs = pd_batch['features'].to(device)
      labels = pd_batch[y_name].to(device)
      ... # modeling and stuff

My concern is that the torch operations might not be optimal. Something else I tried was first creating an array column in my spark dataframe for x1 and x2. I was surprised to find that each epoch was more than 2 times slower than the above strategy.

df_train = df_train.withColumn("features", array("x1", 'x2')).select('features', 'y')
# remainder same as above except torch.transpose(torch.stack([pd_batch[x] for x in x_feat]), 0, 1) was removed

Are there any improvements I can make here?

jmpanfil avatar Mar 03 '22 21:03 jmpanfil

There are many pieces in work here: SparkDatasetConverter would store spark dataframe into a parquet file, then a make_torch_dataloader would bring up workers pool to read that parquet file. I assume the slowness you are referring to is coming from next call? Can you confirm please?

make_torch_dataloader takes petastorm_reader_kwargs argument. You can see full documentation in make_batch_reader. You can try tweaking some parameters there (reader_pool_type, workers_count) to play with various parallelization parameters (thread vs process pool, number of workers).

Hope this helps.

selitvin avatar Mar 08 '22 23:03 selitvin

Hi thanks for your help! My main concern is that using torch.stack with every next call is inefficient, and I'm missing an obvious way to use the SparkDatasetConverter that doesn't require calling stack. That's why I tried creating an array column in my dataframe first, but that turned out to be slower.

I will dive into the full documentation that you sent and play around with some parameters.

jmpanfil avatar Mar 10 '22 02:03 jmpanfil

@jmpanfil how did your experimentation with parameters go?

Data-drone avatar May 25 '22 07:05 Data-drone