Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

where dtype upcast with numpy 2 #8402

Open
5 tasks done
djhoese opened this issue Nov 2, 2023 · 10 comments · May be fixed by #8946
Open
5 tasks done

where dtype upcast with numpy 2 #8402

djhoese opened this issue Nov 2, 2023 · 10 comments · May be fixed by #8946
Labels

Comments

@djhoese
Copy link
Contributor

djhoese commented Nov 2, 2023

What happened?

I'm testing my code with numpy 2.0 and current main xarray and dask and ran into a change that I guess is expected given the way xarray does things, but want to make sure as it could be unexpected for many users.

Doing DataArray.where with an integer array less than 64-bits and an integer as the new value will upcast the array to 64-bit integers (python's int). With old versions of numpy this would preserve the dtype of the array. As far as I can tell the relevant xarray code hasn't changed so this seems to be more about numpy making things more consistent.

The main problem seems to come down to:

arrays = [asarray(x, xp=xp) for x in scalars_or_arrays]

As this converts my scalar input int to a numpy array. If it didn't do this array conversion then numpy works as expected. See the MCVE for the xarray specific example, but here's the numpy equivalent:

import numpy as np

a = np.zeros((2, 2), dtype=np.uint16)

# what I'm intending to do with my xarray `data_arr.where(cond, 2)`
np.where(a != 0, a, 2).dtype
# dtype('uint16')

# equivalent to what xarray does:
np.where(a != 0, a, np.asarray(2)).dtype
# dtype('int64')

# workaround, cast my scalar to a specific numpy type
np.where(a != 0, a, np.asarray(np.uint16(2))).dtype
# dtype('uint16')

From a numpy point of view, the second where call makes sense that 2 arrays should be upcast to the same dtype so they can be combined. But from an xarray user point of view, I'm entering a scalar so I expect it to be the same as the first where call above.

What did you expect to happen?

See above.

Minimal Complete Verifiable Example

import xarray as xr
import numpy as np

data_arr = xr.DataArray(np.array([1, 2], dtype=np.uint16))
print(data_arr.where(data_arr == 2, 3).dtype)
# int64

MVCE confirmation

  • Minimal example — the example is as focused as reasonably possible to demonstrate the underlying issue in xarray.
  • Complete example — the example is self-contained, including all data and the text of any traceback.
  • Verifiable example — the example copy & pastes into an IPython prompt or Binder notebook, returning the result.
  • New issue — a search of GitHub Issues suggests this is not a duplicate.
  • Recent environment — the issue occurs with the latest version of xarray and its dependencies.

Relevant log output

No response

Anything else we need to know?

Numpy 1.x preserves the dtype.

In [1]: import numpy as np

In [2]: np.asarray(2).dtype
Out[2]: dtype('int64')

In [3]: a = np.zeros((2, 2), dtype=np.uint16)

In [4]: np.where(a != 0, a, np.asarray(2)).dtype
Out[4]: dtype('uint16')

In [5]: np.where(a != 0, a, np.asarray(np.uint16(2))).dtype
Out[5]: dtype('uint16')

Environment

INSTALLED VERSIONS
------------------
commit: None
python: 3.11.4 | packaged by conda-forge | (main, Jun 10 2023, 18:08:17) [GCC 12.2.0]
python-bits: 64
OS: Linux
OS-release: 6.4.6-76060406-generic
machine: x86_64
processor: x86_64
byteorder: little
LC_ALL: None
LANG: en_US.UTF-8
LOCALE: ('en_US', 'UTF-8')
libhdf5: 1.14.2
libnetcdf: 4.9.2

xarray: 2023.10.2.dev21+gfcdc8102
pandas: 2.2.0.dev0+495.gecf449b503
numpy: 2.0.0.dev0+git20231031.42c33f3
scipy: 1.12.0.dev0+1903.18d0a2f
netCDF4: 1.6.5
pydap: None
h5netcdf: 1.2.0
h5py: 3.10.0
Nio: None
zarr: 2.16.1
cftime: 1.6.3
nc_time_axis: None
PseudoNetCDF: None
iris: None
bottleneck: 1.3.7.post0.dev7
dask: 2023.10.1+4.g91098a63
distributed: 2023.10.1+5.g76dd8003
matplotlib: 3.9.0.dev0
cartopy: None
seaborn: None
numbagg: None
fsspec: 2023.6.0
cupy: None
pint: 0.22
sparse: None
flox: None
numpy_groupies: None
setuptools: 68.0.0
pip: 23.2.1
conda: None
pytest: 7.4.0
mypy: None
IPython: 8.14.0
sphinx: 7.1.2
@djhoese djhoese added bug needs triage Issue that has not been reviewed by xarray team member labels Nov 2, 2023
@dcherian
Copy link
Contributor

dcherian commented Nov 2, 2023

Thanks for the well written issue.

I'm not sure what to do here: asarray(x, xp=xp) for x in scalars_or_arrays is quite intentionally converting scalars to 0D arrays.

Perhaps @shoyer has thoughts.

@dcherian dcherian removed the needs triage Issue that has not been reviewed by xarray team member label Nov 2, 2023
@djhoese
Copy link
Contributor Author

djhoese commented Nov 2, 2023

FYI it looks like this effects data_arr.where(cond, np.nan) too. This used to retain the floating type of data_arr, but now is upcasting 32-bit floats to 64-bit float. I know a lot of work was done for NA handling in xarray, but this automatic np.nan 32 or 64-bit handling worked after that. Now it is back to upcasting.

@shoyer
Copy link
Member

shoyer commented Nov 3, 2023

I think we should probably make an exception for Python built-in numbers when casting to arrays. Something closer to:

scalar_types = (int, float, complex)
arrays = [x if isinstance(x, scalar_types) else asarray(x, xp=xp) for x in scalars_or_arrays] 

@djhoese
Copy link
Contributor Author

djhoese commented Nov 3, 2023

@shoyer do you have any memory why the scalar types are converted to arrays? I don't think your suggestion catches the case where np.nan is passed which if that has to be that way then I'll work around it, but if it can be handled by a similar change it'd make things easier for me. But before I can suggest anything I'll need to figure out what the scalar -> array conversion is meant to catch/fix.

@keewis
Copy link
Collaborator

keewis commented Nov 3, 2023

If I remember correctly, the array API does not support python scalar types (or numpy scalars), so we have to convert them to 0D arrays (see also the discussion in #7721).

@shoyer
Copy link
Member

shoyer commented Nov 4, 2023 via email

@djhoese
Copy link
Contributor Author

djhoese commented Nov 6, 2023

I started hacking around with your idea @shoyer and had the realization that np.nan is a python builtin float so it should work fine. This is also still inline with NEP 50 and handling scalars. I did get my simple np.nan with a DataArray.where working with the below:

diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py
index 4f245e59..8fbe4731 100644
--- a/xarray/core/duck_array_ops.py
+++ b/xarray/core/duck_array_ops.py
@@ -202,13 +202,13 @@ def as_shared_dtype(scalars_or_arrays, xp=np):

         arrays = [asarray(x, xp=cp) for x in scalars_or_arrays]
     else:
-        arrays = [asarray(x, xp=xp) for x in scalars_or_arrays]
+        arrays = [x if isinstance(x, (int, float, complex)) else asarray(x, xp=xp) for x in scalars_or_arrays]
     # Pass arrays directly instead of dtypes to result_type so scalars
     # get handled properly.
     # Note that result_type() safely gets the dtype from dask arrays without
     # evaluating them.
     out_type = dtypes.result_type(*arrays)
-    return [astype(x, out_type, copy=False) for x in arrays]
+    return [astype(x, out_type, copy=False) if hasattr(x, "dtype") else x for x in arrays]

Without the second modification astype fails on scalars. I was surprised that xarray's result_type was OK with the scalars being passed to it. Now for the tests...

65 failed, 15045 passed, 1637 skipped, 154 xfailed, 59 xpassed, 6494 warnings

I'll try investigating more tomorrow if I have the time, but I get 63 failures with no changes so... 🤷‍♂️

@djhoese
Copy link
Contributor Author

djhoese commented Nov 6, 2023

Ok, very interesting. I think this issue will need to be fixed to preserve behavior when numpy 2.0 comes out. So this test fails right now with numpy 2 (not including my changes):

FAILED xarray/tests/test_rolling.py::TestDataArrayRolling::test_rolling_dask_dtype[float32] - AssertionError: assert dtype('float64') == dtype('float32')

As you can see it is specifically checking the result dtype and they are not equal (the result's type was upcast to 64-bit float). This is WITHOUT my changes. This test is fixed when I rerun the tests with my hacky solution mentioned above.

As for tests that fail with my changes that didn't before:

FAILED xarray/tests/test_duck_array_ops.py::test_argmin_max[x-True-min-True-True-str-1] - numpy.exceptions.DTypePromotionError: The DType <class 'numpy.dtypes.StrDType'> could not be promoted by <class 'numpy._FloatAbstractDType'>. This means that no common DType exists for the given inputs. For example they cannot be s...
FAILED xarray/tests/test_duck_array_ops.py::test_argmin_max[x-True-min-True-False-str-1] - numpy.exceptions.DTypePromotionError: The DType <class 'numpy.dtypes.StrDType'> could not be promoted by <class 'numpy._FloatAbstractDType'>. This means that no common DType exists for the given inputs. For example they cannot be s...
FAILED xarray/tests/test_duck_array_ops.py::test_argmin_max[x-True-max-True-True-str-1] - numpy.exceptions.DTypePromotionError: The DType <class 'numpy.dtypes.StrDType'> could not be promoted by <class 'numpy._FloatAbstractDType'>. This means that no common DType exists for the given inputs. For example they cannot be s...
FAILED xarray/tests/test_duck_array_ops.py::test_argmin_max[x-True-max-True-False-str-1] - numpy.exceptions.DTypePromotionError: The DType <class 'numpy.dtypes.StrDType'> could not be promoted by <class 'numpy._FloatAbstractDType'>. This means that no common DType exists for the given inputs. For example they cannot be s...

This is kind of expected given how inelegant my update was.

@djhoese
Copy link
Contributor Author

djhoese commented Apr 15, 2024

Are there any issues or milestones to track for general numpy 2 compatibility work? I'm running into this particular issue again in another one of my projects (with xarray main) and my little hack above suggested by shoyer seems to fix things still, but I don't know enough about the xarray internals to know if any of those changes are pull request worthy or bound to conflict with other things in xarray.

@dcherian
Copy link
Contributor

Good timing on the bump! We have some convo here: #8844

AFAIK no one has looked at this though I nominally volunteered to do so.

It sounds like you've made some progress. Can you open a PR please, it might be easier to help then?

@djhoese djhoese linked a pull request Apr 15, 2024 that will close this issue
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants