Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: stats: add array-API support to kstat/kstatvar #20634

Merged
merged 10 commits into from
May 8, 2024
12 changes: 7 additions & 5 deletions scipy/stats/_morestats.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,13 @@ def kstat(data, n=2, *, axis=None):
0.00166 0.00166 -4.99e-09
-2.88e-06 -2.88e-06 8.63e-13
"""
xp = array_namespace(data)
data = xp.asarray(data)
if n > 4 or n < 1:
raise ValueError("k-statistics only supported for 1<=n<=4")
n = int(n)
data = np.asarray(data)
if axis is None:
data = ravel(data)
data = xp.reshape(data, (-1,))
axis = 0

N = data.shape[axis]
Expand All @@ -302,7 +303,7 @@ def kstat(data, n=2, *, axis=None):
if N == 0:
raise ValueError("Data input must not be empty")

S = [None] + [np.sum(data**k, axis=axis) for k in range(1, n + 1)]
S = [None] + [xp.sum(data**k, axis=axis) for k in range(1, n + 1)]
if n == 1:
return S[1] * 1.0/N
elif n == 2:
Expand Down Expand Up @@ -365,9 +366,10 @@ def kstatvar(data, n=2, *, axis=None):
\frac{144 n \kappa_{2} \kappa^2_{3}}{(n - 1) (n - 2)} +
\frac{24 (n + 1) n \kappa^4_{2}}{(n - 1) (n - 2) (n - 3)}
""" # noqa: E501
data = np.asarray(data)
xp = array_namespace(data)
data = xp.asarray(data)
if axis is None:
data = ravel(data)
data = xp.reshape(data, (-1,))
axis = 0
N = data.shape[axis]

Expand Down
3 changes: 3 additions & 0 deletions scipy/stats/_stats_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -2829,6 +2829,9 @@ def sem(a, axis=0, ddof=1, nan_policy='propagate'):

"""
xp = array_namespace(a)
if axis is None:
a = xp.reshape(a, (-1,))
axis = 0
a = atleast_nd(a, ndim=1, xp=xp)
n = a.shape[axis]
s = xp.std(a, axis=axis, correction=ddof) / n**0.5
Expand Down
89 changes: 50 additions & 39 deletions scipy/stats/tests/test_morestats.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,14 @@
from .._hypotests import _get_wilcoxon_distr, _get_wilcoxon_distr2
from scipy.stats._binomtest import _binary_search_for_binom_tst
from scipy.stats._distr_params import distcont

from scipy.conftest import array_api_compatible
from scipy._lib._array_api import (array_namespace, xp_assert_close, xp_assert_less,
SCIPY_ARRAY_API, xp_assert_equal)


skip_xp_backends = pytest.mark.skip_xp_backends

distcont = dict(distcont) # type: ignore

# Matplotlib is not a scipy dependency but is optionally used in probplot, so
Expand Down Expand Up @@ -1669,36 +1673,38 @@ def test_permutation_method(self, size):
12.10, 15.02, 16.83, 16.98, 19.92, 9.47, 11.68, 13.41, 15.35, 19.11]


@array_api_compatible
class TestKstat:
def test_moments_normal_distribution(self):
def test_moments_normal_distribution(self, xp):
np.random.seed(32149)
data = np.random.randn(12345)
moments = [stats.kstat(data, n) for n in [1, 2, 3, 4]]
data = xp.asarray(np.random.randn(12345), dtype=xp.float64)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even torch will produce a float64 Tensor if it is generated from a NumPy float64 array, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I can't quite remember why I did this, will have a look in the follow up

moments = xp.asarray([stats.kstat(data, n) for n in [1, 2, 3, 4]])

expected = [0.011315, 1.017931, 0.05811052, 0.0754134]
assert_allclose(moments, expected, rtol=1e-4)
expected = xp.asarray([0.011315, 1.017931, 0.05811052, 0.0754134],
dtype=data.dtype)
xp_assert_close(moments, expected, rtol=1e-4)

# test equivalence with `stats.moment`
m1 = stats.moment(data, order=1)
m2 = stats.moment(data, order=2)
m3 = stats.moment(data, order=3)
assert_allclose((m1, m2, m3), expected[:-1], atol=0.02, rtol=1e-2)
xp_assert_close(xp.asarray((m1, m2, m3)), expected[:-1], atol=0.02, rtol=1e-2)

def test_empty_input(self):
def test_empty_input(self, xp):
message = 'Data input must not be empty'
with pytest.raises(ValueError, match=message):
stats.kstat([])
stats.kstat(xp.asarray([]))

def test_nan_input(self):
data = np.arange(10.)
data[6] = np.nan
def test_nan_input(self, xp):
data = xp.arange(10.)
data[6] = xp.nan
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine for now, but nice to save @lucascolley some trouble by generating these without mutation as we go forward. Is there a better way than?

Suggested change
data[6] = xp.nan
data = xp.where(data==6, xp.nan, data)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Obvously here this is fine but I do wonder if we encountered mutations in the in loop of a solver, for example, if all the new allocations might be a bit painful performance wise

Copy link
Contributor

@mdhaber mdhaber May 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I wouldn't change existing code like that all the time. Just easy cases like these tests, where it could make the difference between the test running with JAX or not. Cases do need to be considered individually.


assert_equal(stats.kstat(data), np.nan)
xp_assert_equal(stats.kstat(data), xp.asarray(xp.nan))

@pytest.mark.parametrize('n', [0, 4.001])
def test_kstat_bad_arg(self, n):
def test_kstat_bad_arg(self, n, xp):
# Raise ValueError if n > 4 or n < 1.
data = np.arange(10)
data = xp.arange(10)
message = 'k-statistics only supported for 1<=n<=4'
with pytest.raises(ValueError, match=message):
stats.kstat(data, n=n)
Expand All @@ -1707,7 +1713,7 @@ def test_kstat_bad_arg(self, n):
(2, 12.65006954022974),
(3, -1.447059503280798),
(4, -141.6682291883626)])
def test_against_R(self, case):
def test_against_R(self, case, xp):
# Test against reference values computed with R kStatistics, e.g.
# options(digits=16)
# library(kStatistics)
Expand All @@ -1717,32 +1723,35 @@ def test_against_R(self, case):
# 19.92, 9.47, 11.68, 13.41, 15.35, 19.11)
# nKS(4, data)
n, ref = case
res = stats.kstat(x_kstat, n)
assert_allclose(res, ref)
res = stats.kstat(xp.asarray(x_kstat), n)
xp_assert_close(res, xp.asarray(ref))



@array_api_compatible
class TestKstatVar:
def test_empty_input(self):
def test_empty_input(self, xp):
message = 'Data input must not be empty'
with pytest.raises(ValueError, match=message):
stats.kstatvar([])
stats.kstatvar(xp.asarray([]))

def test_nan_input(self):
data = np.arange(10.)
data[6] = np.nan
def test_nan_input(self, xp):
data = xp.arange(10.)
data[6] = xp.nan

assert_equal(stats.kstat(data), np.nan)
xp_assert_equal(stats.kstat(data), xp.asarray(xp.nan))

def test_bad_arg(self):
@skip_xp_backends(np_only=True,
reasons=['input validation of `n` does not depend on backend'])
def test_bad_arg(self, xp):
# Raise ValueError is n is not 1 or 2.
data = [1]
n = 10
message = 'Only n=1 or n=2 supported.'
with pytest.raises(ValueError, match=message):
stats.kstatvar(data, n=n)

def test_against_R_mathworld(self):
def test_against_R_mathworld(self, xp):
# Test against reference values computed using formulas exactly as
# they appear at https://mathworld.wolfram.com/k-Statistic.html
# This is *really* similar to how they appear in the implementation,
Expand All @@ -1751,14 +1760,14 @@ def test_against_R_mathworld(self):
k2 = 12.65006954022974 # see source code in TestKstat
k4 = -141.6682291883626

res = stats.kstatvar(x_kstat, 1)
res = stats.kstatvar(xp.asarray(x_kstat), 1)
ref = k2 / n
assert_allclose(res, ref)
xp_assert_close(res, xp.asarray(ref))

res = stats.kstatvar(x_kstat, 2)
res = stats.kstatvar(xp.asarray(x_kstat), 2)
# *unbiased estimator* for var(k2)
ref = (2*k2**2*n + (n-1)*k4) / (n * (n+1))
assert_allclose(res, ref)
xp_assert_close(res, xp.asarray(ref))


class TestPpccPlot:
Expand Down Expand Up @@ -3011,29 +3020,31 @@ def test_edge_cases(self):
assert_array_equal(stats.false_discovery_control([]), [])


@array_api_compatible
class TestCommonAxis:
# More thorough testing of `axis` in `test_axis_nan_policy`,
# but those testa aren't run with array API yet. This class
# but those tests aren't run with array API yet. This class
# is in `test_morestats` instead of `test_axis_nan_policy`
# because there is no reason to run `test_axis_nan_policy`
# with the array API CI job right now.

@pytest.mark.parametrize('case', [(stats.sem, {}),
(stats.kstat, {'n': 4}),
(stats.kstat, {'n': 2})])
def test_axis(self, case):
(stats.kstat, {'n': 2}),
(stats.variation, {})])
def test_axis(self, case, xp):
fun, kwargs = case
rng = np.random.default_rng(24598245982345)
x = rng.random((6, 7))
x = xp.asarray(rng.random((6, 7)))

res = fun(x, **kwargs, axis=0)
ref = [fun(x[:, i], **kwargs) for i in range(x.shape[1])]
assert_allclose(res, ref)
ref = xp.asarray([fun(x[:, i], **kwargs) for i in range(x.shape[1])])
xp_assert_close(res, ref)

res = fun(x, **kwargs, axis=1)
ref = [fun(x[i, :], **kwargs) for i in range(x.shape[0])]
assert_allclose(res, ref)
ref = xp.asarray([fun(x[i, :], **kwargs) for i in range(x.shape[0])])
xp_assert_close(res, ref)

res = fun(x, **kwargs, axis=None)
ref = fun(x.ravel(), **kwargs)
assert_allclose(res, ref)
ref = fun(xp.reshape(x, (-1,)), **kwargs)
xp_assert_close(res, ref)