Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix upcasting with python builtin numbers and numpy 2 #8946

Open
wants to merge 24 commits into
base: main
Choose a base branch
from

Conversation

djhoese
Copy link
Contributor

@djhoese djhoese commented Apr 15, 2024

See #8402 for more discussion. Bottom line is that numpy 2 changes the rules for casting between two inputs. Due to this and xarray's preference for promoting python scalars to 0d arrays (scalar arrays), xarray objects are being upcast to higher data types when they previously didn't.

I'm mainly opening this PR for further and more detailed discussion.

CC @dcherian

@dcherian dcherian added the run-upstream Run upstream CI label Apr 15, 2024
@djhoese
Copy link
Contributor Author

djhoese commented Apr 15, 2024

Ugh my local clone was so old it was pointing to master. One sec...

@djhoese
Copy link
Contributor Author

djhoese commented Apr 15, 2024

Ok so the failing test is the array-api version (https://github.com/data-apis/array-api-compat) where it expects both the x and y inputs of the where function to be .dtype. Since we're skipping scalar->array conversion in this PR those objects won't have a .dtype. I'm not sure what the rules are for the strict array API having scalar inputs.

@dcherian
Copy link
Contributor

Looks like the array api strictly wants arrays: https://data-apis.org/array-api/latest/API_specification/generated/array_api.where.html

@djhoese
Copy link
Contributor Author

djhoese commented Apr 15, 2024

Related but I don't fully understand it: data-apis/array-api-compat#85

@djhoese
Copy link
Contributor Author

djhoese commented Apr 16, 2024

I guess it depends how you interpret the array API standard then. I can file an issue if needed. To me, depending on how you read the standard, it means either:

  1. This test is flawed as it tests scalar inputs when the array API specifically defines Array inputs.
  2. The Array API package is flawed because it assumes and requires Array inputs when the standard allows for scalar inputs (I don't think this is true if I'm understanding the description).

The other point is that maybe numpy compatibility is more important until numpy more formally conforms to the array API standard (see the first note on https://data-apis.org/array-api/latest/API_specification/array_object.html#api-specification-array-object--page-root). But also type promotion seems wishy-washy and not super strict: https://data-apis.org/array-api/latest/API_specification/type_promotion.html#mixing-arrays-with-python-scalars

I propose, because it works best for me and matches numpy compatibility, that I update the test to have a numpy case only but add a new test function with numpy and array api cases with array inputs to .where instead of scalars.

* main:
  (feat): Support for `pandas` `ExtensionArray` (pydata#8723)
  Migrate datatree mapping.py (pydata#8948)
  Add mypy to dev dependencies (pydata#8947)
  Convert 360_day calendars by choosing random dates to drop or add (pydata#8603)
@dcherian
Copy link
Contributor

I lean towards (1).

I looked at this for a while, and we'll need major changes around handling array API dtype objects to do this properly.

cc @keewis

@dcherian dcherian requested a review from keewis April 18, 2024 14:25
@keewis
Copy link
Collaborator

keewis commented Apr 22, 2024

we'll need major changes around handling array API dtype objects to do this properly.

I think the change could be limited to xarray.core.duck_array_ops.as_shared_dtype. According to the Array API section on mixing scalars and arrays, we should to use the dtype of the array (though it only looks at scalar + 1 array, so we'd need to extend that).

However, what we currently do is cast all scalars to arrays using asarray, which means python scalars use the OS default dtype (e.g. float64 on most 64-bit systems).

As a algorithm, maybe this could work:

  • separate the input into python scalars and arrays / scalars with dtype
  • determine result_type using just the arrays / scalars with dtype
  • check that all python scalars are compatible with the result (otherwise might have to return object?)
  • cast all input to arrays with the dtype

@djhoese
Copy link
Contributor Author

djhoese commented Apr 22, 2024

According to the Array API section on mixing scalars and arrays, we should to use the dtype of the array (though it only looks at scalar + 1 array, so we'd need to extend that).

Do you know if this is inline with numpy 2 dtype casting behavior?

@keewis
Copy link
Collaborator

keewis commented Apr 22, 2024

The main numpy namespace is supposed to be Array API compatible, so it should? I don't know for certain, though.

@djhoese
Copy link
Contributor Author

djhoese commented Apr 22, 2024

check that all python scalars are compatible with the result (otherwise might have to return object?)

How do we check this?

@djhoese
Copy link
Contributor Author

djhoese commented Apr 22, 2024

Here's what I have locally which seems to pass:

Subject: [PATCH] Cast scalars as arrays with result type of only arrays
---
Index: xarray/core/duck_array_ops.py
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py
--- a/xarray/core/duck_array_ops.py	(revision e27f572585a6386729a5523c1f9082c72fa8d178)
+++ b/xarray/core/duck_array_ops.py	(date 1713816523554)
@@ -239,20 +239,30 @@
         import cupy as cp
 
         arrays = [asarray(x, xp=cp) for x in scalars_or_arrays]
+        # Pass arrays directly instead of dtypes to result_type so scalars
+        # get handled properly.
+        # Note that result_type() safely gets the dtype from dask arrays without
+        # evaluating them.
+        out_type = dtypes.result_type(*arrays)
     else:
-        arrays = [
-            # https://github.com/pydata/xarray/issues/8402
-            # https://github.com/pydata/xarray/issues/7721
-            x if isinstance(x, (int, float, complex)) else asarray(x, xp=xp)
-            for x in scalars_or_arrays
-        ]
-    # Pass arrays directly instead of dtypes to result_type so scalars
-    # get handled properly.
-    # Note that result_type() safely gets the dtype from dask arrays without
-    # evaluating them.
-    out_type = dtypes.result_type(*arrays)
+        # arrays = [
+        #     # https://github.com/pydata/xarray/issues/8402
+        #     # https://github.com/pydata/xarray/issues/7721
+        #     x if isinstance(x, (int, float, complex)) else asarray(x, xp=xp)
+        #     for x in scalars_or_arrays
+        # ]
+        objs_with_dtype = [obj for obj in scalars_or_arrays if hasattr(obj, "dtype")]
+        if objs_with_dtype:
+            # Pass arrays directly instead of dtypes to result_type so scalars
+            # get handled properly.
+            # Note that result_type() safely gets the dtype from dask arrays without
+            # evaluating them.
+            out_type = dtypes.result_type(*objs_with_dtype)
+        else:
+            out_type = dtypes.result_type(*scalars_or_arrays)
+        arrays = [asarray(x, xp=xp) for x in scalars_or_arrays]
     return [
-        astype(x, out_type, copy=False) if hasattr(x, "dtype") else x for x in arrays
+        astype(x, out_type, copy=False) for x in arrays
     ]
 
 

I just through it together to see if it would work. I'm not sure it is accurate, but the fact that it is almost exactly like the existing solution with the only difference being the out_type = changes makes me feel this is going in a good direction.

Note I had to do if objs_with_dtype: because the test passes two python scalars so there are no arrays to determine the result type.

@keewis
Copy link
Collaborator

keewis commented Apr 22, 2024

How do we check this?

Not sure... but there are only so many builtin types that can be involved without requiring object dtype, so we could just enumerate all of them? As far as I can tell, that would be: bool, int, float, str, datetime/date, and timedelta

@djhoese
Copy link
Contributor Author

djhoese commented Apr 26, 2024

check that all python scalars are compatible with the result (otherwise might have to return object?)

How do we check this?

@keewis Do you have a test that I can add to verify any fix I attempt for this? What do you mean by python scalar being compatible with the result?

@keewis
Copy link
Collaborator

keewis commented Apr 26, 2024

well, for example, what should happen for this:

a = xr.DataArray(np.array([1, 2, 3], dtype="int8"), dim="x")
xr.where(a % 2 == 1, a, 1.2)

according to the algorithm above, we have one array of dtype int8, so that means we'd have to check if 1.2 (a float) is compatible with int8. It is not, so we should promote everything to float (the default would be to use float64, which might be a bit weird).

Something similar:

a = xr.DataArray(np.array(["2019-01-01", "2020-01-01"], dtype="datetime64[ns]"), dim="x")
xr.where(a.x % 2 == 1, a, datetime.datetime(2019, 6, 30))

in that case, the check should succeed, because we can convert a builtin datetime object to datetime64[ns].

@djhoese
Copy link
Contributor Author

djhoese commented Apr 28, 2024

I committed my (what I consider ugly) implementation of your original approach @keewis. I'm still not sure I understand how to approach the scalar compatibility so if someone has some ideas then please make some suggestion comments or commits directly if you have the permissions.

@keewis
Copy link
Collaborator

keewis commented Apr 28, 2024

this might be cleaner:

def asarray(data, xp=np, dtype=None):
    return data if is_duck_array(data) else xp.asarray(data, dtype=dtype)


def as_shared_dtype(scalars_or_arrays, xp=np):
    """Cast a arrays to a shared dtype using xarray's type promotion rules."""
    if any(is_extension_array_dtype(x) for x in scalars_or_arrays):
        # as soon as extension arrays are involved we only use this:
        extension_array_types = [
            x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x)
        ]
        if len(extension_array_types) == len(scalars_or_arrays) and all(
            isinstance(x, type(extension_array_types[0])) for x in extension_array_types
        ):
            return scalars_or_arrays
        raise ValueError(
            f"Cannot cast arrays to shared type, found array types {[x.dtype for x in scalars_or_arrays]}"
        )

    if array_type_cupy := array_type("cupy") and any(  # noqa: F841
        isinstance(x, array_type_cupy) for x in scalars_or_arrays  # noqa: F821
    ):
        import cupy as cp

        xp_ = cp
    else:
        xp_ = xp

    # split into python scalars and arrays / numpy scalars (i.e. into weakly and strongly dtyped)
    with_dtype = {}
    python_scalars = {}
    for index, elem in enumerate(scalars_or_arrays):
        append_to = with_dtype if hasattr(elem, "dtype") else python_scalars
        append_to[index] = elem

    if with_dtype:
        to_convert = with_dtype
    else:
        # can't avoid using the default dtypes if we only get weak dtypes
        to_convert = python_scalars
        python_scalars = {}

    arrays = {index: asarray(x, xp=xp_) for index, x in to_convert.items()}

    common_dtype = dtypes.result_type(*arrays.values())
    # TODO(keewis): check that all python scalars are compatible. If not, change the dtype or raise.

    # cast arrays
    cast = {index: astype(x, dtype=common_dtype, copy=False) for index, x in arrays.items()}
    # convert python scalars to arrays with a specific dtype
    converted = {index: asarray(x, xp=xp_, dtype=common_dtype) for index, x in python_scalars.items()}

    # merge both
    combined = cast | converted
    return [x for _, x in sorted(combined.items(), key=lambda x: x[0])]

This is still missing the dtype fallbacks, though.

@keewis
Copy link
Collaborator

keewis commented Apr 28, 2024

I see now why the dtype fallbacks for scalars is tricky... we basically need to enumerate the casting rules, and decide when to return a different dtype (like object). numpy has can_cast with the option to choose the strictness (so we could use "same_kind") and it accepts python scalar types, while the Array API does not allow that choice, and we also can't pass in python scalar types.

To start, here's the rules from the Array API:

  • complex dtypes are compatible with int, float, or complex
  • float dtypes are compatible with any int or float
  • int dtypes are compatible with int (but beware: python uses BigInt, so the value might exceed the maximum of the dtype)
  • the bool dtype is only compatible with bool

From numpy, we also have these (numpy casting is even more relaxed than this, but that behavior may also cause some confusing issues):

  • bool can be cast to int, so it is compatible with anything int is compatible with
  • str dtypes are only compatible with str. Anything else, like formatting and casting to other types, has to be done explicitly before calling as_shared_dtype.
  • datetime dtypes (precisions) are compatible with datetime.datetime, datetime.date, and pd.Timestamp
  • timedelta dtypes (precisions) are compatible with datetime.timedelta and pd.Timedelta. Casting to int is possible, but has to be done explicitly (i.e. we can ignore it here)
  • anything else results in a object dtype

Edit: it appears NEP 50 describes the changes in detail. I didn't see that before writing both the list above and implementing the changes, so I might have to change both.

@keewis
Copy link
Collaborator

keewis commented Apr 28, 2024

here's my shot at the scalar dtype verification. I'm pretty sure it can be cleaned up further (and we need more tests), but it does fix all the casting issues. Edit: note that this depends on the Array API fixes for numpy>=2.

What I don't like is that we're essentially hard-coding the dtype casting hierarchy, but I couldn't figure out a way to make it work without that.

@djhoese
Copy link
Contributor Author

djhoese commented May 12, 2024

FYI to everyone watching this, I'm going to be switching to a heavier paternity leave than I was already starting this week. I think someone else should take this PR over as I don't think I'll have time to finish it in time for the numpy 2 final release.

@shoyer
Copy link
Member

shoyer commented May 15, 2024

In an ideal world, I think this would be written something like:

def as_shared_dtype(scalars_or_arrays):
    xp = get_array_namespace_or_numpy(scalars_or_arrays)
    dtype = xp.result_type(*scalars_or_arrays)
    return tuple(xp.asarray(x, dtype) for x in scalars_or_arrays)

The main issues stopping this:

  1. cupy, pandas and old versions of numpy don't support the array API
  2. xp.result_type only supports arrays, not Python scalars

The first issue can be solved with compatibility code. I will raise the second issue on the array API tracker.

@keewis
Copy link
Collaborator

keewis commented May 22, 2024

while 2 should eventually be resolved by an addition to the Array API, it won't help us right now to resolve the dtype casting before the release of numpy=2.0.

As far as I understand it, to put the above code sample (with slight code modifications due to 1) in as_shared_dtypes, we'd have to add compatibility code to dtypes.result_type, at least until we can require a version of the Array API that allows us to forward most of it to xp.result_type.

To do so, we'd have to find a way to split the input of dtypes.result_type into weakly dtyped and explicitly dtyped / dtypes (since dtypes.result_type and xp.result_type accept arrays / explicitly dtyped scalars or dtype objects). Then we can forward the latter to xp.result_type, and figure out what to do with the weakly dtyped data in an additional step.

However, while with numpy we can simply use isinstance(x, np.dtype) to find dtypes, this won't help us with other Array API-implementing libraries as the dtypes are generally opaque objects, and we also don't want to lose the ability to use where on dtype=object arrays. In other words, I can't find a way to separate weakly dtyped data from dtypes and explicitly dtyped data.

If there truly is no way to find dtypes in a general way, we'll have to do the split in as_shared_dtype, where we can guarantee that we don't get dtype objects:

def as_shared_dtype(scalars_or_arrays):
    xp = get_array_namespace_or_numpy(scalars_or_arrays)

    explicitly_dtyped, weakly_dtyped = dtypes.extract_explicitly_dtyped(scalars_or_arrays)
    common_dtype = dtypes.result_type(*explicitly_dtyped)
    dtype = dtypes.adapt_common_dtype(common_dtype, weakly_dtyped)

    return tuple(xp.asarray(x, dtype) for x in scalars_or_arrays)

Another option would be to pass the weakly dtyped data as a keyword argument to dtypes.result_type, which in the future would allow us to pass both to xp.result_type if we know that a specific library supports python scalars.

Edit: actually, I'd probably go for the latter.

@keewis
Copy link
Collaborator

keewis commented May 23, 2024

In the most recent commits I've added a way to check if a dtype is compatible with the scalars, using the second option from #8946 (comment): split into weakly / strongly dtyped in as_shared_dtype, then pass both to result_type (but as separate arguments).

This appears to resolve the issues we had with dtype casting and numpy>=2, but there's a few other issues that pop up. For example, cupy doesn't have cupy.astype, and the failing pint tests seem to hint at an issue in my algorithm and in the existing non-pint test coverage (not sure, though).

I'll try to look into the above and into adding additional tests over the weekend.

In the meantime, I'd appreciate reviews for the general idea (cc in particular @shoyer, but also @dcherian).

@keewis
Copy link
Collaborator

keewis commented May 25, 2024

looks like the most recent commits fixed the remaining failing tests (turns out I forgot to apply our custom dtype casting rules when adjusting the dtype to fit the scalars), so all that's left is to fix mypy and write tests.

return xp.dtype(object)
if not arrays_and_dtypes:
# no explicit dtypes, so we simply convert to 0-d arrays using default dtypes
arrays_and_dtypes = [asarray(x, xp=xp) for x in weakly_dtyped] # type: ignore
Copy link
Collaborator

@keewis keewis May 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with my limited knowledge on typing, ignoring is the only way I could find to get the code to pass. arrays_and_dtypes is of type list[np.typing.ArrayLike | np.typing.DTypeLike], while I'd like to assign something that is of type list[np.typing.ArrayLike] (i.e. no dtypes). How difficult would it be to get this to work without the ignore? Would we need to type hint duck_array_ops.asarray first?

@shoyer
Copy link
Member

shoyer commented May 31, 2024

Hey @keewis, thanks for continuing to dive into this!

Given that there seems to be a concensus to add support for weak/strong dtypes into the Array API's result_type in the future, let's try to build this code around the assumption that that will work in the future.

I think this could look something the following:

def _future_array_api_result_type(
    *arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike,
    xp,
) -> np.dtype:
    ...

def result_type(
    *arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike,
    xp=None,
) -> np.dtype:
    from xarray.core.duck_array_ops import get_array_namespace

    if xp is None:
        xp = get_array_namespace(arrays_and_dtypes)

    types = {xp.result_type(t) for t in arrays_and_dtypes}
    if all(isinstance(t, np.dtype) for t in types):  # NOTE: slightly more conservative than the existing code
        # only check if there's numpy dtypes – the array API does not
        # define the types we're checking for
        for left, right in PROMOTE_TO_OBJECT:
            if any(np.issubdtype(t, left) for t in types) and any(
                np.issubdtype(t, right) for t in types
            ):
                return np.dtype(object)

    if xp is np:
        return np.result_dtype(*array_and_dtype)  # fast path

    # TODO: replace with xp.result_type when the array API always supports weak dtypes:
    # https://github.com/data-apis/array-api/issues/805
    return _future_array_api_result_type(*arrays_and_dtypes, xp=xp)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
run-upstream Run upstream CI
Projects
None yet
Development

Successfully merging this pull request may close these issues.

where dtype upcast with numpy 2
4 participants