Add a ThreadPool which respects the order of Parquet dataset pieces.
This PR offers a solution for #551, where the standard ThreadPool implementation can return dataset pieces out of order.
Contributions
- An
OrderedThreadPoolimplementation, which internally keeps track of results and the pieces indexes returned by the ventilator, and only returns pieces in exact order. - 'OrderedVentilatedItemProcessedMessage', which reports the dataset piece index to the 'OrderedThreadPool`.
- 'OrderedWorkerThread'. The only difference between this class and the original
WorkerThreadis that it returns indexedOrderedVentilatedItemProcessedMessageobjects. - Updates to
make_readerandmake_batch_readerto allow for instantation ofOrderedThreadPoolobjects from string withreader_pool_type="orderedthread" - Updates to the docstrings of
make_readerandmake_batch_readerto include the ordered option - Dataset order testing in
test_parquet_reader.py
Worked Example
We provide a modified version of the minimal code example in #551, which can be used to verify the solution.
import pathlib
import numpy as np
import os
from petastorm import make_reader
from petastorm.codecs import ScalarCodec
from petastorm.etl.dataset_metadata import materialize_dataset
from petastorm.unischema import Unischema, UnischemaField, dict_to_spark_row
from pyspark.sql import SparkSession
from pyspark.sql.types import LongType
output_directory = pathlib.Path('./_generated_demo_data')
output_url = output_directory.resolve().as_uri()
session_builder = SparkSession \
.builder \
.appName('Demo')
spark = session_builder.getOrCreate()
sc = spark.sparkContext
schema = Unischema('DemoSchema', [
UnischemaField('timestamp', np.uint64, (), ScalarCodec(LongType()), False),
])
if not os.path.exists(output_url):
# Generate petastorm with timestamps in order
with materialize_dataset(spark, output_url, schema, row_group_size_mb=1):
generator = enumerate(range(1000000))
rows_dict_generator = map(lambda x: {'timestamp': x[0]}, generator)
rows_spark_generator = map(lambda x: dict_to_spark_row(schema, x), rows_dict_generator)
rows_rdd = sc.parallelize(rows_spark_generator)
spark.createDataFrame(rows_rdd, schema.as_spark_schema()) \
.coalesce(1) \
.write \
.mode('overwrite') \
.parquet(output_url)
# Read generated petastorm and check timestamps ordering
last_timestamp = -float("inf")
with make_reader(output_url,
schema_fields=['timestamp'],
reader_pool_type="orderedthread",
shuffle_row_groups=False) as reader:
for row in reader:
# ensure timestamp ordering or num_epochs handling
if row.timestamp < last_timestamp:
raise Exception('Timestamps in petastorm are not in order!')
last_timestamp = row.timestamp
As someone who needs repeatability I'd really like to see this merged. Do you have any insight into the performance implications? Clearly this should be faster than setting workers_count=1 and slower than using the current unordered ThreadPool.
I would assume that any performance differences would only be noticeable for the first few items, and once all queues are all filled and model training becomes the bottleneck there should be almost no difference compared to unordered ThreadPool. If that is correct, then I think this should actually become the default.
The additional overhead seems to be functionally negligible for most practical use-cases where you're model bottlenecked. @jrauch-pros you are correct that there is a performance cost at initialization, see below:
import os
import time
import shutil
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from petastorm import make_batch_reader
from petastorm.tests.test_common import create_test_scalar_dataset
import seaborn as sns
sns.set()
tmp_pq = "/tmp/tmp.parquet"
url = "file://" + tmp_pq
file_counts = [2, 5, 10, 20]
results = []
for num_files in file_counts:
if os.path.exists(tmp_pq):
shutil.rmtree(tmp_pq)
_ = create_test_scalar_dataset(url, max(file_counts), num_files=num_files, partition_by=['id'])
for pool_type in ["thread","orderedthread"]:
times = []
reader = make_batch_reader(url, reader_pool_type=pool_type)
times.append(time.time())
for row in reader:
times.append(time.time())
times = np.asarray(times)
durs = times[1:] - times[:-1]
for row, d in enumerate(durs):
results.append(dict(num_files=str(num_files),pool_type=pool_type, row=row, time=d))
shutil.rmtree(tmp_pq)
results = pd.DataFrame(results)
f,ax = plt.subplots()
_ = sns.lineplot(results, x="row",y="time",hue="num_files",style="pool_type", ax=ax)
plt.yscale('log')
plt.legend(loc='upper right')
f.savefig("ordered_thread_pool_performance.pdf")