-
-
Notifications
You must be signed in to change notification settings - Fork 5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: stats: add array-API support to kstat/kstatvar #20634
Changes from all commits
d7f7f03
3dc78dd
318f998
47fbe5a
812f5e6
5737033
8180862
a3848ac
747e654
1166d1c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -21,10 +21,14 @@ | |||||
from .._hypotests import _get_wilcoxon_distr, _get_wilcoxon_distr2 | ||||||
from scipy.stats._binomtest import _binary_search_for_binom_tst | ||||||
from scipy.stats._distr_params import distcont | ||||||
|
||||||
from scipy.conftest import array_api_compatible | ||||||
from scipy._lib._array_api import (array_namespace, xp_assert_close, xp_assert_less, | ||||||
SCIPY_ARRAY_API, xp_assert_equal) | ||||||
|
||||||
|
||||||
skip_xp_backends = pytest.mark.skip_xp_backends | ||||||
|
||||||
distcont = dict(distcont) # type: ignore | ||||||
|
||||||
# Matplotlib is not a scipy dependency but is optionally used in probplot, so | ||||||
|
@@ -1669,36 +1673,38 @@ def test_permutation_method(self, size): | |||||
12.10, 15.02, 16.83, 16.98, 19.92, 9.47, 11.68, 13.41, 15.35, 19.11] | ||||||
|
||||||
|
||||||
@array_api_compatible | ||||||
class TestKstat: | ||||||
def test_moments_normal_distribution(self): | ||||||
def test_moments_normal_distribution(self, xp): | ||||||
np.random.seed(32149) | ||||||
data = np.random.randn(12345) | ||||||
moments = [stats.kstat(data, n) for n in [1, 2, 3, 4]] | ||||||
data = xp.asarray(np.random.randn(12345), dtype=xp.float64) | ||||||
moments = xp.asarray([stats.kstat(data, n) for n in [1, 2, 3, 4]]) | ||||||
|
||||||
expected = [0.011315, 1.017931, 0.05811052, 0.0754134] | ||||||
assert_allclose(moments, expected, rtol=1e-4) | ||||||
expected = xp.asarray([0.011315, 1.017931, 0.05811052, 0.0754134], | ||||||
dtype=data.dtype) | ||||||
xp_assert_close(moments, expected, rtol=1e-4) | ||||||
|
||||||
# test equivalence with `stats.moment` | ||||||
m1 = stats.moment(data, order=1) | ||||||
m2 = stats.moment(data, order=2) | ||||||
m3 = stats.moment(data, order=3) | ||||||
assert_allclose((m1, m2, m3), expected[:-1], atol=0.02, rtol=1e-2) | ||||||
xp_assert_close(xp.asarray((m1, m2, m3)), expected[:-1], atol=0.02, rtol=1e-2) | ||||||
|
||||||
def test_empty_input(self): | ||||||
def test_empty_input(self, xp): | ||||||
message = 'Data input must not be empty' | ||||||
with pytest.raises(ValueError, match=message): | ||||||
stats.kstat([]) | ||||||
stats.kstat(xp.asarray([])) | ||||||
|
||||||
def test_nan_input(self): | ||||||
data = np.arange(10.) | ||||||
data[6] = np.nan | ||||||
def test_nan_input(self, xp): | ||||||
data = xp.arange(10.) | ||||||
data[6] = xp.nan | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fine for now, but nice to save @lucascolley some trouble by generating these without mutation as we go forward. Is there a better way than?
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Obvously here this is fine but I do wonder if we encountered mutations in the in loop of a solver, for example, if all the new allocations might be a bit painful performance wise There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, I wouldn't change existing code like that all the time. Just easy cases like these tests, where it could make the difference between the test running with JAX or not. Cases do need to be considered individually. |
||||||
|
||||||
assert_equal(stats.kstat(data), np.nan) | ||||||
xp_assert_equal(stats.kstat(data), xp.asarray(xp.nan)) | ||||||
|
||||||
@pytest.mark.parametrize('n', [0, 4.001]) | ||||||
def test_kstat_bad_arg(self, n): | ||||||
def test_kstat_bad_arg(self, n, xp): | ||||||
# Raise ValueError if n > 4 or n < 1. | ||||||
data = np.arange(10) | ||||||
data = xp.arange(10) | ||||||
message = 'k-statistics only supported for 1<=n<=4' | ||||||
with pytest.raises(ValueError, match=message): | ||||||
stats.kstat(data, n=n) | ||||||
|
@@ -1707,7 +1713,7 @@ def test_kstat_bad_arg(self, n): | |||||
(2, 12.65006954022974), | ||||||
(3, -1.447059503280798), | ||||||
(4, -141.6682291883626)]) | ||||||
def test_against_R(self, case): | ||||||
def test_against_R(self, case, xp): | ||||||
# Test against reference values computed with R kStatistics, e.g. | ||||||
# options(digits=16) | ||||||
# library(kStatistics) | ||||||
|
@@ -1717,32 +1723,35 @@ def test_against_R(self, case): | |||||
# 19.92, 9.47, 11.68, 13.41, 15.35, 19.11) | ||||||
# nKS(4, data) | ||||||
n, ref = case | ||||||
res = stats.kstat(x_kstat, n) | ||||||
assert_allclose(res, ref) | ||||||
res = stats.kstat(xp.asarray(x_kstat), n) | ||||||
xp_assert_close(res, xp.asarray(ref)) | ||||||
|
||||||
|
||||||
|
||||||
@array_api_compatible | ||||||
class TestKstatVar: | ||||||
def test_empty_input(self): | ||||||
def test_empty_input(self, xp): | ||||||
message = 'Data input must not be empty' | ||||||
with pytest.raises(ValueError, match=message): | ||||||
stats.kstatvar([]) | ||||||
stats.kstatvar(xp.asarray([])) | ||||||
|
||||||
def test_nan_input(self): | ||||||
data = np.arange(10.) | ||||||
data[6] = np.nan | ||||||
def test_nan_input(self, xp): | ||||||
data = xp.arange(10.) | ||||||
data[6] = xp.nan | ||||||
|
||||||
assert_equal(stats.kstat(data), np.nan) | ||||||
xp_assert_equal(stats.kstat(data), xp.asarray(xp.nan)) | ||||||
|
||||||
def test_bad_arg(self): | ||||||
@skip_xp_backends(np_only=True, | ||||||
reasons=['input validation of `n` does not depend on backend']) | ||||||
def test_bad_arg(self, xp): | ||||||
# Raise ValueError is n is not 1 or 2. | ||||||
data = [1] | ||||||
n = 10 | ||||||
message = 'Only n=1 or n=2 supported.' | ||||||
with pytest.raises(ValueError, match=message): | ||||||
stats.kstatvar(data, n=n) | ||||||
|
||||||
def test_against_R_mathworld(self): | ||||||
def test_against_R_mathworld(self, xp): | ||||||
# Test against reference values computed using formulas exactly as | ||||||
# they appear at https://mathworld.wolfram.com/k-Statistic.html | ||||||
# This is *really* similar to how they appear in the implementation, | ||||||
|
@@ -1751,14 +1760,14 @@ def test_against_R_mathworld(self): | |||||
k2 = 12.65006954022974 # see source code in TestKstat | ||||||
k4 = -141.6682291883626 | ||||||
|
||||||
res = stats.kstatvar(x_kstat, 1) | ||||||
res = stats.kstatvar(xp.asarray(x_kstat), 1) | ||||||
ref = k2 / n | ||||||
assert_allclose(res, ref) | ||||||
xp_assert_close(res, xp.asarray(ref)) | ||||||
|
||||||
res = stats.kstatvar(x_kstat, 2) | ||||||
res = stats.kstatvar(xp.asarray(x_kstat), 2) | ||||||
# *unbiased estimator* for var(k2) | ||||||
ref = (2*k2**2*n + (n-1)*k4) / (n * (n+1)) | ||||||
assert_allclose(res, ref) | ||||||
xp_assert_close(res, xp.asarray(ref)) | ||||||
|
||||||
|
||||||
class TestPpccPlot: | ||||||
|
@@ -3011,29 +3020,31 @@ def test_edge_cases(self): | |||||
assert_array_equal(stats.false_discovery_control([]), []) | ||||||
|
||||||
|
||||||
@array_api_compatible | ||||||
class TestCommonAxis: | ||||||
# More thorough testing of `axis` in `test_axis_nan_policy`, | ||||||
# but those testa aren't run with array API yet. This class | ||||||
# but those tests aren't run with array API yet. This class | ||||||
# is in `test_morestats` instead of `test_axis_nan_policy` | ||||||
# because there is no reason to run `test_axis_nan_policy` | ||||||
# with the array API CI job right now. | ||||||
|
||||||
@pytest.mark.parametrize('case', [(stats.sem, {}), | ||||||
(stats.kstat, {'n': 4}), | ||||||
(stats.kstat, {'n': 2})]) | ||||||
def test_axis(self, case): | ||||||
(stats.kstat, {'n': 2}), | ||||||
(stats.variation, {})]) | ||||||
def test_axis(self, case, xp): | ||||||
fun, kwargs = case | ||||||
rng = np.random.default_rng(24598245982345) | ||||||
x = rng.random((6, 7)) | ||||||
x = xp.asarray(rng.random((6, 7))) | ||||||
|
||||||
res = fun(x, **kwargs, axis=0) | ||||||
ref = [fun(x[:, i], **kwargs) for i in range(x.shape[1])] | ||||||
assert_allclose(res, ref) | ||||||
ref = xp.asarray([fun(x[:, i], **kwargs) for i in range(x.shape[1])]) | ||||||
xp_assert_close(res, ref) | ||||||
|
||||||
res = fun(x, **kwargs, axis=1) | ||||||
ref = [fun(x[i, :], **kwargs) for i in range(x.shape[0])] | ||||||
assert_allclose(res, ref) | ||||||
ref = xp.asarray([fun(x[i, :], **kwargs) for i in range(x.shape[0])]) | ||||||
xp_assert_close(res, ref) | ||||||
|
||||||
res = fun(x, **kwargs, axis=None) | ||||||
ref = fun(x.ravel(), **kwargs) | ||||||
assert_allclose(res, ref) | ||||||
ref = fun(xp.reshape(x, (-1,)), **kwargs) | ||||||
xp_assert_close(res, ref) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even
torch
will produce afloat64
Tensor if it is generated from a NumPyfloat64
array, right?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I can't quite remember why I did this, will have a look in the follow up