xbatcher
xbatcher copied to clipboard
Can't control batch size when all dims are input dims
Is your feature request related to a problem?
Title. Basically, in most cases you can control your batch size by setting the batch_dims
option in BatchGenerator. However, if you don't have any batch dims to start with, you are effectively unable to control your batch size.
e.g., for an xarray DataSet ds
with dims lat
and lon
, a BatchGenerator like
bgen = xb.BatchGenerator(
ds,
{'nlon':nlons, 'nlat':nlats}
)
offers no option to control batch size.
Describe the solution you'd like
I want to be able to pass an integer to BatchGenerator that tells it the size of the batch I want, in the case described above.
Maybe something like this, but wrapped as a BatchGenerator option.
Describe alternatives you've considered
No response
Additional context
I think this can probably be solved at the same time as #127
Alternatively, is it possible in this scenario to "rechunk" along the sample dimension (so you'd get like 32 x lon x lat)?
So I figured out I can use concat_input_dims=True
to get to a better state, but one giant batch is also not ideal if we're trying to parallelize stuff in the future. I tried playing with adding a dummy batch_dim and rechunking that way, but that doesn't work (and it probably shouldn't)
The problem seems to be more general than having all dims as input dims. As soon as the #input dims =# total_dims-1, the behavior gets degraded.
Here is an example:
da = xr.DataArray(np.random.rand(1000, 100, 100, 2 ,2), name='foo',
dims=['time', 'y', 'x','z','j']).chunk({'time': 1})
bgen_1D = xbatcher.BatchGenerator(da, {'x':10})
bgen_2D = xbatcher.BatchGenerator(da, {'x':10,'y':10})
bgen_3D = xbatcher.BatchGenerator(da,
input_dims={'x':20, 'y':20,'z':1})
bgen_4D = xbatcher.BatchGenerator(da,
input_dims={'x':20, 'y':20,'z':1,'j':1})
When we check the dims we see:
Tbh: the behavior of xbatcher when the input_dims>1 always seems a bit mysterious, and can be documented better.
A temporary solution can be to just create a size 1 expanded dimension.
Example:
da = xr.DataArray(np.random.rand(100, 100), name='foo',
dims=['y', 'x'])
bgen_1D = xbatcher.BatchGenerator(da, {'x':10})
bgen_2D = xbatcher.BatchGenerator(da, {'x':10,'y':10})
bgen_1D_ed = xbatcher.BatchGenerator(da.expand_dims('time'), {'x':10})
bgen_2D_ed = xbatcher.BatchGenerator(da.expand_dims('time'), {'x':10,'y':10})
bgen_2D_ed2 = xbatcher.BatchGenerator(da.expand_dims(['time1','time2']), {'x':10,'y':10})

Yeah this is connected to some other general weirdness about the number of input vs. concat dims. I'll try adding dummy dimensions again and see what I get (but it would be nice not to have to hack around this).
I am having some success with writing a generator wrapper like this (bad code alert!):
def batcher(bgen):
i = 0
while True:
print("{},\t\t{}".format(32*i, 32*i + 32))
yield xr.concat(itertools.islice(bgen, 32*i, 32*i + 32), dim='sample').transpose('sample', ...)
i += 1
However, you get a different kind of slowdown from having to slice the batch generator at different points. I realized you don't have to use the NN model's batch size here, so it could be larger, and you could find a good compromise between time spent slicing the batch generator, and time retrieving the next batch in your training loop.
Also wanted to note that this issue turns xbatcher into a massive memory hog, and it's probably related to #37 as well.
Why is there a deep copy here?
Why is there a deep copy here?
As I noted on the other thread, that is not a deep copy. It's a very shallow copy. Creating a copy of the data array avoids causing side effects to the user's inputs.
Along the lines of https://github.com/xarray-contrib/xbatcher/issues/162#issuecomment-1431902345, we can create fixed-size batches for the case of all dims being input dims by using a BatchGenerator
wrapper with the following structure:
import xarray as xr
import numpy as np
import xbatcher as xb
da1 = xr.DataArray(np.random.randint(0,9,(400,400)), dims=['d1', 'd2'])
da2 = xr.DataArray(np.random.randint(0,9,(400,400)), dims=['d1', 'd2'])
da3 = xr.DataArray(np.random.randint(0,9,(400,400)), dims=['d1', 'd2'])
ds = xr.Dataset({'da1':da1, 'da2':da2, 'da3':da3})
def batch_generator(bgen, batch_size):
b = (batch for batch in bgen)
n = 0
while n < 400: # hardcoded n is a kludge; while-loop is necessary
batch_stack = [ next(b) for i in range(batch_size) ]
yield xr.concat(batch_stack, 'sample')
n += 1
bgen = xb.BatchGenerator(
ds,
{'d1':20, 'd2':20},
{'d1':2, 'd2':2}
)
gen = batch_generator(bgen, 32)
a = []
for batch in bgen:
a = batch
break
a