Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Predicate option for BatchGenerator #162

Open
cmdupuis3 opened this issue Jan 20, 2023 · 3 comments
Open

Predicate option for BatchGenerator #162

cmdupuis3 opened this issue Jan 20, 2023 · 3 comments
Labels
duplicate This issue or pull request already exists feature

Comments

@cmdupuis3
Copy link

cmdupuis3 commented Jan 20, 2023

Is your feature request related to a problem?

When you create a batch generator, what happens when you have data with NaNs? For example, if we consider an ocean data set, like a map of sea surface temperature, you may iterate through different regions where the stencil is valid, partially valid, or completely full of NaNs. The fact that xbatcher can't filter for these situations means that if you need this, you will have to apply filters inside the batch loop, meaning that you will end up with load imbalances.

Describe the solution you'd like

I would like to see an option in BatchGenerator for a selection predicate. Basically, you would pass a function to BatchGenerator that takes slices as inputs, and evaluates to either True or False. BatchGenerator would then use the result to select only the slices that returned True, thereby restoring load balance.

Describe alternatives you've considered

No response

Additional context

I think this is similar to #158

@cmdupuis3
Copy link
Author

Just learned about itertools.compress, that might be what is needed here.

@cmdupuis3 cmdupuis3 changed the title NaN filters at batch generation Predicate option for BatchGenerator Feb 14, 2023
@cmdupuis3
Copy link
Author

cmdupuis3 commented Feb 14, 2023

While experimenting with some xbatcher ideas, I came up with a function that basically implements a predicate as would be desirable here. Here is a minimal example of what this functionality could look like. This code is hacky and pretty slow, but I think the core functionality (creating fixed-size batches with a predicate) could be very useful in some situations, especially with NaN-filtering.

@maxrjones @rabernat What do y'all think?

This example script makes a list of 1000 integers valued 0-9, then finds the indices of all the eights and returns them in batches of five. The use case I was interested in was having fixed-size batches of size 32, while discarding NxN samples that had either any NaNs, or were all NaNs.

import random

def my_gen(batch_set, batch_size=5, predicate=None, sample_dim_name=None):
    n = 0
    while n < 1000:
        m = 0
        batch=[]
        while m < batch_size:
            if n >= 1000:
                break
            else:
                this_batch = batch_set[{sample_dim_name:n}] if sample_dim else batch_set[n]
                if not predicate or predicate(this_batch):
                    batch.append(n)
                    m += 1
                    n += 1
                else:
                    n += 1
                    continue
        if n >= 1000:
            break
        yield batch

batch_set = []
for i in range(1000):
    batch_set.append(random.randint(0,9))

pred = lambda this_batch: this_batch == 8

gen = my_gen(batch_set, predicate=pred)
for g in gen:
    print(g)

@cmdupuis3
Copy link
Author

cmdupuis3 commented Feb 15, 2023

Better code sample, which wraps xbatcher and also offers fixed batch sizes:

import xarray as xr
import xbatcher as xb
import numpy as np
import random

da1 = xr.DataArray(np.random.randint(0,9,(400,400)), dims=['d1', 'd2'])
da2 = xr.DataArray(np.random.randint(0,9,(400,400)), dims=['d1', 'd2'])
da3 = xr.DataArray(np.random.randint(0,9,(400,400)), dims=['d1', 'd2'])
ds = xr.Dataset({'da1':da1, 'da2':da2, 'da3':da3})

bgen = xb.BatchGenerator(
    ds,
    {'d1':5, 'd2':5},
    {'d1':2, 'd2':2}
)

def my_gen2(bgen, batch_size=5, predicate=None):
    b = (batch for batch in bgen)
    n = 0
    batch_stack = []
    while n < 400: # hardcoded n is a kludge; while-loop is necessary
        this_batch = next(b)
        if not predicate or predicate(this_batch):
            batch_stack.append(this_batch)
            n += 1
        else:
            n += 1
            continue
        if len(batch_stack) == batch_size:
            yield xr.concat(batch_stack, 'sample')
            batch_stack = []
    

pred2 = lambda batch: np.mod(batch['da1'].sum(), 10) == 0

gen = my_gen2(bgen, batch_size=2, predicate=pred2)
res = []
for g in gen:
    res.append(g)
len(res)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
duplicate This issue or pull request already exists feature
Projects
None yet
Development

No branches or pull requests

2 participants