Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: stats: add array-API support to kstat/kstatvar #20634

Merged
merged 10 commits into from May 8, 2024

Conversation

j-bowhay
Copy link
Member

@j-bowhay j-bowhay commented May 3, 2024

towards #20544

@github-actions github-actions bot added scipy.stats CI Items related to the CI tools such as CircleCI, GitHub Actions or Azure enhancement A new feature or improvement labels May 3, 2024
scipy/stats/_morestats.py Outdated Show resolved Hide resolved
@j-bowhay j-bowhay mentioned this pull request May 3, 2024
68 tasks
@j-bowhay j-bowhay marked this pull request as ready for review May 3, 2024 14:58
Copy link
Contributor

@mdhaber mdhaber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, I'm realizing that I probably shouldn't have put these on the initial list - they really should get native support for an axis argument before we add array API support. Would you add keyword argument axis with default None (for backward compatibility) to the signature and use it as appropriate?

@j-bowhay j-bowhay marked this pull request as draft May 3, 2024 16:57
@mdhaber
Copy link
Contributor

mdhaber commented May 5, 2024

Do you plan to add native axis support to these or should I do that?

@j-bowhay
Copy link
Member Author

j-bowhay commented May 5, 2024

Do you plan to add native axis support to these or should I do that?

I was going to go for some of the lower-hanging fruit first so feel free to go for it in the mean time

@mdhaber
Copy link
Contributor

mdhaber commented May 6, 2024

Added in gh-20651. Sorry for the merge conflicts!

@mdhaber
Copy link
Contributor

mdhaber commented May 6, 2024

The conflicts are actually not as bad as I expected. It probably is worth it to merge and fix them rather than starting over. Please also convert the new tests added in gh-20651.

@j-bowhay
Copy link
Member Author

j-bowhay commented May 7, 2024

Ok almost there just down to the final failures of this style:

scipy/stats/tests/test_morestats.py:1703: in test_nan_input
    xp_assert_equal(stats.kstat(data), xp.asarray(xp.nan))
        data       = Array([ 0.,  1.,  2.,  3.,  4.,  5., nan,
        7.,  8.,  9.], dtype=array_api_strict.float64)
        self       = <scipy.stats.tests.test_morestats.TestKstat object at 0x7f85c7677850>
        xp         = <module 'array_api_strict' from '/home/jakeb/miniconda3/envs/scipy-dev-pytorch/lib/python3.11/site-packages/array_api_strict/__init__.py'>
scipy/stats/_axis_nan_policy.py:405: in axis_nan_policy_wrapper
    return hypotest_fun_in(*args, **kwds)
        _no_deco   = False
        args       = (Array([ 0.,  1.,  2.,  3.,  4.,  5., nan,
        7.,  8.,  9.], dtype=array_api_strict.float64),)
        default_axis = None
        hypotest_fun_in = <function kstat at 0x7f8671774400>
        is_too_small = <function _axis_nan_policy_factory.<locals>.is_too_small at 0x7f86717742c0>
        kwd_samples = []
        kwds       = {}
        msg        = 'Use of `nan_policy` and `keepdims` is incompatible with non-NumPy arrays.'
        n_outputs  = 1
        n_samples  = 1
        override   = {'nan_propagation': True, 'vectorization': False}
        paired     = False
        result_to_tuple = <function <lambda> at 0x7f8671774220>
        temp       = Array([ 0.,  1.,  2.,  3.,  4.,  5., nan,
        7.,  8.,  9.], dtype=array_api_strict.float64)
        tuple_to_result = <function <lambda> at 0x7f8671774180>
scipy/stats/_morestats.py:307: in kstat
    S = [None] + [xp.sum(data**k, axis=axis) for k in range(1, n + 1)]
        N          = 10
        axis       = 0
        data       = array([ 0.,  1.,  2.,  3.,  4.,  5., nan,  7.,  8.,  9.])
        n          = 2
        xp         = <module 'array_api_strict' from '/home/jakeb/miniconda3/envs/scipy-dev-pytorch/lib/python3.11/site-packages/array_api_strict/__init__.py'>
scipy/stats/_morestats.py:307: in <listcomp>
    S = [None] + [xp.sum(data**k, axis=axis) for k in range(1, n + 1)]
        .0         = <range_iterator object at 0x7f85c75be580>
        axis       = 0
        data       = array([ 0.,  1.,  2.,  3.,  4.,  5., nan,  7.,  8.,  9.])
        k          = 1
        xp         = <module 'array_api_strict' from '/home/jakeb/miniconda3/envs/scipy-dev-pytorch/lib/python3.11/site-packages/array_api_strict/__init__.py'>
/home/jakeb/miniconda3/envs/scipy-dev-pytorch/lib/python3.11/site-packages/array_api_strict/_statistical_functions.py:100: in sum
    if x.dtype not in _numeric_dtypes:
        axis       = 0
        dtype      = None
        keepdims   = False
        x          = array([ 0.,  1.,  2.,  3.,  4.,  5., nan,  7.,  8.,  9.])
/home/jakeb/miniconda3/envs/scipy-dev-pytorch/lib/python3.11/site-packages/array_api_strict/_dtypes.py:24: in __eq__
    warnings.warn(
E   UserWarning: You are comparing a array_api_strict dtype against a NumPy native dtype object, but you probably don't want to do this. array_api_strict dtype objects compare unequal to their NumPy equivalents. Such cross-library comparison is not supported by the standard.
        other      = dtype('float64')
        self       = array_api_strict.float32
============================================================================================== short test summary info ==============================================================================================
FAILED scipy/stats/tests/test_morestats.py::TestKstat::test_moments_normal_distribution[array_api_strict] - UserWarning: You are comparing a array_api_strict dtype against a NumPy native dtype object, but you probably don't want to do this. array_api_strict dtype objects compare unequal to their NumPy equivalents. ...
FAILED scipy/stats/tests/test_morestats.py::TestKstat::test_nan_input[array_api_strict] - UserWarning: You are comparing a array_api_strict dtype against a NumPy native dtype object, but you probably don't want to do this. array_api_strict dtype objects compare unequal to their NumPy equivalents. ...

hopefully will be obvious when it isn't so late!

@mdhaber
Copy link
Contributor

mdhaber commented May 7, 2024

Line 296 and the next few don't look like they've been converted.

@j-bowhay
Copy link
Member Author

j-bowhay commented May 7, 2024

Line 296 and the next few don't look like they've been converted.

Ha yes thanks, temporary post coursework submission blindness...

@j-bowhay j-bowhay marked this pull request as ready for review May 7, 2024 08:54
Copy link
Contributor

@mdhaber mdhaber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The statistic conversion looks pretty good, but the tests still need to become @array_api_compatible. We can add tests with the intent of checking the behavior for NumPy lists, but we also need to test behavior with backends other than NumPy.

scipy/stats/tests/test_morestats.py Outdated Show resolved Hide resolved
scipy/stats/tests/test_morestats.py Outdated Show resolved Hide resolved
Copy link
Contributor

@mdhaber mdhaber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After this, would you be willing to submit another PR that cleans up the notes of these? For example:

image

worse is that these correspond with:

image

but the implementation is:
image

np.random.seed(32149)
data = np.random.randn(12345)
moments = [stats.kstat(data, n) for n in [1, 2, 3, 4]]
data = xp.asarray(np.random.randn(12345), dtype=xp.float64)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even torch will produce a float64 Tensor if it is generated from a NumPy float64 array, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I can't quite remember why I did this, will have a look in the follow up

data[6] = np.nan
def test_nan_input(self, xp):
data = xp.arange(10.)
data[6] = xp.nan
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine for now, but nice to save @lucascolley some trouble by generating these without mutation as we go forward. Is there a better way than?

Suggested change
data[6] = xp.nan
data = xp.where(data==6, xp.nan, data)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Obvously here this is fine but I do wonder if we encountered mutations in the in loop of a solver, for example, if all the new allocations might be a bit painful performance wise

Copy link
Contributor

@mdhaber mdhaber May 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I wouldn't change existing code like that all the time. Just easy cases like these tests, where it could make the difference between the test running with JAX or not. Cases do need to be considered individually.

@mdhaber mdhaber merged commit 5b7e5a0 into scipy:main May 8, 2024
30 checks passed
@dschmitz89 dschmitz89 added this to the 1.14.0 milestone May 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CI Items related to the CI tools such as CircleCI, GitHub Actions or Azure enhancement A new feature or improvement scipy.stats
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants