Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't control batch size when all dims are input dims #163

Open
cmdupuis3 opened this issue Jan 20, 2023 · 9 comments
Open

Can't control batch size when all dims are input dims #163

cmdupuis3 opened this issue Jan 20, 2023 · 9 comments

Comments

@cmdupuis3
Copy link

cmdupuis3 commented Jan 20, 2023

Is your feature request related to a problem?

Title. Basically, in most cases you can control your batch size by setting the batch_dims option in BatchGenerator. However, if you don't have any batch dims to start with, you are effectively unable to control your batch size.

e.g., for an xarray DataSet ds with dims lat and lon, a BatchGenerator like

    bgen = xb.BatchGenerator(
        ds,
        {'nlon':nlons, 'nlat':nlats}
    )

offers no option to control batch size.

Describe the solution you'd like

I want to be able to pass an integer to BatchGenerator that tells it the size of the batch I want, in the case described above.

Maybe something like this, but wrapped as a BatchGenerator option.

Describe alternatives you've considered

No response

Additional context

I think this can probably be solved at the same time as #127

@cmdupuis3
Copy link
Author

cmdupuis3 commented Jan 20, 2023

Alternatively, is it possible in this scenario to "rechunk" along the sample dimension (so you'd get like 32 x lon x lat)?

@cmdupuis3
Copy link
Author

cmdupuis3 commented Jan 20, 2023

So I figured out I can use concat_input_dims=True to get to a better state, but one giant batch is also not ideal if we're trying to parallelize stuff in the future. I tried playing with adding a dummy batch_dim and rechunking that way, but that doesn't work (and it probably shouldn't)

@dhruvbalwada
Copy link

The problem seems to be more general than having all dims as input dims. As soon as the #input dims =# total_dims-1, the behavior gets degraded.

Here is an example:

da = xr.DataArray(np.random.rand(1000, 100, 100, 2 ,2), name='foo',
                  dims=['time', 'y', 'x','z','j']).chunk({'time': 1})

bgen_1D = xbatcher.BatchGenerator(da, {'x':10})
bgen_2D = xbatcher.BatchGenerator(da, {'x':10,'y':10})
bgen_3D = xbatcher.BatchGenerator(da, 
                               input_dims={'x':20, 'y':20,'z':1})
bgen_4D = xbatcher.BatchGenerator(da, 
                               input_dims={'x':20, 'y':20,'z':1,'j':1})

When we check the dims we see:
Screen Shot 2023-01-30 at 4 49 28 PM

Tbh: the behavior of xbatcher when the input_dims>1 always seems a bit mysterious, and can be documented better.

@dhruvbalwada
Copy link

A temporary solution can be to just create a size 1 expanded dimension.

Example:

da = xr.DataArray(np.random.rand(100, 100), name='foo',
                  dims=['y', 'x'])

bgen_1D = xbatcher.BatchGenerator(da, {'x':10})
bgen_2D = xbatcher.BatchGenerator(da, {'x':10,'y':10})
bgen_1D_ed = xbatcher.BatchGenerator(da.expand_dims('time'), {'x':10})
bgen_2D_ed = xbatcher.BatchGenerator(da.expand_dims('time'), {'x':10,'y':10})
bgen_2D_ed2 = xbatcher.BatchGenerator(da.expand_dims(['time1','time2']), {'x':10,'y':10})

Screen Shot 2023-01-30 at 5 07 23 PM

Note that expanding dims such that total_dims> input_dims+1 creates the sample dimension.

@cmdupuis3
Copy link
Author

Yeah this is connected to some other general weirdness about the number of input vs. concat dims. I'll try adding dummy dimensions again and see what I get (but it would be nice not to have to hack around this).

I am having some success with writing a generator wrapper like this (bad code alert!):

def batcher(bgen):
    i = 0
    while True:
        print("{},\t\t{}".format(32*i, 32*i + 32))
        yield xr.concat(itertools.islice(bgen, 32*i, 32*i + 32), dim='sample').transpose('sample', ...)
        i += 1

However, you get a different kind of slowdown from having to slice the batch generator at different points. I realized you don't have to use the NN model's batch size here, so it could be larger, and you could find a good compromise between time spent slicing the batch generator, and time retrieving the next batch in your training loop.

@cmdupuis3
Copy link
Author

Also wanted to note that this issue turns xbatcher into a massive memory hog, and it's probably related to #37 as well.

@cmdupuis3
Copy link
Author

Why is there a deep copy here?

@rabernat
Copy link
Contributor

Why is there a deep copy here?

As I noted on the other thread, that is not a deep copy. It's a very shallow copy. Creating a copy of the data array avoids causing side effects to the user's inputs.

@cmdupuis3
Copy link
Author

Along the lines of #162 (comment), we can create fixed-size batches for the case of all dims being input dims by using a BatchGenerator wrapper with the following structure:

import xarray as xr
import numpy as np
import xbatcher as xb

da1 = xr.DataArray(np.random.randint(0,9,(400,400)), dims=['d1', 'd2'])
da2 = xr.DataArray(np.random.randint(0,9,(400,400)), dims=['d1', 'd2'])
da3 = xr.DataArray(np.random.randint(0,9,(400,400)), dims=['d1', 'd2'])
ds = xr.Dataset({'da1':da1, 'da2':da2, 'da3':da3})

def batch_generator(bgen, batch_size):
    b = (batch for batch in bgen)
    n = 0
    while n < 400: # hardcoded n is a kludge; while-loop is necessary
        batch_stack = [ next(b) for i in range(batch_size) ]
        yield xr.concat(batch_stack, 'sample')
        n += 1

bgen = xb.BatchGenerator(
    ds,
    {'d1':20, 'd2':20},
    {'d1':2,  'd2':2}
)

gen = batch_generator(bgen, 32)

a = []
for batch in bgen:
    a = batch
    break
a

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants