xbatcher icon indicating copy to clipboard operation
xbatcher copied to clipboard

Can't control batch size when all dims are input dims

Open cmdupuis3 opened this issue 2 years ago • 9 comments

Is your feature request related to a problem?

Title. Basically, in most cases you can control your batch size by setting the batch_dims option in BatchGenerator. However, if you don't have any batch dims to start with, you are effectively unable to control your batch size.

e.g., for an xarray DataSet ds with dims lat and lon, a BatchGenerator like

    bgen = xb.BatchGenerator(
        ds,
        {'nlon':nlons, 'nlat':nlats}
    )

offers no option to control batch size.

Describe the solution you'd like

I want to be able to pass an integer to BatchGenerator that tells it the size of the batch I want, in the case described above.

Maybe something like this, but wrapped as a BatchGenerator option.

Describe alternatives you've considered

No response

Additional context

I think this can probably be solved at the same time as #127

cmdupuis3 avatar Jan 20 '23 18:01 cmdupuis3

Alternatively, is it possible in this scenario to "rechunk" along the sample dimension (so you'd get like 32 x lon x lat)?

cmdupuis3 avatar Jan 20 '23 18:01 cmdupuis3

So I figured out I can use concat_input_dims=True to get to a better state, but one giant batch is also not ideal if we're trying to parallelize stuff in the future. I tried playing with adding a dummy batch_dim and rechunking that way, but that doesn't work (and it probably shouldn't)

cmdupuis3 avatar Jan 20 '23 20:01 cmdupuis3

The problem seems to be more general than having all dims as input dims. As soon as the #input dims =# total_dims-1, the behavior gets degraded.

Here is an example:

da = xr.DataArray(np.random.rand(1000, 100, 100, 2 ,2), name='foo',
                  dims=['time', 'y', 'x','z','j']).chunk({'time': 1})

bgen_1D = xbatcher.BatchGenerator(da, {'x':10})
bgen_2D = xbatcher.BatchGenerator(da, {'x':10,'y':10})
bgen_3D = xbatcher.BatchGenerator(da, 
                               input_dims={'x':20, 'y':20,'z':1})
bgen_4D = xbatcher.BatchGenerator(da, 
                               input_dims={'x':20, 'y':20,'z':1,'j':1})

When we check the dims we see: Screen Shot 2023-01-30 at 4 49 28 PM

Tbh: the behavior of xbatcher when the input_dims>1 always seems a bit mysterious, and can be documented better.

dhruvbalwada avatar Jan 30 '23 22:01 dhruvbalwada

A temporary solution can be to just create a size 1 expanded dimension.

Example:

da = xr.DataArray(np.random.rand(100, 100), name='foo',
                  dims=['y', 'x'])

bgen_1D = xbatcher.BatchGenerator(da, {'x':10})
bgen_2D = xbatcher.BatchGenerator(da, {'x':10,'y':10})
bgen_1D_ed = xbatcher.BatchGenerator(da.expand_dims('time'), {'x':10})
bgen_2D_ed = xbatcher.BatchGenerator(da.expand_dims('time'), {'x':10,'y':10})
bgen_2D_ed2 = xbatcher.BatchGenerator(da.expand_dims(['time1','time2']), {'x':10,'y':10})
Screen Shot 2023-01-30 at 5 07 23 PM Note that expanding dims such that total_dims> input_dims+1 creates the sample dimension.

dhruvbalwada avatar Jan 30 '23 22:01 dhruvbalwada

Yeah this is connected to some other general weirdness about the number of input vs. concat dims. I'll try adding dummy dimensions again and see what I get (but it would be nice not to have to hack around this).

I am having some success with writing a generator wrapper like this (bad code alert!):

def batcher(bgen):
    i = 0
    while True:
        print("{},\t\t{}".format(32*i, 32*i + 32))
        yield xr.concat(itertools.islice(bgen, 32*i, 32*i + 32), dim='sample').transpose('sample', ...)
        i += 1

However, you get a different kind of slowdown from having to slice the batch generator at different points. I realized you don't have to use the NN model's batch size here, so it could be larger, and you could find a good compromise between time spent slicing the batch generator, and time retrieving the next batch in your training loop.

cmdupuis3 avatar Jan 30 '23 23:01 cmdupuis3

Also wanted to note that this issue turns xbatcher into a massive memory hog, and it's probably related to #37 as well.

cmdupuis3 avatar Jan 30 '23 23:01 cmdupuis3

Why is there a deep copy here?

cmdupuis3 avatar Jan 31 '23 00:01 cmdupuis3

Why is there a deep copy here?

As I noted on the other thread, that is not a deep copy. It's a very shallow copy. Creating a copy of the data array avoids causing side effects to the user's inputs.

rabernat avatar Jan 31 '23 03:01 rabernat

Along the lines of https://github.com/xarray-contrib/xbatcher/issues/162#issuecomment-1431902345, we can create fixed-size batches for the case of all dims being input dims by using a BatchGenerator wrapper with the following structure:

import xarray as xr
import numpy as np
import xbatcher as xb

da1 = xr.DataArray(np.random.randint(0,9,(400,400)), dims=['d1', 'd2'])
da2 = xr.DataArray(np.random.randint(0,9,(400,400)), dims=['d1', 'd2'])
da3 = xr.DataArray(np.random.randint(0,9,(400,400)), dims=['d1', 'd2'])
ds = xr.Dataset({'da1':da1, 'da2':da2, 'da3':da3})

def batch_generator(bgen, batch_size):
    b = (batch for batch in bgen)
    n = 0
    while n < 400: # hardcoded n is a kludge; while-loop is necessary
        batch_stack = [ next(b) for i in range(batch_size) ]
        yield xr.concat(batch_stack, 'sample')
        n += 1

bgen = xb.BatchGenerator(
    ds,
    {'d1':20, 'd2':20},
    {'d1':2,  'd2':2}
)

gen = batch_generator(bgen, 32)

a = []
for batch in bgen:
    a = batch
    break
a

cmdupuis3 avatar Feb 15 '23 22:02 cmdupuis3