Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: stats.monte_carlo_test: add array API support #20604

Merged
merged 10 commits into from May 5, 2024

Conversation

mdhaber
Copy link
Contributor

@mdhaber mdhaber commented Apr 28, 2024

Reference issue

Towards gh-20544

What does this implement/fix?

This PR adds array-API support for scipy.stats.monte_carlo_test.

Additional information

Example (15~50x speedup locally, depending on dtype):

import numpy as np
import cupy as cp
from scipy import stats

rng = np.random.default_rng()
data_np = rng.standard_normal(size=100)
data_cp64 = cp.asarray(data_np, dtype=cp.float64)
data_cp32 = cp.asarray(data_np, dtype=cp.float32)

ref = stats.ttest_1samp(data_np, 0.)

def get_statistic(xp):
    def statistic(x, axis=0):
        m = xp.mean(x, axis=axis)
        # this is just the user-defined statistic;
        # it doesn't need to be array-API compliant
        v = xp.var(x, axis=axis, ddof=1)
        n = x.shape[axis]
        return m / (v / n) ** 0.5
    return statistic

# NumPy 64-bit
rng_np = np.random.default_rng()
res1 = stats.monte_carlo_test(data_np, rng_np.standard_normal, get_statistic(np), n_resamples=99999)
# 123 ms ± 1.02 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# CuPy 64-bit
rng_cp = cp.random.default_rng()
res2 = stats.monte_carlo_test(data_cp64, rng_cp.standard_normal, get_statistic(cp), n_resamples=99999)
# 8.27 ms ± 767 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

# CuPy 32-bit
def rvs(size):
    return rng_cp.standard_normal(size=size, dtype=cp.float32)
res3 = stats.monte_carlo_test(data_cp32, rvs, get_statistic(cp), n_resamples=99999)
# 2.51 ms ± 33.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

print(ref.pvalue, res1.pvalue, res2.pvalue, res3.pvalue)
# 0.7506365948009475 0.754 0.75104 0.74768

@mdhaber mdhaber added scipy.stats enhancement A new feature or improvement labels Apr 28, 2024
@github-actions github-actions bot added scipy._lib CI Items related to the CI tools such as CircleCI, GitHub Actions or Azure array types Items related to array API support and input array validation (see gh-18286) labels Apr 28, 2024
@mdhaber mdhaber changed the title Xp monte carlo test ENH: stats.monte_carlo_test: add array API support Apr 28, 2024
Copy link
Contributor Author

@mdhaber mdhaber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some self-review to aid the reviewer.

scipy/_lib/_array_api.py Show resolved Hide resolved
scipy/_lib/_array_api.py Show resolved Hide resolved
scipy/stats/_axis_nan_policy.py Show resolved Hide resolved
scipy/stats/_axis_nan_policy.py Show resolved Hide resolved
scipy/stats/_resampling.py Show resolved Hide resolved
scipy/stats/_resampling.py Show resolved Hide resolved
@@ -1,13 +1,18 @@
import numpy as np
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, didn't want to add to the mess of imports.

scipy/stats/tests/test_resampling.py Show resolved Hide resolved
scipy/stats/tests/test_resampling.py Show resolved Hide resolved
scipy/stats/tests/test_resampling.py Outdated Show resolved Hide resolved
@mdhaber mdhaber requested a review from tupui April 29, 2024 00:06
scipy/_lib/_array_api.py Outdated Show resolved Hide resolved
@rgommers
Copy link
Member

rgommers commented May 1, 2024

Example (15~50x speedup locally, depending on dtype):

Very nice! This diff should be pretty clean once array-api-compat supports v2023.12 of the standard.

@mdhaber
Copy link
Contributor Author

mdhaber commented May 1, 2024

Thanks for taking a look @rgommers; responses are above. Another array API thing you might have thoughts about is the idea of making the default tolerance of xp_assert_close dtype-dependent.
#20597 (review)
The intent is to save us from constantly forcing torch arrays to be float64 to meet tolerances and/or setting tolerances for different dtypes manually.

@mdhaber mdhaber mentioned this pull request May 2, 2024
68 tasks
@rgommers
Copy link
Member

rgommers commented May 3, 2024

Towards #20544

If stats functions are going to be done one by one over many PRs rather than in one go, would it make sense to have a separate tracker for stats to point to? Having an overview of the coverage in stats will be useful (and help contributors pick functions to work on), and it's a bit much to put that overview in the scipy-wide tracker probably.

@mdhaber
Copy link
Contributor Author

mdhaber commented May 3, 2024

would it make sense to have a separate tracker for stats to point to?

? You referred to the separate stats tracker.

@rgommers
Copy link
Member

rgommers commented May 3, 2024

? You referred to the separate stats tracker.

oops, sorry for the noise. temporary blindness there

@mdhaber
Copy link
Contributor Author

mdhaber commented May 3, 2024

@rgommers I know you'll be out for a bit - shall I ask others to review/merge this in the meantime, or did you want to continue with it?

@rgommers
Copy link
Member

rgommers commented May 3, 2024

Thanks for asking. Please don't feel like you have to wait for me.

@mdhaber mdhaber requested a review from j-bowhay May 3, 2024 15:18
Copy link
Contributor Author

@mdhaber mdhaber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replied to comments. Thanks @j-bowhay!

scipy/stats/tests/test_resampling.py Show resolved Hide resolved
scipy/stats/tests/test_resampling.py Show resolved Hide resolved
scipy/stats/tests/test_resampling.py Show resolved Hide resolved
scipy/stats/tests/test_resampling.py Show resolved Hide resolved
[skip cirrus] [skip circle]
Copy link
Member

@lucascolley lucascolley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this all looks idiomatic to me, at a glance at least 👍

x = rng.random(10)
x = xp.asarray(rng.standard_normal(size=10))

xp_test = array_namespace(x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a comment for why this is needed here too? (thanks for adding in the other places)

Copy link
Member

@j-bowhay j-bowhay left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think this should be good to go once the merge conflicts and Lucas's comment is resolved

@j-bowhay j-bowhay merged commit 9ee34f9 into scipy:main May 5, 2024
30 checks passed
@j-bowhay j-bowhay added this to the 1.14.0 milestone May 5, 2024
@mdhaber
Copy link
Contributor Author

mdhaber commented May 5, 2024

Thanks, all! I think I'll do power next, because it is quite similar and doesn't require an RNG.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
array types Items related to array API support and input array validation (see gh-18286) CI Items related to the CI tools such as CircleCI, GitHub Actions or Azure enhancement A new feature or improvement scipy._lib scipy.stats
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants