Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: stats: add array-API support to kstat/kstatvar #20634

Merged
merged 10 commits into from
May 8, 2024
Merged
1 change: 1 addition & 0 deletions .github/workflows/array_api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,4 @@ jobs:
python dev.py --no-build test -b all -t scipy._lib.tests.test_array_api
python dev.py --no-build test -b all -t scipy._lib.tests.test__util -- --durations 3 --timeout=60
python dev.py --no-build test -b all -t scipy.stats.tests.test_stats -- --durations 3 --timeout=60
python dev.py --no-build test -b all -t scipy.stats.tests.test_morestats -- --durations 3 --timeout=60
23 changes: 14 additions & 9 deletions scipy/stats/_morestats.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from scipy import optimize, special, interpolate, stats
from scipy._lib._bunch import _make_tuple_bunch
from scipy._lib._util import _rename_parameter, _contains_nan, _get_nan
from scipy._lib._array_api import array_namespace

from ._ansari_swilk_statistics import gscale, swilk
from . import _stats_py, _wilcoxon
Expand Down Expand Up @@ -283,23 +284,25 @@ def kstat(data, n=2):
0.00166 0.00166 -4.99e-09
-2.88e-06 -2.88e-06 8.63e-13
"""
xp = array_namespace(data)
data = xp.asarray(data)
if n > 4 or n < 1:
raise ValueError("k-statistics only supported for 1<=n<=4")
n = int(n)
S = np.zeros(n + 1, np.float64)
data = ravel(data)
N = data.size
S = xp.zeros(n + 1, dtype=xp.float64)
data = xp.reshape(data, (-1,))
N = data.shape[0]

# raise ValueError on empty input
if N == 0:
raise ValueError("Data input must not be empty")

# on nan input, return nan without warning
if np.isnan(np.sum(data)):
return np.nan
if xp.isnan(xp.sum(data)):
return xp.nan
j-bowhay marked this conversation as resolved.
Show resolved Hide resolved

for k in range(1, n + 1):
S[k] = np.sum(data**k, axis=0)
S[k] = xp.sum(data**k, axis=0)
if n == 1:
return S[1] * 1.0/N
elif n == 2:
Expand Down Expand Up @@ -356,9 +359,11 @@ def kstatvar(data, n=2):
\frac{72 n \kappa^2_{2} \kappa_4}{(n - 1) (n - 2)} +
\frac{144 n \kappa_{2} \kappa^2_{3}}{(n - 1) (n - 2)} +
\frac{24 (n + 1) n \kappa^4_{2}}{(n - 1) (n - 2) (n - 3)}
""" # noqa: E501
data = ravel(data)
N = len(data)
"""
xp = array_namespace(data)
xp.asarray(data)
data = xp.reshape(data, (-1,))
N = data.shape[0]
if n == 1:
return kstat(data, n=2) * 1.0/N
elif n == 2:
Expand Down
54 changes: 32 additions & 22 deletions scipy/stats/tests/test_morestats.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
from .._hypotests import _get_wilcoxon_distr, _get_wilcoxon_distr2
from scipy.stats._binomtest import _binary_search_for_binom_tst
from scipy.stats._distr_params import distcont
from scipy._lib._array_api import SCIPY_ARRAY_API
from scipy._lib._array_api import SCIPY_ARRAY_API, xp_assert_equal, xp_assert_close
from scipy.conftest import array_api_compatible

skip_xp_backends = pytest.mark.skip_xp_backends

distcont = dict(distcont) # type: ignore

Expand Down Expand Up @@ -1659,48 +1662,55 @@ def test_permutation_method(self, size):
assert_equal(res.pvalue, ref.pvalue) # random_state used


@array_api_compatible
class TestKstat:
def test_moments_normal_distribution(self):
def test_moments_normal_distribution(self, xp):
np.random.seed(32149)
data = np.random.randn(12345)
moments = [stats.kstat(data, n) for n in [1, 2, 3, 4]]
data = xp.asarray(np.random.randn(12345))
moments = xp.asarray([stats.kstat(data, n) for n in [1, 2, 3, 4]])

expected = [0.011315, 1.017931, 0.05811052, 0.0754134]
assert_allclose(moments, expected, rtol=1e-4)
expected = xp.asarray([0.011315, 1.017931, 0.05811052, 0.0754134],
dtype=data.dtype)
xp_assert_close(moments, expected, rtol=1e-4)

# test equivalence with `stats.moment`
m1 = stats.moment(data, order=1)
m2 = stats.moment(data, order=2)
m3 = stats.moment(data, order=3)
assert_allclose((m1, m2, m3), expected[:-1], atol=0.02, rtol=1e-2)

def test_empty_input(self):
assert_raises(ValueError, stats.kstat, [])
xp_assert_close(xp.asarray((m1, m2, m3)), expected[:-1], atol=0.02, rtol=1e-2)
def test_empty_input(self, xp):
assert_raises(ValueError, stats.kstat, xp.asarray([]))

def test_nan_input(self):
data = np.arange(10.)
data[6] = np.nan
def test_nan_input(self, xp):
data = xp.arange(10.)
data[6] = xp.nan
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine for now, but nice to save @lucascolley some trouble by generating these without mutation as we go forward. Is there a better way than?

Suggested change
data[6] = xp.nan
data = xp.where(data==6, xp.nan, data)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Obvously here this is fine but I do wonder if we encountered mutations in the in loop of a solver, for example, if all the new allocations might be a bit painful performance wise

Copy link
Contributor

@mdhaber mdhaber May 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I wouldn't change existing code like that all the time. Just easy cases like these tests, where it could make the difference between the test running with JAX or not. Cases do need to be considered individually.


assert_equal(stats.kstat(data), np.nan)
xp_assert_equal(stats.kstat(data), xp.asarray(xp.nan))

def test_kstat_bad_arg(self):
@skip_xp_backends(np_only=True,
reasons=['input validatio of `n` does not depend on backend'])
def test_kstat_bad_arg(self, xp):
# Raise ValueError if n > 4 or n < 1.
data = np.arange(10)
for n in [0, 4.001]:
assert_raises(ValueError, stats.kstat, data, n=n)


@array_api_compatible
class TestKstatVar:
def test_empty_input(self):
assert_raises(ValueError, stats.kstatvar, [])
def test_empty_input(self, xp):
assert_raises(ValueError, stats.kstatvar, xp.asarray([]))

def test_nan_input(self):
data = np.arange(10.)
data[6] = np.nan
def test_nan_input(self, xp):
data = xp.arange(10.)
data[6] = xp.nan

assert_equal(stats.kstat(data), np.nan)
xp_assert_equal(stats.kstat(data), xp.asarray(xp.nan))

def test_bad_arg(self):
@skip_xp_backends(np_only=True,
reasons=['input validatio of `n` does not depend on backend'])
def test_bad_arg(self, xp):
# Raise ValueError is n is not 1 or 2.
data = [1]
n = 10
Expand Down