-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: array types: add JAX support #20085
Conversation
(The reason I decided to comment over on the DLPack issue is that I recall a conversation about how portability could be increased if we replace occurrences of |
Thanks for working on this Lucas. JAX support will be very nice. And a third library with CPU support (after NumPy and PyTorch) will also be good for testing how generic our array API standard support actually is. Okay, related to the read-only question, it looks like this is the problem you were seeing:
The problem is that Cython doesn't accept read-only arrays when the signature is a regular memoryview. There's a long discussion about this topic in scikit-learn/scikit-learn#10624. Now that we have Cython 3 though, the fix is simple: diff --git a/scipy/cluster/_hierarchy.pyx b/scipy/cluster/_hierarchy.pyx
index 814051df2..c59b3de6a 100644
--- a/scipy/cluster/_hierarchy.pyx
+++ b/scipy/cluster/_hierarchy.pyx
@@ -1012,7 +1012,7 @@ def nn_chain(double[:] dists, int n, int method):
return Z_arr
-def mst_single_linkage(double[:] dists, int n):
+def mst_single_linkage(const double[:] dists, int n):
"""Perform hierarchy clustering using MST algorithm for single linkage.
Parameters This makes the tests pass (at least for this issue, I tried with the |
thanks! I've removed the copies and added some |
00e27e8
to
77aaebd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One question I have here, which is probably a question more broadly for the array API: as written, much of the JAX support added here will not work under jax.jit
, because it requires converting array objects to host-side buffers, and this is not possible during tracing when the array objects are abstract. JAX has mechanisms for this (namely custom calls and/or pure_callback) but the array API doesn't seem to have much consideration for this kind of library structure. Unfortunately, I think this will severely limit the usefulness of these kinds of implementations. I wonder if the array API could consider this kind of limitation?
Do you mean for testing purposes, or for library code? For the latter: we should never do device transfers like GPU->host memory under the hood. The array API standard design was careful to not include that. It wasn't even possible at all until very recently, when a way was added to do it with DLPack (for testing purposes). If you mean "convert to
JIT compilers were explicitly considered, and nothing in the standard should be JIT-unfriendly, except for the few clearly marked as data-dependent output shapes and the few dunder methods that are also problematic for lazy arrays. |
If this is what you meant, x-ref the 'Dispatching Mechanism' section of gh-18286 |
I mean for actual user-level code: most of the work here will be more-or-less useless for JAX users because array conversions via dlpack cannot be done under JIT without some sort of callback mechanism. |
Okay, I had a look at https://jax.readthedocs.io/en/latest/tutorials/external-callbacks.html and understand what you mean now. It looks fairly straightforward to support (disclaimer: I haven't tried it yet). It'd be taking this current code pattern: # inside some Python-level scipy function with array API standard support:
x = np.asarray(x)
result = call_some_compiled_code(x)
result = xp.asarray(result) # back to original array type and replacing it with something like (untested): def call_compiled_code_helper(x, xp): # needs *args, *kwargs too
if is_jax(x):
result_shape_dtypes = ... # TODO: figure out how to construct the needed PyTree here
result = jax.pure_callback(call_some_compiled_code, result_shape_dtypes, x)
else:
x = np.asarray(x)
result = call_some_compiled_code(x)
result = xp.asarray(result) Use of a utility function like It's interesting that |
Yeah, something like that is what I had in mind, though |
It is (depending on your defintion of "issue") because there's no magic bullet that will do something like take some native function implemented in C/Fortran/Cython inside SciPy and make that run on GPU. The basic state of things is:
In a generic library like SciPy it's almost impossible to support custom kernels on device. Our choices for arrays that don't live on host memory are:
|
I gave adding Dask another shot just now, but unfortunately things are missing from |
I'd suggest keeping this PR focused on JAX and getting that merged first. That makes it easier to see (also in the future) what had to be done only for JAX. And if we're going to experiment a bit with |
[skip ci]
[skip ci]
Looks like we're basically there. I'll do some testing with PyTorch and CuPy tomorrow to check that we didn't silently broke anything there - and then it's probably good to merge. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll do some testing with PyTorch and CuPy tomorrow to check that we didn't silently broke anything there
Short summary seems to be that most of the failures are already present on main
rather than introduced here, from my initial checking on GPU.
For SCIPY_DEVICE=cuda python dev.py test -j 32 -b all
with NVIDIA GPU:
- latest
main
:66 failed, 51999 passed, 11312 skipped, 157 xfailed, 13 xpassed in 55.91s
FAILED scipy/fft/tests/test_basic.py::TestFFTThreadSafe::test_fft[torch] - pytest.PytestUnhandledThreadExceptionWarning: Exception in thread Thread-33 (worker)
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case15-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case16-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case17-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case18-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case19-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_basic[p1-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case20-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case21-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/fft/tests/test_basic.py::TestFFTThreadSafe::test_ifft[cupy] - AssertionError:
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case22-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case23-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_vectorization[shape0-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case24-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case25-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case26-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_vectorization[shape0-cupy] - AssertionError:
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_vectorization[shape1-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_vectorization[shape1-cupy] - AssertionError:
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case27-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_vectorization[shape2-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case28-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_vectorization[shape2-cupy] - AssertionError:
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case29-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_vectorization[shape3-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case30-torch] - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_vectorization[shape3-cupy] - AssertionError:
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case31-torch] - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_convergence[torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case32-torch] - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case33-torch] - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case34-torch] - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case35-torch] - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case0-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case1-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case36-torch] - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case2-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case37-torch] - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case3-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case4-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case38-torch] - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
FAILED scipy/fft/tests/test_basic.py::TestFFTThreadSafe::test_rfft[cupy] - AssertionError:
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case5-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case6-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case39-torch] - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case7-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case40-torch] - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case8-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case41-torch] - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case42-torch] - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case9-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case43-torch] - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case44-torch] - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case10-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_dtype[float16-0.622-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Half for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case11-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case12-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case13-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case14-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_dtype[float16-root1-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Half for the destination and Float for the source.
FAILED scipy/special/tests/test_support_alternative_backends.py::test_support_alternative_backends[f_name_n_args15-array_api_strict] - AssertionError:
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_dtype[float64-0.622-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_dtype[float64-root1-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/special/tests/test_support_alternative_backends.py::test_support_alternative_backends[f_name_n_args15-torch] - AssertionError: Scalars are not close!
FAILED scipy/stats/tests/test_stats.py::TestDescribe::test_describe_numbers[torch] - AssertionError: Tensor-likes are not equal!
FAILED scipy/fft/tests/test_basic.py::TestFFTThreadSafe::test_ihfft[cupy] - AssertionError:
- here:
67 failed, 51998 passed, 11316 skipped, 157 xfailed, 13 xpassed in 58.12s
Mostly looks like the chandrupatla
stuff discussed above on torch
+ GPU, which is already on main
.
With the patch below
--- a/scipy/optimize/_tstutils.py
+++ b/scipy/optimize/_tstutils.py
@@ -44,6 +44,7 @@ from random import random
import numpy as np
from scipy.optimize import _zeros_py as cc
+from scipy._lib._array_api import array_namespace
# "description" refers to the original functions
description = """
@@ -887,18 +888,21 @@ fun6.root = 0
def fun7(x):
- return 0 if abs(x) < 3.8e-4 else x*np.exp(-x**(-2))
+ xp = array_namespace(x)
+ return 0 if abs(x) < 3.8e-4 else x*xp.exp(-x**(-2))
fun7.root = 0
def fun8(x):
+ xp = array_namespace(x)
xi = 0.61489
- return -(3062*(1-xi)*np.exp(-x))/(xi + (1-xi)*np.exp(-x)) - 1013 + 1628/x
+ return -(3062*(1-xi)*xp.exp(-x))/(xi + (1-xi)*xp.exp(-x)) - 1013 + 1628/x
fun8.root = 1.0375360332870405
def fun9(x):
- return np.exp(x) - 2 - 0.01/x**2 + .000002/x**3
+ xp = array_namespace(x)
+ return xp.exp(x) - 2 - 0.01/x**2 + .000002/x**3
fun9.root = 0.7032048403631358
# Each "chandropatla" test case has
most of the torch
GPU failures become RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
.
And with this additional patch:
--- a/scipy/optimize/_chandrupatla.py
+++ b/scipy/optimize/_chandrupatla.py
@@ -187,7 +187,7 @@ def _chandrupatla(func, a, b, *, args=(), xatol=None, xrtol=None,
# If the bracket is no longer valid, report failure (unless a function
# tolerance is met, as detected above).
i = (xp_sign(work.f1) == xp_sign(work.f2)) & ~stop
- NaN = xp.asarray(xp.nan)
+ NaN = xp.asarray(xp.nan, dtype=work.xmin.dtype)
work.xmin[i], work.fmin[i], work.status[i] = NaN, NaN, eim._ESIGNERR
stop[i] = True
it drops to 15 failures. There are some typos in some of the array API-converted tests it seems. More np->xp conversion fixes another GPU failure:
--- a/scipy/optimize/tests/test_chandrupatla.py
+++ b/scipy/optimize/tests/test_chandrupatla.py
@@ -656,11 +656,11 @@ class TestChandrupatla(TestScalarRootFinders):
x1, x2 = bracket
f0 = xp_minimum(xp.abs(self.f(x1, *args)), xp.abs(self.f(x2, *args)))
res1 = _chandrupatla_root(self.f, *bracket, **kwargs)
- xp_assert_less(np.abs(res1.fun), 1e-3*f0)
+ xp_assert_less(xp.abs(res1.fun), 1e-3*f0)
kwargs['frtol'] = 1e-6
res2 = _chandrupatla_root(self.f, *bracket, **kwargs)
- xp_assert_less(np.abs(res2.fun), 1e-6*f0)
- xp_assert_less(np.abs(res2.fun), np.abs(res1.fun))
+ xp_assert_less(xp.abs(res2.fun), 1e-6*f0)
+ xp_assert_less(xp.abs(res2.fun), xp.abs(res1.fun))
I ran out of steam there, but basically this branch doesn't seem to introduce much that isn't already broken on main
in my hands.
@tylerjereddy would you be willing to open a PR with these patches that I can merge? |
ok |
* Addresses some of my points at: scipy#20085 (review) and seems to fix about 55 GPU-based array API test failures [skip cirrus] [skip circle]
I can make any remaining fixes in that PR if need be. It turned out that we forgot to add |
Addresses some of my points at: #20085 (review) and seems to fix about 55 GPU-based array API test failures Co-authored-by: Matt Haberland <mhaberla@calpoly.edu>
@@ -207,7 +205,6 @@ def test_mlab_linkage_conversion_empty(self, xp): | |||
xp_assert_equal(from_mlab_linkage(X), X) | |||
xp_assert_equal(to_mlab_linkage(X), X) | |||
|
|||
@skip_xp_backends(cpu_only=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from_mlab_linkage
converts with np.asarray
, so I'll put these back.
@lucascolley FYI some of the I found that having an environment with both JAX and CuPy and testing with |
Nice. Moving forward I will probably have one env with JAX + CuPy and another with PyTorch + CuPy + array-api-strict, and test with both. Things will be easier once I'm back with a GPU. |
The one Windows failure is unrelated:
|
[skip ci]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, time to give it a go. Thanks a lot @lucascolley and all reviewers!
Follow-up steps include:
- deal with item/slice assignment, reducing/removing skips related to that
- look at a callback mechanism to make
jax.jit
work - check how things look on TPU (e.g. in a Kaggle notebook, see discussion higher up)
I plan to have a look at 1 and 2 tonight.
Unrelated to JAX follow-ups:
- a few more CuPy test failures to deal with in
main
- deal with other test failures related to
nan_policy
thanks Ralf and all reviewers for all of the help here! I plan to have a look at Dask in a few months' time, but anyone else, feel free to tackle it if you get to it before me. A reminder that gh-19900 looks ready to me and should help eliminate some of the GPU failures Tyler was seeing. But no rush if it looks like more work is needed. FYI @izaid , I'm ~1 week out from finals now, so I'll not be working on any PRs for a while. See you on the other side! |
What are the CuPy and |
Good luck with your finals Lucas!
CuPy failures are taken care of in gh-19900. They were:
|
This makes things work with JAX, at a slight readability cost. Follow up to scipy#20085.
I worked some more on this, adding Using it, some good news and some less good. I could get
The less good news is that the >>> import jax.numpy as jnp
>>> import jax
>>>
>>> def func(x, idx, value):
... return x.at[idx].set(value)
...
>>> func_jit = jax.jit(func)
>>>
>>> x = jnp.arange(5)
>>> idx = x < 3
>>>
>>> func(x, idx, 99)
Array([99, 99, 99, 3, 4], dtype=int32)
>>> func_jit(x, idx, 99)
...
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[5]) The explanation under Boolean indexing into JAX arrays https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError isn't quite satisfactory. There are no dynamic shapes here, so it could work just fine. If the answer is to always use def at_set(
x : Array,
idx: Array | int | slice,
val: Array | int | float | complex,
*,
xp: ModuleType | None = None,
) -> Array:
"""In-place update. Use only if no views are involved."""
xp = array_namespace(x) if xp is None else xp
if is_jax(xp):
if xp.isdtype(idx.dtype, 'bool'):
x = xp.where(idx, x, val)
else:
x = x.at[idx].set(val)
else:
x[idx] = val
return x Which is slower - if is_jax(xp):
if hasattr(idx, 'dtype') and xp.isdtype(idx.dtype, 'bool'):
x = xp.where(idx, x * val, x)
else:
x = x.at[idx].multiply(val)
else:
x[idx] *= val I'll look at it some more - this may be better suited for data-apis/array-api#609. |
I think it would be worth adding something that works for now, even if it's not great. It would avoid all the test skips and make it more obvious what capabilities we need. Once something better comes along, it will be easy to replace. It's probably better than using |
Yeah maybe - I don't want to go too fast though, and add a bunch of code we may regret. Looks like the new version (I edited my comment and pushed a new commit) works though, and is still very fast with JAX.
Let's make sure not to do things like that. Using |
@rgommers FYI I managed to implement the scalar boolean scatter in JAX, and it will be available in the next release. Turns out we had all the necessary logic there already – I just needed to put it together! google/jax#21305 |
Great! Thanks @jakevdp. Looks like a small patch that I can try out pretty easily on top of JAX 0.4.28 - will give it a go later this week. (note to self, since comments are hard to find in this PR: the relevant comment here is #20085 (comment)) |
Reference issue
Towards gh-18867
What does this implement/fix?
First steps on JAX support. To-do:
Additional information
Can do the same for
dask.array
once the problems are fixed over at data-apis/array-api-compat#89.