Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: linalg: support array API for standard extension functions #19260

Open
wants to merge 42 commits into
base: main
Choose a base branch
from

Conversation

lucascolley
Copy link
Member

@lucascolley lucascolley commented Sep 19, 2023

For context, this work was started as part of my internship at Quansight Labs, which ran until the end of September 2023.

Reference issue

Towards gh-19068 and gh-18867.
Please see gh-19068 for context.

What does this implement/fix?

Support is added for the functions in the array API standard linalg extension. This allows users to input arrays from any compatible array library.

Tests are modified to allow testing with numpy.array_api, cupy and torch. Some new tests are added for exceptions for unsupported parameters.

Additional information

I was not able to convert every relevant test since lots of NumPy-specifc things are used. We may want to find ways to convert some of these, or write new tests to serve the same purpose.

TestSVD is a little strange due to the way the lapack_driver parameter is tested. I have tried to apply a minimal refactor here, but a more substantial refactor may result in something more readable. It is a bit of a misnomer that all of our array API compatible tests are under TestSVD_GESDD, just because gesdd is the default value.

Lots of tests are failing for PyTorch CUDA, but hopefully we just need to wait for pytorch/pytorch#106773 to come through.

scipy/_lib/_array_api.py Outdated Show resolved Hide resolved
scipy/linalg/tests/test_basic.py Outdated Show resolved Hide resolved
scipy/linalg/tests/test_basic.py Outdated Show resolved Hide resolved
scipy/linalg/tests/test_decomp.py Show resolved Hide resolved
scipy/linalg/tests/test_decomp_cholesky.py Outdated Show resolved Hide resolved
@lucascolley lucascolley marked this pull request as ready for review September 19, 2023 13:17
@lucascolley lucascolley changed the title WIP, ENH: linalg: support array API for standard extension functions ENH: linalg: support array API for standard extension functions Sep 19, 2023
scipy/linalg/_basic.py Outdated Show resolved Hide resolved
@j-bowhay j-bowhay added enhancement A new feature or improvement scipy.linalg array types Items related to array API support and input array validation (see gh-18286) labels Sep 19, 2023
Copy link
Member

@ilayn ilayn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the regular function modifications are OK to me. But I'm not sure if we really need all the test changes. It doesn't reflect the goals of the tests at certain places and I'm not sure I understand why we test nonnumpy xp namespace in the tests. It should be pretty safe to test SciPy with default numpy arrays without any xp_assert_close's or modified tol parameters.

scipy/linalg/tests/test_basic.py Outdated Show resolved Hide resolved
scipy/linalg/tests/test_basic.py Outdated Show resolved Hide resolved
scipy/linalg/tests/test_basic.py Outdated Show resolved Hide resolved
scipy/linalg/tests/test_basic.py Outdated Show resolved Hide resolved
scipy/linalg/tests/test_basic.py Outdated Show resolved Hide resolved
scipy/linalg/tests/test_basic.py Show resolved Hide resolved
scipy/linalg/tests/test_basic.py Outdated Show resolved Hide resolved
scipy/linalg/tests/test_basic.py Show resolved Hide resolved
scipy/linalg/tests/test_decomp.py Outdated Show resolved Hide resolved
Copy link
Member

@rgommers rgommers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks pretty good to me overall. I'd like to see the diff shrink as much as possible though, both in the implementations and tests. That's in general a sign that the code is in good shape, and it makes things easier to review and understand later on.

I think the regular function modifications are OK to me. But I'm not sure if we really need all the test changes. It doesn't reflect the goals of the tests at certain places and I'm not sure I understand why we test nonnumpy xp namespace in the tests. It should be pretty safe to test SciPy with default numpy arrays without any xp_assert_close's or modified tol parameters.

I think it is quite useful to test non-numpy arrays; without testing it's almost certainly going to be broken. I think of these as testing optional dependencies - just like we have tests with mpmath, scikit-umfpack and a whole bunch of other optional runtime dependencies.

Sometimes this requires extra code, however it also tends to uncover bugs and non-standard code constructs that (when refactored) improve the code itself. Two examples here:

  1. The changes here from integer to floating point arrays for testing make sense. Functions like solve are inherently floating point-only. We also need to test that integers (and lists, and other array-like's) are still converted correctly and we don't break backwards compat. However, that can be a single small test. That many tests use integers is a matter of previous authors taking a shortcut because it didn't matter much, rather than all those tests using integers by design.
  2. The stricter dtype checks led me to spot a bug quickly when Lucas asked me about result_type:
>>> import numpy as np
>>> a = np.array([[3, 2, 0], [1, -1, 0], [0, 5, 1]])
>>> b = np.array([2, 4, -1])
>>> x = linalg.solve(a, b)
>>> a.dtype, b.dtype
(dtype('int64'), dtype('int64'))
>>> x.dtype
dtype('float64')

>>> # so output dtype should be float64 for integer input, but:
>>> linalg.solve(a, np.empty((3,0), dtype=np.int64)).dtype
dtype('int64')

We've seen in many places in cluster and fft as well that our tests don't check for expected dtypes, and we often have inconsistent return dtypes as a result (arguably all bugs).

So there is extra value from testing with other libraries: finding bugs and improving test coverage.

scipy/linalg/tests/test_decomp_cholesky.py Outdated Show resolved Hide resolved
scipy/linalg/_decomp.py Outdated Show resolved Hide resolved
scipy/linalg/tests/test_basic.py Outdated Show resolved Hide resolved
scipy/linalg/tests/test_decomp.py Outdated Show resolved Hide resolved
@ilayn
Copy link
Member

ilayn commented Sep 27, 2023

I think it is quite useful to test non-numpy arrays; without testing it's almost certainly going to be broken. I think of these as testing optional dependencies - just like we have tests with mpmath, scikit-umfpack and a whole bunch of other optional runtime dependencies.

Testing is always nice indeed but the question is what to do when it is broken. I think none of us want to go chasing around PyTorch or CuPy repos for fixing things that is not really meant for us to do just to get our tests out to the greenland.

@rgommers
Copy link
Member

Testing is always nice indeed but the question is what to do when it is broken. I think none of us want to go chasing around PyTorch or CuPy repos for fixing things that is not really meant for us to do just to get our tests out to the greenland.

I think the same of something breaks in NumPy, Cython, pytest, Sphinx or wherever else: we file an issue and skip the test or put a temporary upper bound. We're using pretty core/standard functions here, so I am not too worried about seeing too many regressions once things work. That would be really surprising. And in terms of debugging or even contributing upstream, I'd much rather work with CuPy or PyTorch than with things like pytest/sphinx/mpmath.

Also, the CuPy and PyTorch teams (and Dask and JAX) have invested large amounts of effort in NumPy and SciPy compatibility, so I'm pretty sure they'd appreciate and are willing to address bug reports.

@lucascolley lucascolley marked this pull request as draft March 26, 2024 18:30
scipy/linalg/_decomp.py Outdated Show resolved Hide resolved
@lucascolley
Copy link
Member Author

The tests still need a good bit of work. But dare I say this is getting pretty close.

scipy/linalg/_decomp_cholesky.py Outdated Show resolved Hide resolved
scipy/linalg/tests/test_decomp.py Show resolved Hide resolved
scipy/linalg/tests/test_decomp.py Show resolved Hide resolved
scipy/linalg/tests/test_decomp.py Show resolved Hide resolved
Comment on lines +1066 to +1071
xp_assert_close(u.T @ u, xp.eye(3), atol=1e-6)
xp_assert_close(vh.T @ vh, xp.eye(3), atol=1e-6)
sigma = xp.zeros((u.shape[0], vh.shape[0]), dtype=s.dtype)
for i in range(s.shape[0]):
sigma[i, i] = s[i]
assert_array_almost_equal(u @ sigma @ vh, a)
xp_assert_close(u @ sigma @ vh, a, rtol=1e-6)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the tolerances around here could do with a look, I'm not sure why I wrote a mixture of atol and rtol.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah I remember, it's because the default for np.testing.assert_array_almost_equal is roughly equivalent to rtol=0, atol=1.5e-6.

scipy/linalg/tests/test_basic.py Outdated Show resolved Hide resolved
scipy/linalg/tests/test_basic.py Outdated Show resolved Hide resolved
scipy/linalg/tests/test_basic.py Outdated Show resolved Hide resolved
scipy/linalg/tests/test_basic.py Outdated Show resolved Hide resolved
scipy/linalg/tests/test_basic.py Outdated Show resolved Hide resolved
@lucascolley
Copy link
Member Author

lucascolley commented Apr 1, 2024

CI should be green. I think this is almost ready. A few questions remain:

  • this PR involves a lot of general improvements to the tests (checking more dtypes, checking shapes, stricter tolerances), but clearly these improvements haven't been pushed to be optimal. I think doing so would be a huge effort given the size of this diff, but happy to work a little more if there are particular areas that could do with some TLC.
  • I don't know if we want somewhat of a policy about what to do with tolerances. I've basically just gone with the defaults of the assertions, but where identity matrices are used we get atol failures due to things being non-0, so I have introduced atol bumps where needed (the default atol is 0).
  • A lot of the tests which are currently skipped could be split up into parts which are compatible and parts which aren't (at least for this PR). I think it would be too much effort to split up all of them, but some are maybe worth it. I've marked a few with TODOs.

EDIT: spoke too soon on CI but looks like just a atol=0 thing so far

EDIT 2: finally green :)

@lucascolley lucascolley requested a review from ilayn April 1, 2024 23:09
@lucascolley lucascolley marked this pull request as ready for review April 1, 2024 23:11
[skip cirrus] [skip circle]
[skip cirrus] [skip circle]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
array types Items related to array API support and input array validation (see gh-18286) enhancement A new feature or improvement scipy.linalg
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants