Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issues when running flox with chunked pint arrays #163

Open
riley-brady opened this issue Oct 6, 2022 · 3 comments · May be fixed by #207
Open

Issues when running flox with chunked pint arrays #163

riley-brady opened this issue Oct 6, 2022 · 3 comments · May be fixed by #207
Labels
array-types bug Something isn't working enhancement New feature or request

Comments

@riley-brady
Copy link

riley-brady commented Oct 6, 2022

After upgrading to the latest xarray version and installing flox, I find that chunked pint arrays break with the .resample() method. I'm posting this here instead of pint_xarray since it looks like from the traceback this is coming from flox.

I imagine this has to do with the complexity of working with duck-arrays like pint_xarray.

Possible related threads:

import xarray as xr
import pint_xarray
import flox

xr.__version__
>>> '2022.9.0'

pint_xarray.__version__
>>> '0.3'

flox.__version__
>>> '0.5.9'

time_ax = xr.cftime_range('2020-06-01 01:00:00', freq='H', periods=3)
ds = xr.DataArray(range(3), dims='time', coords={'time': time_ax})

# Simple case, no dask or pint
ds.resample(time="D").mean()
>>> <xarray.DataArray (time: 1)>
>>> array([1.])
>>> Coordinates:
>>>   * time     (time) object 2020-06-01 00:00:00

# Dask case
ds_chunked = ds.chunk({'time': 1})
ds_chunked.resample(time="D").mean().compute()
>>> <xarray.DataArray (time: 1)>
>>> array([1.])
>>> Coordinates:
>>>   * time     (time) object 2020-06-01 00:00:00

# Pint case
ds_pint = ds.pint.quantify('kelvin')
ds_pint.resample(time="D").mean()
>>> <xarray.DataArray (time: 1)>
>>> <Quantity([1.], 'kelvin')>
>>> Coordinates:
>>>  * time     (time) object 2020-06-01 00:00:00

# Pint with xarray chunk
ds_pint_chunk = ds_pint.chunk({'time': 1})
ds_pint_chunk.resample(time="D").mean().compute()
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [25], in <cell line: 3>()
      1 # Pint with xarray chunk
      2 ds_pint_chunk = ds_pint.chunk({'time': 1})
----> 3 ds_pint_chunk.resample(time="D").mean().compute()

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/xarray/core/dataarray.py:1083, in DataArray.compute(self, **kwargs)
   1064 """Manually trigger loading of this array's data from disk or a
   1065 remote source into memory and return a new array. The original is
   1066 left unaltered.
   (...)
   1080 dask.compute
   1081 """
   1082 new = self.copy(deep=False)
-> 1083 return new.load(**kwargs)

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/xarray/core/dataarray.py:1057, in DataArray.load(self, **kwargs)
   1039 def load(self: T_DataArray, **kwargs) -> T_DataArray:
   1040     """Manually trigger loading of this array's data from disk or a
   1041     remote source into memory and return this array.
   1042 
   (...)
   1055     dask.compute
   1056     """
-> 1057     ds = self._to_temp_dataset().load(**kwargs)
   1058     new = self._from_temp_dataset(ds)
   1059     self._variable = new._variable

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/xarray/core/dataset.py:734, in Dataset.load(self, **kwargs)
    731 import dask.array as da
    733 # evaluate all the dask arrays simultaneously
--> 734 evaluated_data = da.compute(*lazy_data.values(), **kwargs)
    736 for k, data in zip(lazy_data, evaluated_data):
    737     self.variables[k].data = data

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/base.py:600, in compute(traverse, optimize_graph, scheduler, get, *args, **kwargs)
    597     keys.append(x.__dask_keys__())
    598     postcomputes.append(x.__dask_postcompute__())
--> 600 results = schedule(dsk, keys, **kwargs)
    601 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/threaded.py:89, in get(dsk, keys, cache, num_workers, pool, **kwargs)
     86     elif isinstance(pool, multiprocessing.pool.Pool):
     87         pool = MultiprocessingPoolExecutor(pool)
---> 89 results = get_async(
     90     pool.submit,
     91     pool._max_workers,
     92     dsk,
     93     keys,
     94     cache=cache,
     95     get_id=_thread_get_id,
     96     pack_exception=pack_exception,
     97     **kwargs,
     98 )
    100 # Cleanup pools associated to dead threads
    101 with pools_lock:

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/local.py:511, in get_async(submit, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, chunksize, **kwargs)
    509         _execute_task(task, data)  # Re-execute locally
    510     else:
--> 511         raise_exception(exc, tb)
    512 res, worker_id = loads(res_info)
    513 state["cache"][key] = res

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/local.py:319, in reraise(exc, tb)
    317 if exc.__traceback__ is not tb:
    318     raise exc.with_traceback(tb)
--> 319 raise exc

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/local.py:224, in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
    222 try:
    223     task, data = loads(task_info)
--> 224     result = _execute_task(task, data)
    225     id = get_id()
    226     result = dumps((result, id))

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/core.py:119, in _execute_task(arg, cache, dsk)
    115     func, args = arg[0], arg[1:]
    116     # Note: Don't assign the subtask results to a variable. numpy detects
    117     # temporaries by their reference count and can execute certain
    118     # operations in-place.
--> 119     return func(*(_execute_task(a, cache) for a in args))
    120 elif not ishashable(arg):
    121     return arg

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/optimization.py:990, in SubgraphCallable.__call__(self, *args)
    988 if not len(args) == len(self.inkeys):
    989     raise ValueError("Expected %d args, got %d" % (len(self.inkeys), len(args)))
--> 990 return core.get(self.dsk, self.outkey, dict(zip(self.inkeys, args)))

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/core.py:149, in get(dsk, out, cache)
    147 for key in toposort(dsk):
    148     task = dsk[key]
--> 149     result = _execute_task(task, cache)
    150     cache[key] = result
    151 result = _execute_task(out, cache)

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/core.py:119, in _execute_task(arg, cache, dsk)
    115     func, args = arg[0], arg[1:]
    116     # Note: Don't assign the subtask results to a variable. numpy detects
    117     # temporaries by their reference count and can execute certain
    118     # operations in-place.
--> 119     return func(*(_execute_task(a, cache) for a in args))
    120 elif not ishashable(arg):
    121     return arg

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/toolz/functoolz.py:487, in Compose.__call__(self, *args, **kwargs)
    486 def __call__(self, *args, **kwargs):
--> 487     ret = self.first(*args, **kwargs)
    488     for f in self.funcs:
    489         ret = f(ret)

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/flox/core.py:689, in chunk_reduce(array, by, func, expected_groups, axis, fill_value, dtype, reindex, engine, kwargs, sort)
    687     result = reduction(group_idx, array, **kwargs)
    688 else:
--> 689     result = generic_aggregate(
    690         group_idx, array, axis=-1, engine=engine, func=reduction, **kwargs
    691     ).astype(dt, copy=False)
    692 if np.any(props.nanmask):
    693     # remove NaN group label which should be last
    694     result = result[..., :-1]

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/flox/aggregations.py:49, in generic_aggregate(group_idx, array, engine, func, axis, size, fill_value, dtype, **kwargs)
     44 else:
     45     raise ValueError(
     46         f"Expected engine to be one of ['flox', 'numpy', 'numba']. Received {engine} instead."
     47     )
---> 49 return method(
     50     group_idx, array, axis=axis, size=size, fill_value=fill_value, dtype=dtype, **kwargs
     51 )

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/flox/aggregate_flox.py:33, in _np_grouped_op(group_idx, array, op, axis, size, fill_value, dtype, out)
     26     out = np.full(array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype)
     28 if (len(uniques) == size) and (uniques == np.arange(size)).all():
     29     # The previous version of this if condition
     30     #     ((uniques[1:] - uniques[:-1]) == 1).all():
     31     # does not work when group_idx is [1, 2] for e.g.
     32     # This happens  during binning
---> 33     op.reduceat(array, inv_idx, axis=axis, dtype=dtype, out=out)
     34 else:
     35     out[..., uniques] = op.reduceat(array, inv_idx, axis=axis, dtype=dtype)

TypeError: operand type(s) all returned NotImplemented from __array_ufunc__(<ufunc 'add'>, 'reduceat', <Quantity([0], 'kelvin')>, array([0]), axis=-1, dtype=dtype('int64'), out=(array([0]),)): 'Quantity', 'ndarray', 'ndarray'
@dcherian
Copy link
Collaborator

dcherian commented Oct 6, 2022

OK that wont work but it should have not gone down this code path at all in xarray. But it looks like I only tested pure pint arrays not pint + dask:
https://github.com/pydata/xarray/blob/50ea159bfd0872635ebf4281e741f3c87f0bef6b/xarray/core/utils.py#L980

It'd be nice to add full pint support here but it'll be a bit of effort. Are you interested in working on it?

@riley-brady
Copy link
Author

@dcherian sorry for the delay here. I could work on this effort, but unfortunately only on weekends, so it might be a long process. I would appreciate some guidance if you have some time (either over chat here or a zoom call) on which parts of the code to target, since I haven't worked closely with the package.

My current solution is to dequantify, run resample() or whichever other method this is happening on, and then quantify, which isn't ideal but works. The error message is not super clear, so I'm not sure that's a sustainable solution for the community as a whole.

@dcherian
Copy link
Collaborator

dcherian commented Nov 5, 2022

Thanks for offering to help @riley-brady

My current solution is to dequantify, run resample() or whichever other method this is happening on, and then quantify, which isn't ideal but works.

I think this is what we'll have to do since pint's support for ufuncs isn't great apparently.

lets strip array units if any right at the beginning, and reapply it at the end

flox/flox/core.py

Line 1641 in e3ea0e7

  1. We won't be handling units on by but I think that's OK for now? Alternatively you could again dequantify and then quantify after compute.
  2. To figure out output units, we'll have to run getattr(numpy, agg.name)(Quantity([1, 1,], dtype=array.dtype, units=array.units) So basically run the aggregation on a small problem to determine what the output units are (necessary for any, all, var, arg* for e.g.), and apply that at the end. This approach won't work for "custom aggregations" but we can deal with that later when we need to.

@dcherian dcherian linked a pull request Jan 24, 2023 that will close this issue
@dcherian dcherian added bug Something isn't working enhancement New feature or request labels Jan 25, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
array-types bug Something isn't working enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants