Skip to content

Commit

Permalink
ENH: stats.monte_carlo_test: add array API support (#20604)
Browse files Browse the repository at this point in the history
* MAINT: stats._broadcast_arrays: add array API support

* ENH: stats.monte_carlo_test: add array-API support to input validation

* ENH: stats.monte_carlo_test: add array-API support

* TST: stats.monte_carlo_test: convert input validation test to array API

* TST: stats.monte_carlo_test: make test_axis array-API compatible

* TST: stats.monte_carlo_test: array-API tests for batch and alternative

* MAINT: stats.monte_carlo_test: preserve dtype

* Update scipy/stats/tests/test_resampling.py

* TST: stats.monte_carlo_test: comment on why xp_test is needed
  • Loading branch information
mdhaber committed May 5, 2024
1 parent 4a537b7 commit 9ee34f9
Show file tree
Hide file tree
Showing 6 changed files with 215 additions and 106 deletions.
1 change: 1 addition & 0 deletions .github/workflows/array_api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,4 @@ jobs:
python dev.py --no-build test -b all -t scipy._lib.tests.test__util -- --durations 3 --timeout=60
python dev.py --no-build test -b all -t scipy.stats.tests.test_stats -- --durations 3 --timeout=60
python dev.py --no-build test -b all -t scipy.stats.tests.test_morestats -- --durations 3 --timeout=60
python dev.py --no-build test -b all -t scipy.stats.tests.test_resampling -- --durations 3 --timeout=60
22 changes: 21 additions & 1 deletion scipy/_lib/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,4 +398,24 @@ def xp_minimum(x1, x2):
res = xp.asarray(x1, copy=True, dtype=dtype)
i = (x2 < x1) | xp.isnan(x2)
res[i] = x2[i]
return res
return res[()] if res.ndim == 0 else res


# temporary substitute for xp.clip, which is not yet in all backends
# or covered by array_api_compat.
def xp_clip(x, a, b, xp=None):
xp = array_namespace(xp) if xp is None else xp
y = xp.asarray(x, copy=True)
y[y < a] = a
y[y > b] = b
return y[()] if y.ndim == 0 else y


# temporary substitute for xp.moveaxis, which is not yet in all backends
# or covered by array_api_compat.
def _move_axis_to_end(x, source, xp=None):
xp = array_namespace(xp) if xp is None else xp
axes = list(range(x.ndim))
temp = axes.pop(source)
axes = axes + [temp]
return xp.permute_dims(x, axes)
23 changes: 9 additions & 14 deletions scipy/stats/_axis_nan_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,20 @@
import inspect


def _broadcast_arrays(arrays, axis=None):
def _broadcast_arrays(arrays, axis=None, xp=None):
"""
Broadcast shapes of arrays, ignoring incompatibility of specified axes
"""
new_shapes = _broadcast_array_shapes(arrays, axis=axis)
xp = array_namespace(*arrays) if xp is None else xp
arrays = [xp.asarray(arr) for arr in arrays]
shapes = [arr.shape for arr in arrays]
new_shapes = _broadcast_shapes(shapes, axis)
if axis is None:
new_shapes = [new_shapes]*len(arrays)
return [np.broadcast_to(array, new_shape)
return [xp.broadcast_to(array, new_shape)
for array, new_shape in zip(arrays, new_shapes)]


def _broadcast_array_shapes(arrays, axis=None):
"""
Broadcast shapes of arrays, ignoring incompatibility of specified axes
"""
shapes = [np.asarray(arr).shape for arr in arrays]
return _broadcast_shapes(shapes, axis)


def _broadcast_shapes(shapes, axis=None):
"""
Broadcast shapes, ignoring incompatibility of specified axes
Expand Down Expand Up @@ -103,10 +98,10 @@ def _broadcast_array_shapes_remove_axis(arrays, axis=None):
Examples
--------
>>> import numpy as np
>>> from scipy.stats._axis_nan_policy import _broadcast_array_shapes
>>> from scipy.stats._axis_nan_policy import _broadcast_array_shapes_remove_axis
>>> a = np.zeros((5, 2, 1))
>>> b = np.zeros((9, 3))
>>> _broadcast_array_shapes((a, b), 1)
>>> _broadcast_array_shapes_remove_axis((a, b), 1)
(5, 3)
"""
# Note that here, `axis=None` means do not consume/drop any axes - _not_
Expand All @@ -119,7 +114,7 @@ def _broadcast_shapes_remove_axis(shapes, axis=None):
"""
Broadcast shapes, dropping specified axes
Same as _broadcast_array_shapes, but given a sequence
Same as _broadcast_array_shapes_remove_axis, but given a sequence
of array shapes `shapes` instead of the arrays themselves.
"""
shapes = _broadcast_shapes(shapes, axis)
Expand Down
73 changes: 50 additions & 23 deletions scipy/stats/_resampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
import numpy as np
from itertools import combinations, permutations, product
from collections.abc import Sequence
from dataclasses import dataclass
import inspect

from scipy._lib._util import check_random_state, _rename_parameter
from scipy._lib._util import check_random_state, _rename_parameter, rng_integers
from scipy._lib._array_api import (array_namespace, is_numpy, xp_minimum,
xp_clip, _move_axis_to_end)
from scipy.special import ndtr, ndtri, comb, factorial
from scipy._lib._util import rng_integers
from dataclasses import dataclass

from ._common import ConfidenceInterval
from ._axis_nan_policy import _broadcast_concatenate, _broadcast_arrays
from ._warnings_errors import DegenerateDataWarning
Expand Down Expand Up @@ -662,7 +664,6 @@ def percentile_fun(a, q):
def _monte_carlo_test_iv(data, rvs, statistic, vectorized, n_resamples,
batch, alternative, axis):
"""Input validation for `monte_carlo_test`."""

axis_int = int(axis)
if axis != axis_int:
raise ValueError("`axis` must be an integer.")
Expand All @@ -677,26 +678,45 @@ def _monte_carlo_test_iv(data, rvs, statistic, vectorized, n_resamples,
if not callable(rvs_i):
raise TypeError("`rvs` must be callable or sequence of callables.")

# At this point, `data` should be a sequence
# If it isn't, the user passed a sequence for `rvs` but not `data`
message = "If `rvs` is a sequence, `len(rvs)` must equal `len(data)`."
try:
len(data)
except TypeError as e:
raise ValueError(message) from e
if not len(rvs) == len(data):
message = "If `rvs` is a sequence, `len(rvs)` must equal `len(data)`."
raise ValueError(message)

if not callable(statistic):
raise TypeError("`statistic` must be callable.")

if vectorized is None:
vectorized = 'axis' in inspect.signature(statistic).parameters
try:
signature = inspect.signature(statistic).parameters
except ValueError as e:
message = (f"Signature inspection of {statistic=} failed; "
"pass `vectorize` explicitly.")
raise ValueError(message) from e
vectorized = 'axis' in signature

xp = array_namespace(*data)

if not vectorized:
statistic_vectorized = _vectorize_statistic(statistic)
if is_numpy(xp):
statistic_vectorized = _vectorize_statistic(statistic)
else:
message = ("`statistic` must be vectorized (i.e. support an `axis` "
f"argument) when `data` contains {xp.__name__} arrays.")
raise ValueError(message)
else:
statistic_vectorized = statistic

data = _broadcast_arrays(data, axis)
data = _broadcast_arrays(data, axis, xp=xp)
data_iv = []
for sample in data:
sample = np.atleast_1d(sample)
sample = np.moveaxis(sample, axis_int, -1)
sample = xp.broadcast_to(sample, (1,)) if sample.ndim == 0 else sample
sample = _move_axis_to_end(sample, axis_int, xp=xp)
data_iv.append(sample)

n_resamples_int = int(n_resamples)
Expand All @@ -715,8 +735,12 @@ def _monte_carlo_test_iv(data, rvs, statistic, vectorized, n_resamples,
if alternative not in alternatives:
raise ValueError(f"`alternative` must be in {alternatives}")

# Infer the desired p-value dtype based on the input types
min_float = getattr(xp, 'float16', xp.float32)
dtype = xp.result_type(*data_iv, min_float)

return (data_iv, rvs, statistic_vectorized, vectorized, n_resamples_int,
batch_iv, alternative, axis_int)
batch_iv, alternative, axis_int, dtype, xp)


@dataclass
Expand Down Expand Up @@ -908,11 +932,12 @@ def monte_carlo_test(data, rvs, statistic, *, vectorized=None,
"""
args = _monte_carlo_test_iv(data, rvs, statistic, vectorized,
n_resamples, batch, alternative, axis)
(data, rvs, statistic, vectorized,
n_resamples, batch, alternative, axis) = args
(data, rvs, statistic, vectorized, n_resamples,
batch, alternative, axis, dtype, xp) = args

# Some statistics return plain floats; ensure they're at least a NumPy float
observed = np.asarray(statistic(*data, axis=-1))[()]
observed = xp.asarray(statistic(*data, axis=-1))
observed = observed[()] if observed.ndim == 0 else observed

n_observations = [sample.shape[-1] for sample in data]
batch_nominal = batch or n_resamples
Expand All @@ -922,37 +947,39 @@ def monte_carlo_test(data, rvs, statistic, *, vectorized=None,
resamples = [rvs_i(size=(batch_actual, n_observations_i))
for rvs_i, n_observations_i in zip(rvs, n_observations)]
null_distribution.append(statistic(*resamples, axis=-1))
null_distribution = np.concatenate(null_distribution)
null_distribution = null_distribution.reshape([-1] + [1]*observed.ndim)
null_distribution = xp.concat(null_distribution)
null_distribution = xp.reshape(null_distribution, [-1] + [1]*observed.ndim)

# relative tolerance for detecting numerically distinct but
# theoretically equal values in the null distribution
eps = (0 if not np.issubdtype(observed.dtype, np.inexact)
else np.finfo(observed.dtype).eps*100)
gamma = np.abs(eps * observed)
eps = (0 if not xp.isdtype(observed.dtype, ('real floating'))
else xp.finfo(observed.dtype).eps*100)
gamma = xp.abs(eps * observed)

def less(null_distribution, observed):
cmps = null_distribution <= observed + gamma
pvalues = (cmps.sum(axis=0) + 1) / (n_resamples + 1) # see [1]
cmps = xp.asarray(cmps, dtype=dtype)
pvalues = (xp.sum(cmps, axis=0, dtype=dtype) + 1.) / (n_resamples + 1.)
return pvalues

def greater(null_distribution, observed):
cmps = null_distribution >= observed - gamma
pvalues = (cmps.sum(axis=0) + 1) / (n_resamples + 1) # see [1]
cmps = xp.asarray(cmps, dtype=dtype)
pvalues = (xp.sum(cmps, axis=0, dtype=dtype) + 1.) / (n_resamples + 1.)
return pvalues

def two_sided(null_distribution, observed):
pvalues_less = less(null_distribution, observed)
pvalues_greater = greater(null_distribution, observed)
pvalues = np.minimum(pvalues_less, pvalues_greater) * 2
pvalues = xp_minimum(pvalues_less, pvalues_greater) * 2
return pvalues

compare = {"less": less,
"greater": greater,
"two-sided": two_sided}

pvalues = compare[alternative](null_distribution, observed)
pvalues = np.clip(pvalues, 0, 1)
pvalues = xp_clip(pvalues, 0., 1., xp=xp)

return MonteCarloTestResult(observed, pvalues, null_distribution)

Expand Down
19 changes: 3 additions & 16 deletions scipy/stats/_stats_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@
from scipy import stats
from scipy.optimize import root_scalar
from scipy._lib._util import normalize_axis_index
from scipy._lib._array_api import array_namespace, is_numpy, atleast_nd
from scipy._lib._array_api import (array_namespace, is_numpy, atleast_nd,
xp_clip, _move_axis_to_end)
from scipy._lib.array_api_compat import size as xp_size

# In __all__ but deprecated for removal in SciPy 1.13.0
Expand Down Expand Up @@ -4535,20 +4536,6 @@ def confidence_interval(self, confidence_level=0.95, method=None):
return ci


def _move_axis_to_end(x, source, xp):
axes = list(range(x.ndim))
temp = axes.pop(source)
axes = axes + [temp]
return xp.permute_dims(x, axes)


def _clip(x, a, b, xp):
y = xp.asarray(x, copy=True)
y[y < a] = a
y[y > b] = b
return y


def pearsonr(x, y, *, alternative='two-sided', method=None, axis=0):
r"""
Pearson correlation coefficient and p-value for testing non-correlation.
Expand Down Expand Up @@ -4934,7 +4921,7 @@ def statistic(x, y, axis):
one = xp.asarray(1, dtype=dtype)
# `clip` only recently added to array API, so it's not yet available in
# array_api_strict. Replace with e.g. `xp.clip(r, -one, one)` when available.
r = xp.asarray(_clip(r, -one, one, xp))
r = xp.asarray(xp_clip(r, -one, one, xp))
r[const_xy] = xp.nan

# As explained in the docstring, the distribution of `r` under the null
Expand Down

0 comments on commit 9ee34f9

Please sign in to comment.