Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: stats.variation: add array-API support #20647

Merged
merged 4 commits into from May 6, 2024
Merged

Conversation

j-bowhay
Copy link
Member

@j-bowhay j-bowhay commented May 5, 2024

towards #20544

@github-actions github-actions bot added scipy.stats CI Items related to the CI tools such as CircleCI, GitHub Actions or Azure enhancement A new feature or improvement labels May 5, 2024
scipy/stats/_variation.py Outdated Show resolved Hide resolved
scipy/stats/_variation.py Outdated Show resolved Hide resolved
@j-bowhay j-bowhay marked this pull request as ready for review May 5, 2024 16:58
@j-bowhay j-bowhay mentioned this pull request May 5, 2024
68 tasks
# torch
sup.filter(UserWarning, "std*")
y = variation(x)
xp_assert_equal(y, xp.asarray(xp.nan, dtype=x.dtype))
Copy link
Contributor

@mdhaber mdhaber May 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a big deal, but you could change np.zeros(5) to [0.]*5 to avoid the explicit dtype, if we prefer testing the default type. Then again, one could argue that now this tests both the default type and float64 (if they are different), and coverage is a good thing. Either way; just an observation.

Copy link
Contributor

@mdhaber mdhaber May 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant that since you changed np.zeros to [0.]*, you should be able to do:

Suggested change
xp_assert_equal(y, xp.asarray(xp.nan, dtype=x.dtype))
xp_assert_equal(y, xp.asarray(xp.nan))

Copy link
Contributor

@mdhaber mdhaber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems straightforward, and I'm not seeing anything wrong. I'll take another pass after the merge conflict is resolved and we hear back about what we should do with the JAX-incompatible code.

result[i] = np.copysign(result[i], mean_a[i])
return result[()]
std_a = xp.std(a, axis=axis, correction=0)
result = xp.where(std_a > 0, xp.copysign(xp.asarray(xp.inf), mean_a), NaN)
Copy link
Contributor

@mdhaber mdhaber May 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops. Looks like we need an xp_copysign in the meantime (j-bowhay#4). I'd recommend installing array_api_strict.

@mdhaber mdhaber merged commit 44c3b6f into scipy:main May 6, 2024
30 checks passed
@j-bowhay j-bowhay added this to the 1.14.0 milestone May 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CI Items related to the CI tools such as CircleCI, GitHub Actions or Azure enhancement A new feature or improvement scipy.stats
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants