Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: stats._xp_mean, an array API compatible mean with weights and nan_policy #20743

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

mdhaber
Copy link
Contributor

@mdhaber mdhaber commented May 18, 2024

Reference issue

Toward gh-20544

What does this implement/fix?

This function adds _xp_mean, an array-API compatible function which combines the features of np.mean, np.average, and np.nanmean in interface that fits with scipy.stats. This will be needed for making functions like pmean, hmean, and gmean array-API compatible.

Additional information

Potential reviewers: would you be willing to write some unit tests with hypothesis? For such a fundamental function, it's particularly important that it works flawlessly!

If it doesn't sound too crazy, I'd suggest that this and similar var and std functions be added publicly to scipy.stats because they provide functionality that does not exist with the array API (e.g. weights, which has been explicitly rejected, and nan_policy, which has not been standardized and may not follow SciPy's convention). Even considering NumPy alone, it would be useful to have a single function that has all the functionality of mean, average, and nanmean in an interface consistent with the rest of scipy.stats.

Not pursuing these things right now. Let's just get this in so we can finish the other mean functions.

@mdhaber mdhaber added scipy.stats enhancement A new feature or improvement array types Items related to array API support and input array validation (see gh-18286) labels May 18, 2024
scipy/_lib/_array_api.py Outdated Show resolved Hide resolved
scipy/_lib/_array_api.py Outdated Show resolved Hide resolved
scipy/_lib/_array_api.py Outdated Show resolved Hide resolved
scipy/_lib/_array_api.py Outdated Show resolved Hide resolved
scipy/_lib/_array_api.py Outdated Show resolved Hide resolved
scipy/_lib/tests/test_array_api.py Outdated Show resolved Hide resolved
scipy/_lib/tests/test_array_api.py Outdated Show resolved Hide resolved
scipy/_lib/tests/test_array_api.py Outdated Show resolved Hide resolved
scipy/_lib/tests/test_array_api.py Outdated Show resolved Hide resolved
scipy/_lib/tests/test_array_api.py Outdated Show resolved Hide resolved
@mdhaber mdhaber marked this pull request as ready for review May 18, 2024 23:28
Comment on lines +126 to +127
(xp_mean_1samp, tuple(), dict(), 1, 1, False, lambda x: (x,)),
(xp_mean_2samp, tuple(), dict(), 2, 1, True, lambda x: (x,)),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most scipy.stats functions use the _axis_nan_policy decorator to implement nan_policy, keepdims, and tuple axis. I've implemented all these features natively for improved performance (e.g. nan_policy='omit' would otherwise loop over each slice), and the function still passes all the tests, which are quite stringent. So if you don't want to write tests with hypothesis, I'm still pretty comfortable with this.


if weights is not None and x.shape != weights.shape:
try:
x, weights = xp.broadcast_arrays(x, weights)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few thoughts about broadcasting:

  • Technically x = [1, 2, 3] is broadcastable with weights = [2], and it can be interpreted as giving all observations a weight of 2.
  • Technically, x = [1] is broadcastable with weights = [1, 2, 3]: now we have x being broadcast to the shape of weights rather than the (more natural) other way around.
  • Technically x = [] is broadcastable with weights = [1]: weights gets broadcasted to shape (0,), and the weighted mean is NaN.

It's clearly simpler to just accept these sorts of things, but since they're not useful, one could argue that we shouldn't. I'd propose that we just accept them, but if there are strong opinions about not accepting them, LMK.

@@ -475,3 +476,155 @@ def xp_sign(x, xp=None):
sign = xp.where(x < 0, -one, sign)
sign = xp.where(x == 0, 0*one, sign)
return sign


def xp_add_reduced_axes(res, axis, initial_shape, *, xp=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a note on why this is needed? Is it temporary, why can't xp.add not be used, etc.?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type annotations and consistency with other functions in this file would be useful too (at least if you expect this function to stay around for a while).

res should preferably be positional-only.

Copy link
Contributor Author

@mdhaber mdhaber May 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps a better name would have been xp_replace_reduced_axes or xp_keepdims: it adds back axes that have been reduced away. However, when there are other comments to respond to, I'll just move the logic back into xp_mean, since I'm not sure if it will be used elsewhere. It can be factored out again as needed. Although the comment wasn't about xp_mean, I can make the first argument of xp_mean positional-only.

@mdhaber mdhaber mentioned this pull request May 19, 2024
74 tasks
@fancidev
Copy link
Contributor

Why was weights explicitly rejected for the Array API? Would you by chance have a link or something for the discussion back then?

@lucascolley
Copy link
Member

Why was weights explicitly rejected for the Array API? Would you by chance have a link or something for the discussion back then?

data-apis/array-api#366

@fancidev
Copy link
Contributor

Thanks for the link @lucascolley .

To align with the naming convention of hmean, pmean, and gmean, would it be more appropriate to call the function amean (a for arithmetic)?

@lucascolley lucascolley changed the title ENH: xp_mean: an array-API compatible mean with weights and nan_policy ENH: xp_mean, an array API compatible mean with weights and nan_policy Jun 2, 2024
@mdhaber mdhaber changed the title ENH: xp_mean, an array API compatible mean with weights and nan_policy ENH: stats._xp_mean, an array API compatible mean with weights and nan_policy Jun 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
array types Items related to array API support and input array validation (see gh-18286) enhancement A new feature or improvement scipy._lib scipy.stats
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants