New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: stats: add array-API support to kstat/kstatvar #20634
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops, I'm realizing that I probably shouldn't have put these on the initial list - they really should get native support for an axis
argument before we add array API support. Would you add keyword argument axis
with default None
(for backward compatibility) to the signature and use it as appropriate?
Do you plan to add native |
I was going to go for some of the lower-hanging fruit first so feel free to go for it in the mean time |
Added in gh-20651. Sorry for the merge conflicts! |
The conflicts are actually not as bad as I expected. It probably is worth it to merge and fix them rather than starting over. Please also convert the new tests added in gh-20651. |
[skip ci]
Ok almost there just down to the final failures of this style:
hopefully will be obvious when it isn't so late! |
Line 296 and the next few don't look like they've been converted. |
Ha yes thanks, temporary post coursework submission blindness... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The statistic conversion looks pretty good, but the tests still need to become @array_api_compatible
. We can add tests with the intent of checking the behavior for NumPy lists, but we also need to test behavior with backends other than NumPy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
np.random.seed(32149) | ||
data = np.random.randn(12345) | ||
moments = [stats.kstat(data, n) for n in [1, 2, 3, 4]] | ||
data = xp.asarray(np.random.randn(12345), dtype=xp.float64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even torch
will produce a float64
Tensor if it is generated from a NumPy float64
array, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I can't quite remember why I did this, will have a look in the follow up
data[6] = np.nan | ||
def test_nan_input(self, xp): | ||
data = xp.arange(10.) | ||
data[6] = xp.nan |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fine for now, but nice to save @lucascolley some trouble by generating these without mutation as we go forward. Is there a better way than?
data[6] = xp.nan | |
data = xp.where(data==6, xp.nan, data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Obvously here this is fine but I do wonder if we encountered mutations in the in loop of a solver, for example, if all the new allocations might be a bit painful performance wise
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I wouldn't change existing code like that all the time. Just easy cases like these tests, where it could make the difference between the test running with JAX or not. Cases do need to be considered individually.
towards #20544