-
-
Notifications
You must be signed in to change notification settings - Fork 5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: linalg: support array API for standard extension functions #19260
base: main
Are you sure you want to change the base?
Conversation
53641f7
to
0d9337e
Compare
[skip cirrus] [skip circle]
0d9337e
to
25df2f8
Compare
[skip cirrus] [skip circle]
[skip cirrus] [skip circle]
[skip cirrus] [skip circle]
37239f2
to
ef9831d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the regular function modifications are OK to me. But I'm not sure if we really need all the test changes. It doesn't reflect the goals of the tests at certain places and I'm not sure I understand why we test nonnumpy xp namespace in the tests. It should be pretty safe to test SciPy with default numpy arrays without any xp_assert_close's or modified tol parameters.
… tols [skip cirrus] [skip circle]
[skip cirrus] [skip circle]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks pretty good to me overall. I'd like to see the diff shrink as much as possible though, both in the implementations and tests. That's in general a sign that the code is in good shape, and it makes things easier to review and understand later on.
I think the regular function modifications are OK to me. But I'm not sure if we really need all the test changes. It doesn't reflect the goals of the tests at certain places and I'm not sure I understand why we test nonnumpy xp namespace in the tests. It should be pretty safe to test SciPy with default numpy arrays without any xp_assert_close's or modified tol parameters.
I think it is quite useful to test non-numpy arrays; without testing it's almost certainly going to be broken. I think of these as testing optional dependencies - just like we have tests with mpmath
, scikit-umfpack
and a whole bunch of other optional runtime dependencies.
Sometimes this requires extra code, however it also tends to uncover bugs and non-standard code constructs that (when refactored) improve the code itself. Two examples here:
- The changes here from integer to floating point arrays for testing make sense. Functions like
solve
are inherently floating point-only. We also need to test that integers (and lists, and other array-like's) are still converted correctly and we don't break backwards compat. However, that can be a single small test. That many tests use integers is a matter of previous authors taking a shortcut because it didn't matter much, rather than all those tests using integers by design. - The stricter dtype checks led me to spot a bug quickly when Lucas asked me about
result_type
:
>>> import numpy as np
>>> a = np.array([[3, 2, 0], [1, -1, 0], [0, 5, 1]])
>>> b = np.array([2, 4, -1])
>>> x = linalg.solve(a, b)
>>> a.dtype, b.dtype
(dtype('int64'), dtype('int64'))
>>> x.dtype
dtype('float64')
>>> # so output dtype should be float64 for integer input, but:
>>> linalg.solve(a, np.empty((3,0), dtype=np.int64)).dtype
dtype('int64')
We've seen in many places in cluster
and fft
as well that our tests don't check for expected dtypes, and we often have inconsistent return dtypes as a result (arguably all bugs).
So there is extra value from testing with other libraries: finding bugs and improving test coverage.
Testing is always nice indeed but the question is what to do when it is broken. I think none of us want to go chasing around PyTorch or CuPy repos for fixing things that is not really meant for us to do just to get our tests out to the greenland. |
I think the same of something breaks in NumPy, Cython, pytest, Sphinx or wherever else: we file an issue and skip the test or put a temporary upper bound. We're using pretty core/standard functions here, so I am not too worried about seeing too many regressions once things work. That would be really surprising. And in terms of debugging or even contributing upstream, I'd much rather work with CuPy or PyTorch than with things like pytest/sphinx/mpmath. Also, the CuPy and PyTorch teams (and Dask and JAX) have invested large amounts of effort in NumPy and SciPy compatibility, so I'm pretty sure they'd appreciate and are willing to address bug reports. |
[skip cirrus] [skip circle]
[skip cirrus] [skip circle]
[skip ci]
[skip cirrus] [skip circle]
The tests still need a good bit of work. But dare I say this is getting pretty close. |
[skip ci]
[skip ci]
[skip ci]
xp_assert_close(u.T @ u, xp.eye(3), atol=1e-6) | ||
xp_assert_close(vh.T @ vh, xp.eye(3), atol=1e-6) | ||
sigma = xp.zeros((u.shape[0], vh.shape[0]), dtype=s.dtype) | ||
for i in range(s.shape[0]): | ||
sigma[i, i] = s[i] | ||
assert_array_almost_equal(u @ sigma @ vh, a) | ||
xp_assert_close(u @ sigma @ vh, a, rtol=1e-6) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the tolerances around here could do with a look, I'm not sure why I wrote a mixture of atol
and rtol
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah I remember, it's because the default for np.testing.assert_array_almost_equal
is roughly equivalent to rtol=0, atol=1.5e-6
.
[skip ci]
[skip ci]
[skip ci]
[skip cirrus]
[skip cirrus]
CI should be green. I think this is almost ready. A few questions remain:
EDIT: spoke too soon on CI but looks like just a EDIT 2: finally green :) |
[skip cirrus] [skip circle]
[skip cirrus] [skip circle]
For context, this work was started as part of my internship at Quansight Labs, which ran until the end of September 2023.
Reference issue
Towards gh-19068 and gh-18867.
Please see gh-19068 for context.
What does this implement/fix?
Support is added for the functions in the array API standard
linalg
extension. This allows users to input arrays from any compatible array library.Tests are modified to allow testing with
numpy.array_api
,cupy
andtorch
. Some new tests are added for exceptions for unsupported parameters.Additional information
I was not able to convert every relevant test since lots of NumPy-specifc things are used. We may want to find ways to convert some of these, or write new tests to serve the same purpose.
TestSVD
is a little strange due to the way thelapack_driver
parameter is tested. I have tried to apply a minimal refactor here, but a more substantial refactor may result in something more readable. It is a bit of a misnomer that all of our array API compatible tests are underTestSVD_GESDD
, just becausegesdd
is the default value.Lots of tests are failing for PyTorch CUDA, but hopefully we just need to wait for pytorch/pytorch#106773 to come through.