Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: stats.gstd: add array API support #20285

Closed
wants to merge 2 commits into from
Closed

Conversation

mdhaber
Copy link
Contributor

@mdhaber mdhaber commented Mar 19, 2024

Reference issue

Towards gh-18867

What does this implement/fix?

Adds array API support to gstd.

Additional information

Can we deprecate masked array support for this function to simplify things?

I think this would pass tests, but:

import torch
from scipy._lib._array_api import array_namespace
x = torch.asarray([1, 2, 3., 3, 4.5])
xp = array_namespace(x)
xp.std(x, correction=1)

results in:

Traceback (most recent call last):
  File "/Users/matthaberland/miniforge3/envs/scipy-dev/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3526, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-17-a926c1507628>", line 1, in <module>
    xp.std(x, correction=1)
  File "/Users/matthaberland/Desktop/scipy/scipy/_lib/array_api_compat/array_api_compat/torch/_aliases.py", line 396, in std
    res = torch.std(x, tuple(range(x.ndim)), correction=_correction, **kwargs)
                                                        ^^^^^^^^^^^
UnboundLocalError: cannot access local variable '_correction' where it is not associated with a value

Maybe it's a real bug, or maybe I have an old version of something?

@mdhaber mdhaber added scipy.stats enhancement A new feature or improvement array types Items related to array API support and input array validation (see gh-18286) labels Mar 19, 2024
Comment on lines -3233 to -3234
if ((a_nan_any and np.less_equal(np.nanmin(a), 0)) or
(not a_nan_any and np.less_equal(a, 0).any())):
Copy link
Contributor Author

@mdhaber mdhaber Mar 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the reason for tiptoeing around NaNs here. AFAICT, the behavior is the same even if NaNs are present, and no warnings are emitted. Am I missing something?

scipy/stats/_stats_py.py Outdated Show resolved Hide resolved
@@ -3240,7 +3244,7 @@ def gstd(a, axis=0, ddof=1):
raise ValueError(w) from w
else:
# Remaining warnings don't need to be exceptions.
return np.exp(np.std(log(a, where=~a_nan), axis=axis, ddof=ddof))
return xp.exp(xp.std(log(a), axis=axis, **ddof))
Copy link
Contributor Author

@mdhaber mdhaber Mar 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly here. I don't know of circumstances under which taking the log of nan emits a warning or returns anything other than nan. Am I missing something?

@@ -3213,25 +3214,28 @@ def gstd(a, axis=0, ddof=1):
fill_value=999999)

"""
a = np.asanyarray(a)
log = ma.log if isinstance(a, ma.MaskedArray) else np.log
if isinstance(a, ma.MaskedArray):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could just error out on MaskedArrays when the environment variable is set and skip any corresponding tests when it is set.

Copy link
Contributor Author

@mdhaber mdhaber Mar 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I see. Basically just:

        xp = array_namespace(a)  # errors on masked arrays, returns array_api_compat numpy otherwise 
        log = xp.log  # see if this works as-is, if not, use ternary to get `ma.log` 
        ddof = {'correction': ddof}  # this should work, I think
        a = xp.asarray(a)  # this would probably have to branch so as not to change masked arrays?

What do you think about deprecating the current MaskedArray behavior then?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ralf has

Once we're more confident that everything is in good shape, see if we should start emitting warnings for input types that may see a change in behavior

under "desirable for later" on the linked tracker. We definitely want to deprecate at some point! But if making this experimental mode the default is going to be a SciPy 2.0 thing (as I think it might be shaping up to be, we still have a long way to go), then maybe there is no rush.

I suppose, no harm in deprecating here in isolation though!

@lucascolley
Copy link
Member

lucascolley commented Mar 19, 2024

I get a different error


Edit: I've just reproduced your error on array-api-compat main without SciPy, will report upstream

@mdhaber
Copy link
Contributor Author

mdhaber commented Mar 19, 2024

Thanks @lucascolley.

Do you know if there's a canonical way to get the finfo/iinfo attributes of non-NumPy dtypes? I'm thinking for array x e.g.

np.finfo(getattr(np, str(x.dtype)))

Sometimes it's needed to get a default numerical tolerance.

Also, for operations that previously have converted to float64 when the input was of integer dtype, do we now just let the operation produce nonsense results, error out, or do the conversion? (And if we're supposed to error out or do the conversion, how do we determine the dtype kind? 'int' in str(x.dtype)?)

@lucascolley
Copy link
Member

Do you know if there's a canonical way to get the finfo/iinfo attributes of non-NumPy dtypes?

They are in the spec, https://data-apis.org/array-api/latest/API_specification/generated/array_api.finfo.html.

Also, for operations that previously have converted to float64 when the input was of integer dtype, do we now just let the operation produce nonsense results, error out, or do the conversion?

I think we just preserve backwards compatibility here. This effort is about preserving array types, not dtypes. (Separately deprecating integer inputs is orthogonal, I remember that happening for some other function recently.)

xp.isdtype(x.dtype, 'integral')

See https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html

@mdhaber
Copy link
Contributor Author

mdhaber commented Mar 20, 2024

Thanks @lucascolley I thought I knew that those weren't present, so I didn't look!

@mdhaber
Copy link
Contributor Author

mdhaber commented Mar 20, 2024

@lucascolley looks like the upstream issue was fixed. Can we just point the submodule to the latest commit?

@lucascolley
Copy link
Member

lucascolley commented Mar 20, 2024

UnboundLocalError: cannot access local variable '_correction'

If you pull in a new commit of array_api_compat@main this should be fixed (maybe want to wait for the next release, which is planned for about ~1 week away, by the sounds of it)

@mdhaber
Copy link
Contributor Author

mdhaber commented Mar 20, 2024

I'll go ahead and do that to get CI working, then we can pull in the released version before merging.

@mdhaber
Copy link
Contributor Author

mdhaber commented Mar 22, 2024

Looks like pulling that in alone doesn't work, probably because of #19900 (comment). Maybe it will be easier when NumPy 2.0 is released.

@lucascolley
Copy link
Member

lucascolley commented Mar 22, 2024

Hmm, I'm pretty sure it should work. No harm in waiting though 🤷‍♂️

probably because of #19900 (comment)

This shouldn't be an issue anymore now that we pull array-api-strict from pip.

@mdhaber
Copy link
Contributor Author

mdhaber commented Mar 22, 2024

Hmm, I'm pretty sure it should work.

I get an error when trying to run tests.

/Users/matthaberland/miniforge3/envs/scipy-dev/bin/python /Applications/PyCharm.app/Contents/plugins/python/helpers/pycharm/_jb_pytest_runner.py --target tests/test_stats.py::TestGeometricStandardDeviation 
Testing started at 8:34 AM ...
Launching pytest with arguments tests/test_stats.py::TestGeometricStandardDeviation --no-header --no-summary -q in /Users/matthaberland/Desktop/scipy/scipy/stats

ImportError while loading conftest '/Users/matthaberland/Desktop/scipy/scipy/conftest.py'.
../conftest.py:15: in <module>
    from scipy._lib._array_api import SCIPY_ARRAY_API, SCIPY_DEVICE
../_lib/_array_api.py:17: in <module>
    from scipy._lib.array_api_compat import (
../_lib/array_api_compat/array_api_compat/numpy/__init__.py:18: in <module>
    __import__(__package__ + '.fft')
E   ModuleNotFoundError: No module named 'scipy._lib.array_api_compat.numpy.fft'

I manually bisected this and found that 02059d0d8f025724c08ff757baa6b48a579651ea doesn't give this error and 0b6ddcd26cac700e6ff9d5fb42d2476f6cc8aac6 does (commit history).

@lucascolley
Copy link
Member

aha, this is meson.build stuff - see my recent changes in gh-20085

@mdhaber
Copy link
Contributor Author

mdhaber commented Mar 22, 2024

I might put this on hold or close and reopen - I think the underlying function and tests need work before this should be merged.

There was a lot of care taken to ensure that the function raised errors with certain input. For example,

from scipy import stats
stats.gstd([1, 2, 3, np.inf])
# ValueError: Infinite value encountered. The geometric standard deviation is defined for strictly positive values only.

According to the error message itself, the geometric standard deviation is defined for strictly positive values. That may need to be adjusted, because inf is strictly positive. More importantly, I think it should just return nan.

This is particularly important when the input is an array: we would not want an infinite value in one slice to cause an error for all slices.

I think we would prefer for the behavior to be consistent with that of gmean:

stats.gmean([1, 2, 3, np.inf])  # inf
stats.gmean([-1, 1, 2, 3])  # nan

which itself follows NumPy's lead in returning infinities or NaNs as appropriate.

np.std([1, 2, 3, np.inf])  # nan

Furthermore, the approach was to try/except RuntimeWarnings (for speed), but the warnings emitted are backend dependent. torch, for example, doesn't emit the warnings the function is looking for.

I think simplifying the implementation would allow us to simplify the tests considerably. Testing a few cases against a reference implementation plus one comprehensive property-based test would do the trick. That would really simplify this PR, too.

I'll work with the rest of the stats folks to decide what we should do here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
array types Items related to array API support and input array validation (see gh-18286) enhancement A new feature or improvement scipy.stats
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants