Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: stats: add array-API support to kstat/kstatvar #20634

Merged
merged 10 commits into from
May 8, 2024
22 changes: 13 additions & 9 deletions scipy/stats/_morestats.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,23 +283,25 @@ def kstat(data, n=2):
0.00166 0.00166 -4.99e-09
-2.88e-06 -2.88e-06 8.63e-13
"""
xp = array_namespace(data)
data = xp.asarray(data)
if n > 4 or n < 1:
raise ValueError("k-statistics only supported for 1<=n<=4")
n = int(n)
S = np.zeros(n + 1, np.float64)
data = ravel(data)
N = data.size
S = xp.zeros(n + 1, dtype=xp.float64)
data = xp.reshape(data, (-1,))
N = data.shape[0]

# raise ValueError on empty input
if N == 0:
raise ValueError("Data input must not be empty")

# on nan input, return nan without warning
if np.isnan(np.sum(data)):
return np.nan
if xp.isnan(xp.sum(data)):
return _get_nan(data, xp=xp)

for k in range(1, n + 1):
S[k] = np.sum(data**k, axis=0)
S[k] = xp.sum(data**k, axis=0)
if n == 1:
return S[1] * 1.0/N
elif n == 2:
Expand Down Expand Up @@ -356,9 +358,11 @@ def kstatvar(data, n=2):
\frac{72 n \kappa^2_{2} \kappa_4}{(n - 1) (n - 2)} +
\frac{144 n \kappa_{2} \kappa^2_{3}}{(n - 1) (n - 2)} +
\frac{24 (n + 1) n \kappa^4_{2}}{(n - 1) (n - 2) (n - 3)}
""" # noqa: E501
data = ravel(data)
N = len(data)
"""
xp = array_namespace(data)
xp.asarray(data)
data = xp.reshape(data, (-1,))
N = data.shape[0]
if n == 1:
return kstat(data, n=2) * 1.0/N
elif n == 2:
Expand Down
51 changes: 31 additions & 20 deletions scipy/stats/tests/test_morestats.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,14 @@
from .._hypotests import _get_wilcoxon_distr, _get_wilcoxon_distr2
from scipy.stats._binomtest import _binary_search_for_binom_tst
from scipy.stats._distr_params import distcont

from scipy.conftest import array_api_compatible
from scipy._lib._array_api import (array_namespace, xp_assert_close, xp_assert_less,
SCIPY_ARRAY_API, xp_assert_equal)


skip_xp_backends = pytest.mark.skip_xp_backends

distcont = dict(distcont) # type: ignore

# Matplotlib is not a scipy dependency but is optionally used in probplot, so
Expand Down Expand Up @@ -1661,48 +1665,55 @@ def test_permutation_method(self, size):
assert_equal(res.pvalue, ref.pvalue) # random_state used


@array_api_compatible
class TestKstat:
def test_moments_normal_distribution(self):
def test_moments_normal_distribution(self, xp):
np.random.seed(32149)
data = np.random.randn(12345)
moments = [stats.kstat(data, n) for n in [1, 2, 3, 4]]
data = xp.asarray(np.random.randn(12345))
moments = xp.asarray([stats.kstat(data, n) for n in [1, 2, 3, 4]])

expected = [0.011315, 1.017931, 0.05811052, 0.0754134]
assert_allclose(moments, expected, rtol=1e-4)
expected = xp.asarray([0.011315, 1.017931, 0.05811052, 0.0754134],
dtype=data.dtype)
xp_assert_close(moments, expected, rtol=1e-4)

# test equivalence with `stats.moment`
m1 = stats.moment(data, order=1)
m2 = stats.moment(data, order=2)
m3 = stats.moment(data, order=3)
assert_allclose((m1, m2, m3), expected[:-1], atol=0.02, rtol=1e-2)
xp_assert_close(xp.asarray((m1, m2, m3)), expected[:-1], atol=0.02, rtol=1e-2)

def test_empty_input(self):
assert_raises(ValueError, stats.kstat, [])
def test_empty_input(self, xp):
assert_raises(ValueError, stats.kstat, xp.asarray([]))

def test_nan_input(self):
data = np.arange(10.)
data[6] = np.nan
def test_nan_input(self, xp):
data = xp.arange(10.)
data[6] = xp.nan
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine for now, but nice to save @lucascolley some trouble by generating these without mutation as we go forward. Is there a better way than?

Suggested change
data[6] = xp.nan
data = xp.where(data==6, xp.nan, data)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Obvously here this is fine but I do wonder if we encountered mutations in the in loop of a solver, for example, if all the new allocations might be a bit painful performance wise

Copy link
Contributor

@mdhaber mdhaber May 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I wouldn't change existing code like that all the time. Just easy cases like these tests, where it could make the difference between the test running with JAX or not. Cases do need to be considered individually.


assert_equal(stats.kstat(data), np.nan)
xp_assert_equal(stats.kstat(data), xp.asarray(xp.nan))

def test_kstat_bad_arg(self):
@skip_xp_backends(np_only=True,
reasons=['input validation of `n` does not depend on backend'])
def test_kstat_bad_arg(self, xp):
# Raise ValueError if n > 4 or n < 1.
data = np.arange(10)
for n in [0, 4.001]:
assert_raises(ValueError, stats.kstat, data, n=n)


@array_api_compatible
class TestKstatVar:
def test_empty_input(self):
assert_raises(ValueError, stats.kstatvar, [])
def test_empty_input(self, xp):
assert_raises(ValueError, stats.kstatvar, xp.asarray([]))

def test_nan_input(self):
data = np.arange(10.)
data[6] = np.nan
def test_nan_input(self, xp):
data = xp.arange(10.)
data[6] = xp.nan

assert_equal(stats.kstat(data), np.nan)
xp_assert_equal(stats.kstat(data), xp.asarray(xp.nan))

def test_bad_arg(self):
@skip_xp_backends(np_only=True,
reasons=['input validation of `n` does not depend on backend'])
def test_bad_arg(self, xp):
# Raise ValueError is n is not 1 or 2.
data = [1]
n = 10
Expand Down