datasets
datasets copied to clipboard
Filter on dataset too much slowww
I have a dataset with 50M rows. For pre-processing, I need to tokenize this and filter rows with the large sequence.
My tokenization took roughly 12mins. I used map()
with batch size 1024 and multi-process with 96 processes.
When I applied the filter()
function it is taking too much time. I need to filter sequences based on a boolean column.
Below are the variants I tried.
- filter() with batch size 1024, single process (takes roughly 3 hr)
- filter() with batch size 1024, 96 processes (takes 5-6 hrs ¯\_(ツ)_/¯)
- filter() with loading all data in memory, only a single boolean column (never ends).
Can someone please help?
Below is a sample code for small dataset.
from datasets import load_dataset
dataset = load_dataset('glue', 'mrpc', split='train')
dataset = dataset.map(lambda x: {'flag': random.randint(0,1)==1})
def _amplify(data):
return data
dataset = dataset.filter(_amplify, batch_size=1024, keep_in_memory=False, input_columns=['flag'])
When I use the filter on the arrow table directly, it works like butter. But I can't find a way to update the table in Dataset
object.
ds_table = dataset.data.filter(mask=dataset['flag'])
@thomwolf @lhoestq can you guys please take a look and recommend some solution.
Hi ! Currently the filter method reads the dataset batch by batch to write a new, filtered, arrow file on disk. Therefore all the reading + writing can take some time. Using a mask directly on the arrow table doesn't do any read or write operation therefore it's way quicker.
Replacing the old table by the new one should do the job:
dataset._data = dataset._data.filter(...)
Note: this is a workaround and in general users shouldn't have to do that. In particular if you did some shuffle
or select
before that then it would not work correctly since the indices mapping (index from __getitem__
-> index in the table) would not be valid anymore. But if you haven't done any shuffle
, select
, shard
, train_test_split
etc. then it should work.
Ideally it would be awesome to update the filter function to allow masking this way ! If you would like to give it a shot I will be happy to help :)
Yes, would be happy to contribute. Thanks
Hi @lhoestq @ayubSubhaniya,
If there's no progress on this one, can I try working on it?
Thanks, Gunjan
Sure @gchhablani feel free to start working on it, this would be very appreciated :) This feature is would be really awesome, especially since arrow allows to mask really quickly and without having to rewrite the dataset on disk
Hi @lhoestq, any updates on this issue? The filter
method is still veryyy slow 😕
No update so far, we haven't worked on this yet :/
Though PyArrow is much more stable than 3 years ago so it would be a good time to dive into this
Hi @lhoestq, thanks a lot for the update!
I would like to work on this(if possible). Could you please give me some steps regarding how should I approach this? Also any references would be great!
I just played a bit with it to make sure using table.filter()
is fine, but actually it seems to create a new table in memory :/
This is an issue since it can quickly fill the RAM, and datasets
's role is to make sure you can load bigger-than-memory datasets. Therefore I don't think it's a good idea in the end to use table.filter()
Anyway I just ran OP's code an it runs in 20ms now on my side thanks to the I/O optimizations we did.
Another way to speed up filter
is to add support pyarrow expressions though, using e.g. arrow formatting + dataset.filter (runs in 10ms on my side):
import pyarrow.dataset as pds
import pyarrow.compute as pc
expr = pc.field("flag") == True
filtered = dataset.with_format("arrow").filter(
lambda t: pds.dataset(t).to_table(columns={"mask": expr})[0].to_numpy(),
batched=True,
).with_format(None)