-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: stats._xp_mean
, an array API compatible mean
with weights
and nan_policy
#20743
base: main
Are you sure you want to change the base?
Conversation
(xp_mean_1samp, tuple(), dict(), 1, 1, False, lambda x: (x,)), | ||
(xp_mean_2samp, tuple(), dict(), 2, 1, True, lambda x: (x,)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most scipy.stats
functions use the _axis_nan_policy
decorator to implement nan_policy
, keepdims
, and tuple axis
. I've implemented all these features natively for improved performance (e.g. nan_policy='omit'
would otherwise loop over each slice), and the function still passes all the tests, which are quite stringent. So if you don't want to write tests with hypothesis
, I'm still pretty comfortable with this.
scipy/_lib/_array_api.py
Outdated
|
||
if weights is not None and x.shape != weights.shape: | ||
try: | ||
x, weights = xp.broadcast_arrays(x, weights) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few thoughts about broadcasting:
- Technically
x = [1, 2, 3]
is broadcastable withweights = [2]
, and it can be interpreted as giving all observations a weight of2
. - Technically,
x = [1]
is broadcastable withweights = [1, 2, 3]
: now we havex
being broadcast to the shape ofweights
rather than the (more natural) other way around. - Technically
x = []
is broadcastable withweights = [1]
:weights
gets broadcasted to shape(0,)
, and the weighted mean is NaN.
It's clearly simpler to just accept these sorts of things, but since they're not useful, one could argue that we shouldn't. I'd propose that we just accept them, but if there are strong opinions about not accepting them, LMK.
scipy/_lib/_array_api.py
Outdated
@@ -475,3 +476,155 @@ def xp_sign(x, xp=None): | |||
sign = xp.where(x < 0, -one, sign) | |||
sign = xp.where(x == 0, 0*one, sign) | |||
return sign | |||
|
|||
|
|||
def xp_add_reduced_axes(res, axis, initial_shape, *, xp=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a note on why this is needed? Is it temporary, why can't xp.add
not be used, etc.?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Type annotations and consistency with other functions in this file would be useful too (at least if you expect this function to stay around for a while).
res
should preferably be positional-only.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps a better name would have been xp_replace_reduced_axes
or xp_keepdims
: it adds back axes that have been reduced away. However, when there are other comments to respond to, I'll just move the logic back into xp_mean
, since I'm not sure if it will be used elsewhere. It can be factored out again as needed. Although the comment wasn't about xp_mean
, I can make the first argument of xp_mean
positional-only.
Why was |
|
Thanks for the link @lucascolley . To align with the naming convention of |
mean
with weights
and nan_policy
xp_mean
, an array API compatible mean
with weights
and nan_policy
xp_mean
, an array API compatible mean
with weights
and nan_policy
stats._xp_mean
, an array API compatible mean
with weights
and nan_policy
Reference issue
Toward gh-20544
What does this implement/fix?
This function adds
_xp_mean
, an array-API compatible function which combines the features ofnp.mean
,np.average
, andnp.nanmean
in interface that fits withscipy.stats
. This will be needed for making functions likepmean
,hmean
, andgmean
array-API compatible.Additional information
Potential reviewers: would you be willing to write some unit tests withhypothesis
? For such a fundamental function, it's particularly important that it works flawlessly!If it doesn't sound too crazy, I'd suggest that this and similarvar
andstd
functions be added publicly toscipy.stats
because they provide functionality that does not exist with the array API (e.g.weights
, which has been explicitly rejected, andnan_policy
, which has not been standardized and may not follow SciPy's convention). Even considering NumPy alone, it would be useful to have a single function that has all the functionality ofmean
,average
, andnanmean
in an interface consistent with the rest ofscipy.stats
.Not pursuing these things right now. Let's just get this in so we can finish the other mean functions.