Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: _array_api.Generator: unified RNG interface #20549

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
59 changes: 59 additions & 0 deletions scipy/_lib/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import os
import warnings
import random

import numpy as np

Expand Down Expand Up @@ -354,3 +355,61 @@ def xp_unsupported_param_msg(param):

def is_complex(x, xp):
return xp.isdtype(x.dtype, 'complex floating')


# Random Number Generation

class Generator:
def __new__(cls, xp, seed=None):
if is_numpy(xp) or is_cupy(xp):
# could maybe just return xp.default_rng(seed)
return super().__new__(_Generator_numpy_cupy)
elif is_torch(xp):
return super().__new__(_Generator_torch)
# elif is_jax(xp):
# return super().__new__(_Generator_jax)
Comment on lines +369 to +370
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uncomment after gh-20085 merges.

else:
message = f"`Generator` for {xp} is not implemented."
raise NotImplementedError(message)


class _Generator_jax(Generator):
def __init__(self, xp, seed=None):
min = -2 ** 63
seed = random.randint(min, -min - 1) if seed is None else seed
self._xp = xp
self._key = xp.random.key(seed)

def _next(self):
key, subkey = self._xp.random.split(self._key)
self._key = subkey
return key

def random(self, shape=None, dtype=None):
key = self._next()
shape = () if shape is None else shape
return self._xp.random.uniform(key, shape, dtype=dtype)


class _Generator_torch(Generator):
def __init__(self, xp, seed=None):
rng = xp.Generator()
seed = rng.seed() if seed is None else seed
rng.manual_seed(seed)
self._rng = rng
self._xp = xp

def random(self, shape=None, dtype=None):
shape = () if shape is None else shape
return self._xp.rand(shape, generator=self._rng, dtype=dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm seeing a few new test failures in the array API suite with these changes and torch on the GPU per SCIPY_DEVICE=cuda python dev.py test -j 32 -b all.

FAILED scipy/_lib/tests/test_array_api.py::TestArrayAPI::test_Generator[float32-None-None-torch] - RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'
FAILED scipy/_lib/tests/test_array_api.py::TestArrayAPI::test_Generator[float32-None-0-torch] - RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'
FAILED scipy/_lib/tests/test_array_api.py::TestArrayAPI::test_Generator[float32-shape1-None-torch] - RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'
FAILED scipy/_lib/tests/test_array_api.py::TestArrayAPI::test_Generator[float32-shape1-0-torch] - RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'
FAILED scipy/_lib/tests/test_array_api.py::TestArrayAPI::test_Generator[float64-None-None-torch] - RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'
FAILED scipy/_lib/tests/test_array_api.py::TestArrayAPI::test_Generator[float64-None-0-torch] - RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'
FAILED scipy/_lib/tests/test_array_api.py::TestArrayAPI::test_Generator[float64-shape1-None-torch] - RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'
FAILED scipy/_lib/tests/test_array_api.py::TestArrayAPI::test_Generator[float64-shape1-0-torch] - RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'

A small patch seems to return us to baseline:

--- a/scipy/_lib/_array_api.py
+++ b/scipy/_lib/_array_api.py
@@ -393,7 +393,7 @@ class _Generator_jax(Generator):
 
 class _Generator_torch(Generator):
     def __init__(self, xp, seed=None):
-        rng = xp.Generator()
+        rng = xp.Generator(device=SCIPY_DEVICE)
         seed = rng.seed() if seed is None else seed
         rng.manual_seed(seed)
         self._rng = rng

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, that's why I suggested that we might want to add a device argument. This would be relevant for JAX, too.



class _Generator_numpy_cupy(Generator):
def __init__(self, xp, seed=None):
self._xp = xp
self._rng = xp.random.default_rng(seed)

def random(self, shape=None, dtype=None):
# (for NumPy) ensure output is not Python float
temp = self._rng.random(size=shape, dtype=dtype)
return self._xp.asarray(temp, dtype=dtype)[()]
26 changes: 25 additions & 1 deletion scipy/_lib/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from scipy.conftest import array_api_compatible
from scipy._lib._array_api import (
_GLOBAL_CONFIG, array_namespace, _asarray, copy, xp_assert_equal, is_numpy
_GLOBAL_CONFIG, array_namespace, _asarray, copy, xp_assert_equal, is_numpy,
Generator
)
import scipy._lib.array_api_compat.numpy as np_compat

Expand Down Expand Up @@ -107,3 +108,26 @@ def test_check_scalar(self, xp):
with pytest.raises(AssertionError, match="Types do not match."):
xp_assert_equal(xp.asarray(0.), xp.float64(0))
xp_assert_equal(xp.float64(0), xp.asarray(0.))

@pytest.mark.usefixtures("skip_xp_backends")
@pytest.mark.skip_xp_backends('array_api_strict',
reasons=["`array_api_strict` doesn't have RNG"])
@array_api_compatible
@pytest.mark.parametrize('seed', (None, 0))
@pytest.mark.parametrize('shape', (None, (2, 3)))
@pytest.mark.parametrize('dtype', ('float32', 'float64'))
def test_Generator(self, xp, seed, shape, dtype):
dtype = getattr(xp, dtype)

rng = Generator(xp, seed=seed)
x1 = rng.random(shape, dtype=dtype)
assert x1.dtype == dtype
assert x1.shape == () if shape is None else shape

rng = Generator(xp, seed=seed)
x2 = rng.random(shape, dtype=dtype)

if seed is None:
assert xp.all(x1 != x2)
else:
xp_assert_equal(x1, x2)