Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dimension order should be set by input_dims #126

Open
maxrjones opened this issue Nov 17, 2022 · 1 comment
Open

Dimension order should be set by input_dims #126

maxrjones opened this issue Nov 17, 2022 · 1 comment
Labels
bug Something isn't working

Comments

@maxrjones
Copy link
Member

What is your issue?

In most cases, the batch generator will permute the dimension order to agree with the order specified in input_dims. Here is an example:

>>> import numpy as np
>>> import xarray as xr
>>> import xbatcher

>>> shape = (10, 50, 100, 200)
>>> ds = xr.Dataset(
...     {
...         "foo": (["time", "y", "x", "z"], np.random.rand(*shape)),
...         "bar": (["time", "y", "x", "z"], np.random.randint(0, 10, shape)),
...     },
...     {
...         "x": (["x"], np.arange(shape[-2])),
...         "y": (["y"], np.arange(shape[-3])),
...     },
... )
>>> print(ds)
<xarray.Dataset>
Dimensions:  (time: 10, y: 50, x: 100, z: 200)
Coordinates:
  * x        (x) int64 0 1 2 3 4 5 6 7 8 9 10 ... 90 91 92 93 94 95 96 97 98 99
  * y        (y) int64 0 1 2 3 4 5 6 7 8 9 10 ... 40 41 42 43 44 45 46 47 48 49
Dimensions without coordinates: time, z
Data variables:
    foo      (time, y, x, z) float64 0.6615 0.04028 0.8633 ... 0.4632 0.6561
    bar      (time, y, x, z) int64 8 0 9 4 8 9 2 6 7 7 5 ... 3 2 0 7 2 3 2 1 3 6
>>> print(ds['foo'].shape)
(10, 50, 100, 200)
>>> bg = xbatcher.BatchGenerator(ds, input_dims={'x': 10, 'y': 5})
>>> print(bg[0])
<xarray.Dataset>
Dimensions:  (y: 5, x: 10, sample: 2000)
Coordinates:
  * x        (x) int64 0 1 2 3 4 5 6 7 8 9
  * y        (y) int64 0 1 2 3 4
  * sample   (sample) object MultiIndex
  * time     (sample) int64 0 0 0 0 0 0 0 0 0 0 0 0 ... 9 9 9 9 9 9 9 9 9 9 9 9
  * z        (sample) int64 0 1 2 3 4 5 6 7 ... 192 193 194 195 196 197 198 199
Data variables:
    foo      (sample, x, y) float64 0.6615 0.8259 0.09629 ... 0.2105 0.09571
    bar      (sample, x, y) int64 8 4 0 6 0 4 4 0 5 4 5 ... 2 3 8 3 4 1 6 1 9 4
>>> print(bg[0]['foo'].shape)
(2000, 10, 5)

In at least one case, the original dimension order is retained:

>>> import numpy as np
>>> import xarray as xr
>>> import xbatcher
>>> shape = (10, 50, 100)
>>> ds = xr.Dataset(
...     {
...         "foo": (["time", "y", "x"], np.random.rand(*shape)),
...         "bar": (["time", "y", "x"], np.random.randint(0, 10, shape)),
...     },
...     {
...         "x": (["x"], np.arange(shape[-1])),
...         "y": (["y"], np.arange(shape[-2])),
...     },
... )
# Original dimensions permuted
>>> bg = xbatcher.BatchGenerator(
...     ds,
...     input_dims={"x": 5, "y": 10},
...     batch_dims={"time": 2},
...     concat_input_dims=True,
... )
>>> print(bg[0])
<xarray.Dataset>
Dimensions:      (y_input: 10, x_input: 5, sample: 1000)
Coordinates:
    x            (sample, x_input) int64 0 1 2 3 4 0 1 ... 98 99 95 96 97 98 99
    y            (sample, y_input) int64 0 1 2 3 4 5 6 ... 43 44 45 46 47 48 49
  * sample       (sample) object MultiIndex
  * input_batch  (sample) int64 0 0 0 0 0 0 0 0 0 ... 99 99 99 99 99 99 99 99 99
  * time         (sample) int64 0 1 2 3 4 5 6 7 8 9 0 ... 9 0 1 2 3 4 5 6 7 8 9
Dimensions without coordinates: y_input, x_input
Data variables:
    foo          (sample, x_input, y_input) float64 0.3198 0.3109 ... 0.5785
    bar          (sample, x_input, y_input) int64 1 8 5 6 9 8 7 ... 6 0 9 4 8 5
>>> print(bg[0]['foo'].shape)
(1000, 5, 10)
# Original dimension order retained
>>> bg = xbatcher.BatchGenerator(
...     ds,
...     input_dims={"x": 5, "y": 10},
...     batch_dims={"time": 2},
...     concat_input_dims=False,
... )
>>> print(bg[0])
<xarray.Dataset>
Dimensions:  (time: 10, y: 10, x: 5)
Coordinates:
  * x        (x) int64 0 1 2 3 4
  * y        (y) int64 0 1 2 3 4 5 6 7 8 9
Dimensions without coordinates: time
Data variables:
    foo      (time, y, x) float64 0.3198 0.5306 0.3465 ... 0.7873 0.5106 0.9177
    bar      (time, y, x) int64 1 0 2 6 5 8 0 1 2 0 5 ... 1 2 0 2 0 7 5 6 4 8 3
>>> print(bg[0]['foo'].shape)
(10, 10, 5)

We should document the intended behavior for ordering dimensions and test that the shape is consistent. I would have expected that the original dimension would be retained, in contrast to the most common behavior of the batch generator. @jhamman can you provide insight into the original intended behavior?

@maxrjones maxrjones added the question Further information is requested label Nov 17, 2022
@maxrjones maxrjones added bug Something isn't working and removed question Further information is requested labels Dec 2, 2022
@maxrjones maxrjones changed the title Should input_dims affect the dimension order for batches? Dimension order should be set by input_dims Dec 2, 2022
@maxrjones
Copy link
Member Author

I will treat the edge case in which the output dimension order does not agree with the order specified by input_dims as a bug and submit a fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant