-
-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix upcasting with python builtin numbers and numpy 2 #8946
base: main
Are you sure you want to change the base?
Conversation
Ugh my local clone was so old it was pointing to master. One sec... |
88e778a
to
f3c2c93
Compare
for more information, see https://pre-commit.ci
Ok so the failing test is the array-api version (https://github.com/data-apis/array-api-compat) where it expects both the x and y inputs of the |
Looks like the array api strictly wants arrays: https://data-apis.org/array-api/latest/API_specification/generated/array_api.where.html |
Related but I don't fully understand it: data-apis/array-api-compat#85 |
I guess it depends how you interpret the array API standard then. I can file an issue if needed. To me, depending on how you read the standard, it means either:
The other point is that maybe numpy compatibility is more important until numpy more formally conforms to the array API standard (see the first note on https://data-apis.org/array-api/latest/API_specification/array_object.html#api-specification-array-object--page-root). But also type promotion seems wishy-washy and not super strict: https://data-apis.org/array-api/latest/API_specification/type_promotion.html#mixing-arrays-with-python-scalars I propose, because it works best for me and matches numpy compatibility, that I update the test to have a numpy case only but add a new test function with numpy and array api cases with array inputs to |
This reverts commit 3f7670b.
* main: (feat): Support for `pandas` `ExtensionArray` (pydata#8723) Migrate datatree mapping.py (pydata#8948) Add mypy to dev dependencies (pydata#8947) Convert 360_day calendars by choosing random dates to drop or add (pydata#8603)
I lean towards (1). I looked at this for a while, and we'll need major changes around handling array API dtype objects to do this properly. cc @keewis |
I think the change could be limited to However, what we currently do is cast all scalars to arrays using As a algorithm, maybe this could work:
|
Do you know if this is inline with numpy 2 dtype casting behavior? |
The main |
How do we check this? |
Here's what I have locally which seems to pass: Subject: [PATCH] Cast scalars as arrays with result type of only arrays
---
Index: xarray/core/duck_array_ops.py
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py
--- a/xarray/core/duck_array_ops.py (revision e27f572585a6386729a5523c1f9082c72fa8d178)
+++ b/xarray/core/duck_array_ops.py (date 1713816523554)
@@ -239,20 +239,30 @@
import cupy as cp
arrays = [asarray(x, xp=cp) for x in scalars_or_arrays]
+ # Pass arrays directly instead of dtypes to result_type so scalars
+ # get handled properly.
+ # Note that result_type() safely gets the dtype from dask arrays without
+ # evaluating them.
+ out_type = dtypes.result_type(*arrays)
else:
- arrays = [
- # https://github.com/pydata/xarray/issues/8402
- # https://github.com/pydata/xarray/issues/7721
- x if isinstance(x, (int, float, complex)) else asarray(x, xp=xp)
- for x in scalars_or_arrays
- ]
- # Pass arrays directly instead of dtypes to result_type so scalars
- # get handled properly.
- # Note that result_type() safely gets the dtype from dask arrays without
- # evaluating them.
- out_type = dtypes.result_type(*arrays)
+ # arrays = [
+ # # https://github.com/pydata/xarray/issues/8402
+ # # https://github.com/pydata/xarray/issues/7721
+ # x if isinstance(x, (int, float, complex)) else asarray(x, xp=xp)
+ # for x in scalars_or_arrays
+ # ]
+ objs_with_dtype = [obj for obj in scalars_or_arrays if hasattr(obj, "dtype")]
+ if objs_with_dtype:
+ # Pass arrays directly instead of dtypes to result_type so scalars
+ # get handled properly.
+ # Note that result_type() safely gets the dtype from dask arrays without
+ # evaluating them.
+ out_type = dtypes.result_type(*objs_with_dtype)
+ else:
+ out_type = dtypes.result_type(*scalars_or_arrays)
+ arrays = [asarray(x, xp=xp) for x in scalars_or_arrays]
return [
- astype(x, out_type, copy=False) if hasattr(x, "dtype") else x for x in arrays
+ astype(x, out_type, copy=False) for x in arrays
]
I just through it together to see if it would work. I'm not sure it is accurate, but the fact that it is almost exactly like the existing solution with the only difference being the Note I had to do |
Not sure... but there are only so many builtin types that can be involved without requiring |
@keewis Do you have a test that I can add to verify any fix I attempt for this? What do you mean by python scalar being compatible with the result? |
well, for example, what should happen for this: a = xr.DataArray(np.array([1, 2, 3], dtype="int8"), dim="x")
xr.where(a % 2 == 1, a, 1.2) according to the algorithm above, we have one array of dtype Something similar: a = xr.DataArray(np.array(["2019-01-01", "2020-01-01"], dtype="datetime64[ns]"), dim="x")
xr.where(a.x % 2 == 1, a, datetime.datetime(2019, 6, 30)) in that case, the check should succeed, because we can convert a builtin |
I committed my (what I consider ugly) implementation of your original approach @keewis. I'm still not sure I understand how to approach the scalar compatibility so if someone has some ideas then please make some suggestion comments or commits directly if you have the permissions. |
this might be cleaner: def asarray(data, xp=np, dtype=None):
return data if is_duck_array(data) else xp.asarray(data, dtype=dtype)
def as_shared_dtype(scalars_or_arrays, xp=np):
"""Cast a arrays to a shared dtype using xarray's type promotion rules."""
if any(is_extension_array_dtype(x) for x in scalars_or_arrays):
# as soon as extension arrays are involved we only use this:
extension_array_types = [
x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x)
]
if len(extension_array_types) == len(scalars_or_arrays) and all(
isinstance(x, type(extension_array_types[0])) for x in extension_array_types
):
return scalars_or_arrays
raise ValueError(
f"Cannot cast arrays to shared type, found array types {[x.dtype for x in scalars_or_arrays]}"
)
if array_type_cupy := array_type("cupy") and any( # noqa: F841
isinstance(x, array_type_cupy) for x in scalars_or_arrays # noqa: F821
):
import cupy as cp
xp_ = cp
else:
xp_ = xp
# split into python scalars and arrays / numpy scalars (i.e. into weakly and strongly dtyped)
with_dtype = {}
python_scalars = {}
for index, elem in enumerate(scalars_or_arrays):
append_to = with_dtype if hasattr(elem, "dtype") else python_scalars
append_to[index] = elem
if with_dtype:
to_convert = with_dtype
else:
# can't avoid using the default dtypes if we only get weak dtypes
to_convert = python_scalars
python_scalars = {}
arrays = {index: asarray(x, xp=xp_) for index, x in to_convert.items()}
common_dtype = dtypes.result_type(*arrays.values())
# TODO(keewis): check that all python scalars are compatible. If not, change the dtype or raise.
# cast arrays
cast = {index: astype(x, dtype=common_dtype, copy=False) for index, x in arrays.items()}
# convert python scalars to arrays with a specific dtype
converted = {index: asarray(x, xp=xp_, dtype=common_dtype) for index, x in python_scalars.items()}
# merge both
combined = cast | converted
return [x for _, x in sorted(combined.items(), key=lambda x: x[0])] This is still missing the dtype fallbacks, though. |
I see now why the dtype fallbacks for scalars is tricky... we basically need to enumerate the casting rules, and decide when to return a different dtype (like To start, here's the rules from the Array API:
From
Edit: it appears NEP 50 describes the changes in detail. I didn't see that before writing both the list above and implementing the changes, so I might have to change both. |
here's my shot at the scalar dtype verification. I'm pretty sure it can be cleaned up further (and we need more tests), but it does fix all the casting issues. Edit: note that this depends on the Array API fixes for What I don't like is that we're essentially hard-coding the dtype casting hierarchy, but I couldn't figure out a way to make it work without that. |
FYI to everyone watching this, I'm going to be switching to a heavier paternity leave than I was already starting this week. I think someone else should take this PR over as I don't think I'll have time to finish it in time for the numpy 2 final release. |
In an ideal world, I think this would be written something like: def as_shared_dtype(scalars_or_arrays):
xp = get_array_namespace_or_numpy(scalars_or_arrays)
dtype = xp.result_type(*scalars_or_arrays)
return tuple(xp.asarray(x, dtype) for x in scalars_or_arrays) The main issues stopping this:
The first issue can be solved with compatibility code. I will raise the second issue on the array API tracker. |
while 2 should eventually be resolved by an addition to the Array API, it won't help us right now to resolve the dtype casting before the release of As far as I understand it, to put the above code sample (with slight code modifications due to 1) in To do so, we'd have to find a way to split the input of However, while with If there truly is no way to find dtypes in a general way, we'll have to do the split in def as_shared_dtype(scalars_or_arrays):
xp = get_array_namespace_or_numpy(scalars_or_arrays)
explicitly_dtyped, weakly_dtyped = dtypes.extract_explicitly_dtyped(scalars_or_arrays)
common_dtype = dtypes.result_type(*explicitly_dtyped)
dtype = dtypes.adapt_common_dtype(common_dtype, weakly_dtyped)
return tuple(xp.asarray(x, dtype) for x in scalars_or_arrays) Another option would be to pass the weakly dtyped data as a keyword argument to Edit: actually, I'd probably go for the latter. |
In the most recent commits I've added a way to check if a This appears to resolve the issues we had with I'll try to look into the above and into adding additional tests over the weekend. In the meantime, I'd appreciate reviews for the general idea (cc in particular @shoyer, but also @dcherian). |
looks like the most recent commits fixed the remaining failing tests (turns out I forgot to apply our custom dtype casting rules when adjusting the dtype to fit the scalars), so all that's left is to fix |
Using the common dtype instead would change the result.
return xp.dtype(object) | ||
if not arrays_and_dtypes: | ||
# no explicit dtypes, so we simply convert to 0-d arrays using default dtypes | ||
arrays_and_dtypes = [asarray(x, xp=xp) for x in weakly_dtyped] # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with my limited knowledge on typing, ignoring is the only way I could find to get the code to pass. arrays_and_dtypes
is of type list[np.typing.ArrayLike | np.typing.DTypeLike]
, while I'd like to assign something that is of type list[np.typing.ArrayLike]
(i.e. no dtypes). How difficult would it be to get this to work without the ignore? Would we need to type hint duck_array_ops.asarray
first?
Hey @keewis, thanks for continuing to dive into this! Given that there seems to be a concensus to add support for weak/strong dtypes into the Array API's I think this could look something the following: def _future_array_api_result_type(
*arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike,
xp,
) -> np.dtype:
...
def result_type(
*arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike,
xp=None,
) -> np.dtype:
from xarray.core.duck_array_ops import get_array_namespace
if xp is None:
xp = get_array_namespace(arrays_and_dtypes)
types = {xp.result_type(t) for t in arrays_and_dtypes}
if all(isinstance(t, np.dtype) for t in types): # NOTE: slightly more conservative than the existing code
# only check if there's numpy dtypes – the array API does not
# define the types we're checking for
for left, right in PROMOTE_TO_OBJECT:
if any(np.issubdtype(t, left) for t in types) and any(
np.issubdtype(t, right) for t in types
):
return np.dtype(object)
if xp is np:
return np.result_dtype(*array_and_dtype) # fast path
# TODO: replace with xp.result_type when the array API always supports weak dtypes:
# https://github.com/data-apis/array-api/issues/805
return _future_array_api_result_type(*arrays_and_dtypes, xp=xp) |
See #8402 for more discussion. Bottom line is that numpy 2 changes the rules for casting between two inputs. Due to this and xarray's preference for promoting python scalars to 0d arrays (scalar arrays), xarray objects are being upcast to higher data types when they previously didn't.
I'm mainly opening this PR for further and more detailed discussion.
CC @dcherian
where
dtype upcast with numpy 2 #8402whats-new.rst
api.rst