xbatcher
xbatcher copied to clipboard
Support for valid examples
Is your feature request related to a problem?
There is currently no support to serve batches that satisfy some valid criteria. It would be nice to filter out batches based on some criteria such as:
- does an example contain a valid value in a target variable?
- does an example contain a valid value at the center of a target variable?
Consider this dataset:
import xarray as xr
import dask.array as da
import numpy as np
w = 100
da = xr.DataArray(np.random.rand(2, w, w), name='foo', dims=['variable','y', 'x'])
# simulate 10% sparse, expensive target data
percent_nans = .90
number_nans = (w ** 2) * percent_nans
da[0] = xr.where(da[1] < .1, da[1], np.nan)
bgen = xbatcher.BatchGenerator(
da,
{'variable': 2, 'x':10, 'y': 10},
input_overlap={'x': 0, 'y': 0},
batch_dims={'x': 100, 'y': 100},
concat_input_dims=True
)
for batch in bgen:
pass
If we are serving this to a machine learning process and we only care about where we have target data. Many of these examples will not be valid i.e. there will be no target value to use for training.
Describe the solution you'd like
I would like to see something like:
w = 100
da = xr.DataArray(np.random.rand(2, w, w), name='foo', dims=['variable','y', 'x'])
# simulate 10% sparse, expensive target data
percent_nans = .90
number_nans = (w ** 2) * percent_nans
da[0] = xr.where(da[1] < .1, da[1], np.nan)
bgen = xbatcher.BatchGenerator(
da,
{'variable': 2, 'x':10, 'y': 10},
input_overlap={'x': 0, 'y': 0},
batch_dims={'x': 100, 'y': 100},
concat_input_dims=True,
valid_example=lambda x: ~np.isnan(x[0][5,5])
)
for batch in bgen:
pass
where we satisfy: np.all(~np.isnan(batch[:,0,5,5]))
Describe alternatives you've considered
see: https://discourse.pangeo.io/t/efficiently-slicing-random-windows-for-reduced-xarray-dataset/2447
I typically filter out all valid "chips" or "patches" in advance and persist as a "training dataset" to get all the computation out of the way. The dims would look something like {'i': number of valid chips, 'variable': 2, 'x': 10, 'y': 10}. I could then simply use xbatcher to batch on the ith dimension.
Additional context
No response
Agree that there should be a way to filter out invalid values. There's a newer duplicate issue at #162 on having a predicate function (had to look up https://dcl-prog.stanford.edu/function-predicate.html to know that predicate functions are those that return a True or False (i.e. boolean)) similar to the valid_example parameter you are proposing here, but I'll post here on a first come first serve basis.
At https://github.com/xarray-contrib/xbatcher/issues/162#issuecomment-1431902345, @cmdupuis3 showed this example code snippet:
Better code sample, which wraps xbatcher and also offers fixed batch sizes:
bgen = xb.BatchGenerator( ds, {'d1':5, 'd2':5}, {'d1':2, 'd2':2} ) def my_gen2(bgen, batch_size=5, predicate=None): b = (batch for batch in bgen) n = 0 batch_stack = [] while n < 400: # hardcoded n is a kludge; while-loop is necessary this_batch = next(b) if not predicate or predicate(this_batch): batch_stack.append(this_batch) n += 1 else: n += 1 continue if len(batch_stack) == batch_size: yield xr.concat(batch_stack, 'sample') batch_stack = []
This code can be summarized as 3 main steps:
- Use
xbatcher.BatchGeneratorto generate the chips/patches - Filter out invalid values based on a predicate True/False condition
- Use
xr.concatto create the stack of tensors
The fact that someone has to concat the tensors together after having already used BatchGenerator (which according to its name, should be for generating batches) indicates that BatchGenerator is sometimes used for half of the job (the chipping/slicing part). I've had to do the same xbatcher.BatchGenerator + concat workflow at https://zen3geo.readthedocs.io/en/v0.5.0/chipping.html#pool-chips-into-mini-batches, so this isn't an isolated incident.
While we could add a valid_example parameter to filter out NaNs or invalid values, my suggestion is to follow the torchdata compositional style and have a Slicer, Filter and Batcher do each of the 3 steps above. The reasoning is laid out in #172, and is because valid_example would not be the only parameter people would like to add, there's also caching at #109, creating train/val/test splits, shuffling, and so on, which would lead to an overly complicated BatchGenerator.
That said, we could theoretically add a valid_example filter parameter quite easily now, and handle all the extra Slicer/Filter/Batcher stuff in the background hidden from the user. This is if people are interested in using xbatcher.BatchGenerator as a 'one-liner' that does everything similar to something like pandas.read_csv.
@weiji14 thanks for showing interest in this problem!
The term 'predicate function' makes way more sense and I should have used that terminology from the start.
The main issue I see with the three additional steps is that the predicate gets applied to the batches sequentially and we lose the parallel and potentially distributed power of dask, which is critical for decently-scaled ML problems.
I sometimes have >1tb sized dataarrays with dims (variable, y, x) with 10% valid xy coordinates. The target variable that is sparse might be 10+gbs and all that would have to come down sequentially to apply the predicate. Instead of trying to get BatchGenerator to solve this, I create "Training Datasets" with the first dimension being the batched dimension in advance. We persist to zarr or to cluster memory because we also shuffle, which is relatively expensive op. Then we can iterate over the first dim for batching.
Not to open a can of worms, but I think adding a concept like "Training Dataset" to xbatcher to precompute costly predicate functions, reshaping/windowing and shuffling could help decouple the preprocessing from batch serving and be more performant. Then again, anyone can do this in advance and then use the BatchGenerator over the first dim in that dataset.
We still don't do this because even with all those ops out of the way, batch generator still only loads one batch into memory at a time unless it is already persisted (if this can be afforded). This could be fine if the dataset is persisted, but is limited. This is obviously out of scope, but relates to https://github.com/xarray-contrib/xbatcher/pull/161
Hi @ljstrnadiii, thanks for elaborating on your workflow. Do you have something working now? I'm curious to see what you had to do to get this working in a parallel-performant way.
@cmdupuis3
- for creating chip/patch datasets (or this training dataset in advance I am doing something like this) https://discourse.pangeo.io/t/any-suggestions-for-efficiently-operating-over-windows-of-data/3133/4?u=leonard_strnad
- we use tensorflow and build tfrecords from this "training dataset" with dask to speed things up
- other various_methods:
- if you use pytorch or don't want to use tfrecords and want to build a generator to serve up data, maybe something like #161. A dataloader (pytorch) or dataset (tf) could generate examples from a single-threaded op that kicks off batch prefetching on dask from within the submitted train op.
- You can also use threading with dataloader or tf.data.dataset to read from zarr directly.
The biggest step in gains for my use case comes from computing the training dataset in advance where the first dim contains the dim to batch over.
something like
dset = xr.open_zarr(...).to_array() # (# variable, y, x)
# extract valid training examples with extract_training_examples
training_dataset = dset.map_blocks(extract_training_examples) # (# valid examples, #variable, ...)
# persist to zarr or to cluser to get ops out of the way
# if training dataset can not fit into distributed memory
training_dataset.to_zarr(...)
training_dataset = xr.open_zarr(new_persisted_training_dataset_zarr_path)
# or if it can fit into memory
training_dataset = training_dataset.persist()
# then try various methods
Does that add any clarification?
Yeah, that's a lot clearer, thank you!