data
data copied to clipboard
Performance Comparison between native AWSSDK and FSSpec (boto3) based DataPipes
🐛 Describe the bug
After AWSSDK
is integrated with TorchData, we now have two categories of DataPipe
s to access and load data from AWS S3 Bucket:
-
DataPipe
usingfsspec
: It relies ons3fs
module to list/load data from S3 bucket. -
DataPipe
usingAWSSDK
: It relies on pybind fromAWSSDK_CPP
module.
And, I want to carry out a performance comparison of Lister
and Opener
/Loader
between these two ways.
- For
Lister
s, I was using the same root path of"s3://ai2-public-datasets/charades"
and validated that they returned the same values during iteration.
Testing script
import numpy as np
import timeit
s3_path = "s3://ai2-public-datasets/charades"
def s3_fl_time():
SETUP_CODE = """
from torchdata.datapipes.iter import IterableWrapper, S3FileLister
from __main__ import s3_path
dp = S3FileLister(IterableWrapper([s3_path]), region="us-west-2")
"""
TEST_CODE = """
_ = list(dp)
"""
times = timeit.repeat(setup = SETUP_CODE, stmt = TEST_CODE, repeat=10, number = 5)
print(f"S3FileLister: Mean({np.average(times)}), STD({np.std(times)})")
def fsspec_fl_time():
SETUP_CODE = """
from torchdata.datapipes.iter import IterableWrapper, FSSpecFileLister
from __main__ import s3_path
dp = FSSpecFileLister(IterableWrapper([s3_path]), anon=True)
"""
TEST_CODE = """
_ = list(dp)
"""
times = timeit.repeat(setup = SETUP_CODE, stmt = TEST_CODE, repeat=10, number = 5)
print(f"FSSpecFileLister: Mean({np.average(times)}), STD({np.std(times)})")
if __name__ == "__main__":
s3_fl_time()
fsspec_fl_time()
And the result is:
S3FileLister: Mean(1.7595681754999994), STD(0.20364943594288445)
FSSpecFileLister: Mean(0.19180457339999962), STD(0.5630912985701465)
The FSSpecFileLister
performs 10x better than S3FileLister
.
- Due to the different behaviors between
S3FileLoader
andFSSpecFileOpener
, except iterating over these twoDataPipe
s, I also carried out an extra experiment by addingread
from file returned by theseDataPipe
s. And, I only used a two datasets hosted on S3 bucket for testing simply to save my time running tests.
Testing script
import numpy as np
import timeit
s3_file_path = ["s3://ai2-public-datasets/charades/Charades.zip", "s3://ai2-public-datasets/charades/CharadesEgo.zip"]
def s3_fo_time():
SETUP_CODE = """
from torchdata.datapipes.iter import IterableWrapper, S3FileLister, S3FileLoader
from __main__ import s3_file_path
dp = S3FileLoader(S3FileLister(IterableWrapper(s3_file_path), region="us-west-2"), region="us-west-2")
"""
TEST_CODE = """
_ = list(dp)
"""
times = timeit.repeat(setup = SETUP_CODE, stmt = TEST_CODE, repeat=10, number = 5)
print(f"S3FileLoader: Mean({np.average(times)}), STD({np.std(times)})")
def fsspec_fo_time():
SETUP_CODE = """
from torchdata.datapipes.iter import IterableWrapper, FSSpecFileLister, FSSpecFileOpener
from __main__ import s3_file_path
dp = FSSpecFileOpener(FSSpecFileLister(IterableWrapper(s3_file_path), anon=True), mode="rb", anon=True)
"""
TEST_CODE = """
_ = list(dp)
"""
times = timeit.repeat(setup = SETUP_CODE, stmt = TEST_CODE, repeat=10, number = 5)
print(f"FSSpecFileOpener: Mean({np.average(times)}), STD({np.std(times)})")
def s3_fo_read_time():
SETUP_CODE = """
from torchdata.datapipes.iter import IterableWrapper, S3FileLister, S3FileLoader
from __main__ import s3_file_path
dp = S3FileLoader(S3FileLister(IterableWrapper(s3_file_path), region="us-west-2"), region="us-west-2").map(lambda x: x.read(), input_col=1)
"""
TEST_CODE = """
_ = list(dp)
"""
times = timeit.repeat(setup = SETUP_CODE, stmt = TEST_CODE, repeat=10, number = 5)
print(f"S3FileLoader: Mean({np.average(times)}), STD({np.std(times)})")
def fsspec_fo_read_time():
SETUP_CODE = """
from torchdata.datapipes.iter import IterableWrapper, FSSpecFileLister, FSSpecFileOpener
from __main__ import s3_file_path
dp = FSSpecFileOpener(FSSpecFileLister(IterableWrapper(s3_file_path), anon=True), mode="rb", anon=True).map(lambda x: x.read(), input_col=1)
"""
TEST_CODE = """
_ = list(dp)
"""
times = timeit.repeat(setup = SETUP_CODE, stmt = TEST_CODE, repeat=10, number = 5)
print(f"FSSpecFileOpener: Mean({np.average(times)}), STD({np.std(times)})")
if __name__ == "__main__":
s3_fo_time()
fsspec_fo_time()
s3_fo_read_time()
fsspec_fo_read_time()
And the result is:
# Without `read`
S3FileLoader: Mean(23.793047750200007), STD(5.782844565863793)
FSSpecFileOpener: Mean(2.461926894699997), STD(0.34594020726696345)
# With `read`
S3FileLoader: Mean(31.570115949799998), STD(5.767492995195747)
FSSpecFileOpener: Mean(25.325279079399998), STD(5.052614560529884)
By comparing the results without read
, I believe S3FileLoader
would trigger loading data but FSSpecFileOpener
won't read data from remote. So, it makes more sense to compare these two DataPipe
s both with the read
operation attached. The FSSpecFileOpener
still beats S3FileLoader
about 25% performance wise.
Due to the performance regression with AWSSDK
, it becomes hard for me to recommend users to use native S3FileLister
or S3FileLoader
.
cc: @ydaiming
Versions
main branch
I only execute these scripts on my Mac as our out AWS cluster doesn't allow me to access the S3.
Thanks @NivekT and @ejguan for the context in https://github.com/pytorch/data/pull/847!
I've run some preliminary benchmarks on our datasets comparing s3io (aka s3 plugin - pybinded c++ client) versus fsspec. TL;DR - at high enough batchsize + num_workers (dataloader workers) the throughput is comparable (although s3io is still ~ 16% faster) at ~2.4M samples/sec versus ~2.0M samples/sec. Where the difference really shows is when you strip away all the parallelism gimmick. In this case s3io is ~2x faster than fsspec.
Below are the benchmark results:
Experiment Parameters:
EDIT: Note that the qps accounts for cross-region reads - our datasets are located in us-east-1 and my ec2 desktop is in us-west-2 so we can't directly compare throughput with the benchmark above.
- Metric is qps (aka "lines/second", e.g.
batch_size=32
has32
lines) - Each line being about 1.6kb in size (so 100k lines ~ 1.6kb * 100k/1024 ~ 156MB)
- To make the results comparable, I ran the benchmark on smaller-sized shards (each shard is 156MB of json text data)
- Measurements are taken over 1000 batches
-
num_workers
andbatch_size
is passed directly to theDataLoader(...,num_workers, batch_size)
(e.g.num_workers=0
will run the benchmark on the main process) - Measures pure line reads (e.g.
list().load().readlines()
) no other transforms are applied on the datapipe (no shuffling, no sampling)
Notes:
- To look at the most vanilla throughput, take a look at the lines with
num_workers=0
, this does away with any dataloader multiprocessing variances. - Sometimes the p0 qps is super low, this happens on the shard boundaries since we do not prefetch
- One interesting observations that I haven't had time to dig into (will update once I do):
- when
num_workers > 0
there seems to be some multiprocessing interference since largerbatch_sizes
gives me better qps. - you can also see this where the qps for
num_workers = 0
is better thannum_workers = 1
for eachbatch_size
.
- when
s3pipe num workers batch size avg p0 p25 p50 p75 p100
------ ----------- ---------- ------- ------- ------- ------- ------- -------
s3io 0 16 201274 7 196969 202350 212648 225941
s3io 0 32 232753 15 227345 233010 245591 262689
s3io 0 64 266939 31 259390 265831 282259 295217
s3io 0 128 266902 64 258909 265598 278536 297637
s3io 0 256 273715 123 266704 272834 283543 300065
s3io 0 512 276921 243 274824 279354 287159 302561
s3io 0 1024 277564 370 279296 281789 283976 293395
s3io 1 16 120254 7 77926 138233 151336 166989
s3io 1 32 174170 11 99048 176912 253507 311967
s3io 1 64 243339 11 115773 211886 380077 487007
s3io 1 128 314150 55 123450 178838 516737 699358
s3io 1 256 324263 97 131158 186659 510828 659515
s3io 1 512 324929 252 130781 231320 584660 736449
s3io 1 1024 352227 496 133334 249035 590060 796263
s3io 1 2048 319459 973 151944 224123 510122 642646
s3io 2 16 122463 7 117732 123926 126787 503288
s3io 2 32 226006 15 206407 233157 253373 477570
s3io 2 64 301502 29 257777 294349 344868 726529
s3io 2 128 343364 48 284174 335404 404285 1100989
s3io 2 256 503541 118 250606 325955 560916 2983579
s3io 2 512 796450 210 204987 297906 560210 4305015
s3io 2 1024 377733 324 262619 326772 498429 3525406
s3io 4 16 159269 3 123606 134514 143295 523131
s3io 4 32 223498 14 150992 192741 255065 856051
s3io 4 64 463409 34 231618 323635 626472 1638626
s3io 4 128 655444 62 302806 418001 524173 1683079
s3io 4 256 1499592 125 243778 1747684 2591447 3508770
s3io 4 512 2207271 230 222010 2612782 3453065 4408318
s3io 4 1024 2229304 77 143987 3108584 3777646 4344818
s3io 8 16 235531 2 90224 173068 427933 578910
s3io 8 32 420268 5 153782 293532 774494 1219057
s3io 8 64 924232 28 405442 976360 1510673 2001937
s3io 8 128 1071561 63 484375 1119832 1513741 2549037
s3io 8 256 1759111 125 300977 1843266 3137106 3718813
s3io 8 512 2025294 251 235203 2432248 3463428 4285025
s3io 8 1024 2498502 463 1982913 3019237 3426967 4480638
s3io 8 2048 2572884 889 403433 3143241 3757130 5435042
s3io 16 16 247936 4 91293 193142 414299 672863
s3io 16 32 471964 4 163008 336697 817840 1143923
s3io 16 64 964213 7 429131 982190 1528981 1967771
s3io 16 128 1389164 26 651107 1403767 2260536 2703389
s3io 16 256 1836468 36 860716 2230178 2841971 3483237
s3io 16 512 2122104 108 1508325 2182173 3105082 3995848
s3io 16 1024 2228400 222 1439927 2699041 2982076 4068803
s3io 32 16 274790 1 94686 298580 446279 628982
s3io 32 32 490992 2 191883 577720 761773 1023281
s3io 32 64 815386 3 355476 790754 1300752 1768443
s3io 32 128 1277918 6 592099 1587933 1906805 2409413
s3io 32 256 1843344 15 1076507 2235011 2625143 3213584
s3io 32 512 2253216 26 1279235 2625880 3289343 4376035
s3io 32 1024 2460372 45 1833763 2686775 3350917 4450608
fsspec 0 16 96282 15 87635 96552 105414 149066
fsspec 0 32 102321 105 95495 103008 110594 157072
fsspec 0 64 104073 196 99672 105508 111080 224916
fsspec 0 128 104947 373 104491 108868 113347 223996
fsspec 0 256 103895 326 108626 112417 115900 192087
fsspec 0 512 95430 599 109113 113021 115755 153441
fsspec 0 1024 76661 971 5832 109937 112467 130060
fsspec 1 16 64368 14 37367 61772 91475 120297
fsspec 1 32 78291 33 38756 56498 121372 171137
fsspec 1 64 103375 62 43704 60657 165441 363959
fsspec 1 128 102007 115 45662 56697 163913 247354
fsspec 1 256 113207 212 47100 62771 193018 253203
fsspec 1 512 93298 476 47287 62359 159802 255048
fsspec 1 1024 73394 837 5976 67896 95976 236039
fsspec 2 16 87075 6 52098 63849 79879 292345
fsspec 2 32 117182 27 84631 106916 133593 827686
fsspec 2 64 137213 58 95813 122486 158137 1250787
fsspec 2 128 224531 105 72120 119822 178552 1826067
fsspec 2 256 604227 233 72172 155979 1265114 3056899
fsspec 2 512 652140 517 74296 147225 248449 3637834
fsspec 2 1024 855907 936 47756 113607 2127384 3942269
fsspec 4 16 119951 13 81253 96282 113308 461415
fsspec 4 32 214074 32 110961 139161 262108 881935
fsspec 4 64 417155 53 135571 180890 632781 1785408
fsspec 4 128 782010 119 115999 208009 1520354 2786254
fsspec 4 256 1185332 215 118982 1185783 2246033 3385801
fsspec 4 512 1381044 453 92960 1703973 2329940 3694807
fsspec 4 1024 1443561 902 80505 1705301 2652285 4115012
fsspec 8 16 169676 14 70628 152887 232681 522824
fsspec 8 32 411001 24 104540 307435 733555 1180200
fsspec 8 64 817873 57 209504 891595 1374016 1942106
fsspec 8 128 953035 101 155030 720554 1796774 2678495
fsspec 8 256 1475653 241 210541 1469572 2519408 3185707
fsspec 8 512 1639553 442 1111797 1958786 2205231 3638460
fsspec 8 1024 1753882 888 231907 2054775 2574985 3914674
fsspec 16 16 206221 14 106862 162568 292215 541439
fsspec 16 32 450304 27 201878 355513 751119 1150290
fsspec 16 64 544615 55 326337 474516 695706 1589386
fsspec 16 128 1294437 105 554374 1531669 2027241 2578979
fsspec 16 256 1590044 196 881182 1667147 2424060 3155427
fsspec 16 512 1845513 354 1462017 1959899 2503906 3580742
fsspec 16 1024 2029104 844 1936001 2407187 2608265 3723242
fsspec 32 16 261293 10 105641 206391 443640 631938
fsspec 32 32 452097 23 166063 402936 743786 1144495
fsspec 32 64 771898 44 330367 686847 1275442 1634991
fsspec 32 128 1244121 90 628953 1527951 1857241 2426118
fsspec 32 256 1351322 186 782808 1177364 2159871 3052088
fsspec 32 512 1620434 362 1184810 1697163 2221502 3799634
fsspec 32 1024 2005850 648 1327637 2260161 2865855 4136722
@kiukchung Thank you so much to help us benchmarking on the text use case! This seems contradictory to the benchmarking result we have carried out previously (I am more focusing on num_worker=0
case, because I want to make sure the baseline working as expected). As @NivekT suggested, we did find fsspec
having better performance when loading archives of images even the size of each archive (shard) is similar to yours in 100MB.
I have a noob question on the benchmarking settings. Does 1000 batches mean you would only read data from a single shard since a shard has 100k lines with a smaller batch sizes like (16, 32 and 64)? And, since only one shard has been read, it seems weird to me on the low P0 value.
Do you mind helping us to test those two implementations with higher num_workers
? 48 and 64? I was suspecting the S3-plugin would hit memory-bound earlier than fsspec
.
@ejguan thanks for taking a look and your insights/questions. I’m on PTO today so will update with the answers to your questions on Mon/Tue. Will also clean up the benchmark code and post it here so that you can also run it on your end (I need to remove some dependencies to our internal tools that read the dataset manifest containing the s3 URLs of the shards).
RE: P0 being super low. I should’ve been clearer, yes you are correct that for low batch sizes the benchmark will only read one shard. However the first source datapipe in my chain is a “ListShardsFromManifest” which is a custom iter datapipe that I implemented that simply queries the manifest file (a json file in s3) given the dataset name, branch, and region. I believe that the low P0 qps is coming from the fact that to read the first batch, we first read the manifest (a list + read s3 operation). The manifest file itself is pretty small (no more than 100kb) so most of that latency is coming from making those two s3 requests cross region (datasets are in us-east-1 and my desktop is in is-west-2). I’ll try to run the benchmarks on the same region to see if that improves the P0 numbers.
Sorry it took me longer than expected. Here's the benchmarking script:
#!/usr/bin/env python3
"""
requirements.txt:
tabulate
hydra-core
omegaconf
torchdata
torch
"""
import statistics
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Any
import hydra
from hydra.core.config_store import ConfigStore
from omegaconf import OmegaConf
from tabulate import tabulate
from torch.utils.data import DataLoader
from torchdata.datapipes.iter import IterableWrapper
class S3PIPE(str, Enum):
s3io = "s3io" # s3-plugin (pybinded c++ client) https://github.com/pytorch/data/blob/main/torchdata/datapipes/iter/load/s3io.py
fsspec = "fsspec" # fsspec (aioboto3) https://github.com/pytorch/data/blob/main/torchdata/datapipes/iter/load/fsspec.py
@dataclass
class Config:
basedir_url: str = "s3://<YOUR S3 DIR>"
files: List[str] = field(default_factory=lambda: ["<FILE>", "<FILE>", "..."])
# number of batches to sample for metrics
num_batches: int = 100
batch_sizes: List[int] = field(default_factory=lambda: [16, 32, 64, 128, 256])
num_workers: List[int] = field(
default_factory=lambda: [0, 1, 2, 4, 8, 16, 32, 48, 64]
)
s3pipes: List[S3PIPE] = field(default_factory=lambda: [S3PIPE.s3io, S3PIPE.fsspec])
HEADER = [
"s3pipe",
"# workers",
"batch size",
"warmup time (ms)",
"avg (time \u03BCs)",
"avg (qps)",
"p0",
"p25",
"p50",
"p75",
"p100",
]
@hydra.main(version_base=None, config_name="config")
def main(cfg: Config) -> None:
print(f"Loaded Config:\n-----------\n{OmegaConf.to_yaml(cfg)}")
print("Measuring samples/second (qps). Starting benchmarks...\n")
total_benchmarks = len(cfg.num_workers) * len(cfg.batch_sizes) * len(cfg.s3pipes)
benchmark_idx = 0
for num_workers in cfg.num_workers:
for batch_size in cfg.batch_sizes:
table: List[List[Any]] = []
for s3pipe in cfg.s3pipes:
print(
f"Running benchmark [{benchmark_idx:0d}/{total_benchmarks}]:"
f" s3pipe: {s3pipe:6s} num_workers: {num_workers:2d} batch_size: {batch_size:3d}"
)
run_benchmark(s3pipe, batch_size, num_workers, cfg, table)
benchmark_idx += 1
print_table(table)
def print_table(table):
print(tabulate(table, headers=HEADER, stralign="left", floatfmt="8,.0f") + "\n")
def run_benchmark(
s3pipe: S3PIPE,
batch_size: int,
num_workers: int,
cfg: Config,
table: List[List[Any]],
):
s3urls = [f"{cfg.basedir_url}/{f}" for f in cfg.files]
num_batches = cfg.num_batches
s3_shard_urls = IterableWrapper(s3urls).cycle()
if s3pipe == S3PIPE.s3io:
s3_shards = s3_shard_urls.load_files_by_s3()
else: # s3pipe == S3PIPE.fsspec:
s3_shards = s3_shard_urls.open_files_by_fsspec()
dataset = s3_shards.readlines(return_path=False)
num_processed_batches = 0
warmup = True
times_ns = []
warmup_time_ns = -1.0
start = time.perf_counter_ns()
for _ in DataLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
):
end = time.perf_counter_ns()
if warmup:
warmup = False
warmup_time_ns = end - start
else:
num_processed_batches += 1
times_ns.append(end - start)
start = end
if num_processed_batches >= num_batches:
assert len(times_ns) == num_batches
break
qps = [batch_size / (t / 1e9) for t in times_ns]
qps_avg = statistics.mean(qps)
qps_p25, qps_p50, qps_p75 = statistics.quantiles(qps, n=4)
qps_min = min(qps)
qps_max = max(qps)
table.append(
[
s3pipe.name,
num_workers,
batch_size,
warmup_time_ns / 1e6, # milliseconds
statistics.mean(times_ns) / 1e3, # microseconds
qps_avg,
qps_min,
qps_p25,
qps_p50,
qps_p75,
qps_max,
]
)
if __name__ == "__main__":
cs = ConfigStore.instance()
cs.store(name="config", node=Config)
main()
Did some more digging and here are some observations:
- s3io is technically faster but you can get close with fsspec by using a larger (e.g 16, 32) dataloader
num_workers
- fsspec tends to scale more linearly, predictably as
num_workers
increases (up to a certain point) - fsspec is stable w.r.t
batch_size
(e.g. same throughput regardless of different batch_sizes) - fsspec has a bug where it is not thread-safe and will crash when
num_workers
is large (but technically anything above2
should trigger the bug). Filed a bug report here: https://github.com/pytorch/data/issues/906 - s3io's has large variance within a benchmark (constant
num_workers
andbatch_size
) due to it not really "reading-ahead" the buffer and no use of asyncio.- Once you hit the buffer boundary its a blocking call to replenish the buffer. This causes a large variance, especially for small buffer sizes.
- The max throughput of s3io observed empirically (on a p4d instance) is 400-500Mbps so starting with a buffer size of 450MB is a good start for this type of host.
- For text files larger than 256MB, setting s3io's
use_multipart_upload=True
shows zero benefits. - s3io doesn't scale linearly with
num_workers
. This makes sense since it uses a single blocking buffer. - For pure downloads (
num_workers=0
) s3io maxes out at ~500Mbps versus fsspec ~55Mbps. But as mentioned above there are other factors such as: size of shard, buffer size,num_workers
, that affect the overall qps of the dataloader.
So what is the bottom line?
- Use s3io for larger shards (anything above 512MB) and when you have more shards than
num_workers
. Use fsspec for everything else as long as this bug is fixed: https://github.com/pytorch/data/issues/906 - When using s3io:
- Don't set
num_workers
greater than4
unless your datapipe's bottleneck is somewhere in the downstream pipe chain. - Set
buffer_size=k*shard_size
wherek
is some fraction (e.g.0.5
,0.75
) of the size of the shards in s3. This will control the number of blocking buffer refills. (e.g.0.5
will have 1 blocking buffer refill) - Shard the urls. For instance if you have 16 shards in s3 and 4 dataloader workers, make sure you write a shardable iter-datapipe that can distribute 4 shards to each dataloader worker.
- When using fsspec:
- Make sure to set
default_cache_type="readahead"
. This gives about a 30% throughput improvement over the default cache typebytes
.
- Make sure to set
- Don't set
@kiukchung This is amazing! Thank you for providing such detailed benchmarking result and analysis.
5. s3io's has large variance within a benchmark (constant
num_workers
andbatch_size
) due to it not really "reading-ahead" the buffer and no use of asyncio.
Even without asyncio
, IMHO, prefetch_factor
would help the perf by doing some kind of reading ahead. Have you tested enlarging prefetch_factor
? If we find this is helpful, we might do reading ahead directly within the S3DataPipe
.
All your bottom lines you mentioned are super useful for users. It definitely deserves to be written in our documents! cc: @NivekT
@kiukchung Thanks for looking into this and sharing the results! Your findings are very helpful and we should incorporate them into our doc. A few questions:
- Does
s3io
support streaming? One common use case may be a tar archive with JPGs inside andfsspec
allows reading of JPGs without download the whole archive first, which reduces memory footprint. The code looks something like:
dp.open_files_by_fsspec(mode="rb", anon=True).load_from_tar(mode="r|")
# Note that `mode="r|"` for streaming
- You mentioned that users should keep
num_workers
less than 4 unless the bottleneck is elsewhere. Is there any drawback to usings3io
with many workers? Maybe memory usage? I anticipate users may be using 8-12 workers for preprocessing/transformation of data.
- Does s3io support streaming? One common use case may be a tar archive with JPGs inside and fsspec allows reading of JPGs without download the whole archive first, which reduces memory footprint. The code looks something like:
Haven't tried it on tar-balls but "block-streaming" works for text files. I'd assume that since tar-balls can be stream opened, this also works for tars with a caveat (see the end of this paragraph). s3io
downloads the contents into a buffer (the buffer size is configurable) so the memory overhead would be as big as the buffer. Unfortunately it does not do double buffering so the dataloader workers block until the buffer if refilled once in a while. This means that many dataloader workers will not be able to "stream" extract a tarball so one would have to write a shardable
datapipe that assigns an s3 url for the tar to a single dataloader worker instance. This would work well in practice if you had a dataset that comprised of many tarballs but wouldn't work well for a large single tarball.
Will put up a PR for an S3Lister
datapipe that is shardable - I've got one for our internal use-case so I need to remove the internal deps for a clean PR that can be upstreamed to torchdata. Will mention this issue on the PR so that we have lineage.
- You mentioned that users should keep num_workers less than 4 unless the bottleneck is elsewhere. Is there any drawback to using s3io with many workers? Maybe memory usage? I anticipate users may be using 8-12 workers for preprocessing/transformation of data.
Yeah so I was only benchmarking pure reads. In practice the dataloader bottleneck will be in the pipes towards the end of the chain that does a lot of data transforms (e.g. tokenization for NLP). For our use-case we have most of the data pre-processed in S3 that is ready to be fed into the forward() method as soon as they are read (e.g. no pre-processing in the trainer) hence my recommendation for 4 workers.
I mentioned this above, but reiterating here - the s3io
implementation relies on a single download buffer that is accessed by all dataloader workers, and once the buffer depletes it is a blocking call until the buffer is replenished. In my benchmarks I've found 4 workers to be the equilibrium - enough to keep the buffer moving but not enough to keep the workers waiting too frequently.
@NivekT @ejguan any progress in getting some of the best-practices I commented above documented in the torchdata docs page?
@kiukchung Not yet. Do you want to open a PR to add them? We are currently tied to work on a certain features prior to branch cut.