Use Ray _BackwardsCompatibleNumpyRng if possible (#421)

This commit is contained in:
Antoni Baum 2022-01-23 07:14:49 +01:00 committed by GitHub
parent 4e8b6b98b0
commit 113539545c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -35,46 +35,48 @@ except AttributeError:
logger = logging.getLogger(__name__)
try:
from ray.tune.sample import _BackwardsCompatibleNumpyRng
except ImportError:
class _BackwardsCompatibleNumpyRng:
"""Thin wrapper to ensure backwards compatibility between
new and old numpy randomness generators.
"""
class _BackwardsCompatibleNumpyRng:
"""Thin wrapper to ensure backwards compatibility between
new and old numpy randomness generators.
"""
_rng = None
_rng = None
def __init__(
self,
generator_or_seed: Optional[
Union["np_random_generator", np.random.RandomState, int]
] = None,
):
if generator_or_seed is None or isinstance(
generator_or_seed, (np.random.RandomState, np_random_generator)
def __init__(
self,
generator_or_seed: Optional[
Union["np_random_generator", np.random.RandomState, int]
] = None,
):
self._rng = generator_or_seed
elif LEGACY_RNG:
self._rng = np.random.RandomState(generator_or_seed)
else:
self._rng = np.random.default_rng(generator_or_seed)
if generator_or_seed is None or isinstance(
generator_or_seed, (np.random.RandomState, np_random_generator)
):
self._rng = generator_or_seed
elif LEGACY_RNG:
self._rng = np.random.RandomState(generator_or_seed)
else:
self._rng = np.random.default_rng(generator_or_seed)
@property
def legacy_rng(self) -> bool:
return not isinstance(self._rng, np_random_generator)
@property
def legacy_rng(self) -> bool:
return not isinstance(self._rng, np_random_generator)
@property
def rng(self):
# don't set self._rng to np.random to avoid picking issues
return self._rng if self._rng is not None else np.random
@property
def rng(self):
# don't set self._rng to np.random to avoid picking issues
return self._rng if self._rng is not None else np.random
def __getattr__(self, name: str) -> Any:
# https://numpy.org/doc/stable/reference/random/new-or-different.html
if self.legacy_rng:
if name == "integers":
name = "randint"
elif name == "random":
name = "rand"
return getattr(self.rng, name)
def __getattr__(self, name: str) -> Any:
# https://numpy.org/doc/stable/reference/random/new-or-different.html
if self.legacy_rng:
if name == "integers":
name = "randint"
elif name == "random":
name = "rand"
return getattr(self.rng, name)
RandomState = Union[