mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-30 00:30:23 +00:00
Use Ray _BackwardsCompatibleNumpyRng if possible (#421)
This commit is contained in:
parent
4e8b6b98b0
commit
113539545c
@ -35,46 +35,48 @@ except AttributeError:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from ray.tune.sample import _BackwardsCompatibleNumpyRng
|
||||
except ImportError:
|
||||
class _BackwardsCompatibleNumpyRng:
|
||||
"""Thin wrapper to ensure backwards compatibility between
|
||||
new and old numpy randomness generators.
|
||||
"""
|
||||
|
||||
class _BackwardsCompatibleNumpyRng:
|
||||
"""Thin wrapper to ensure backwards compatibility between
|
||||
new and old numpy randomness generators.
|
||||
"""
|
||||
_rng = None
|
||||
|
||||
_rng = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
generator_or_seed: Optional[
|
||||
Union["np_random_generator", np.random.RandomState, int]
|
||||
] = None,
|
||||
):
|
||||
if generator_or_seed is None or isinstance(
|
||||
generator_or_seed, (np.random.RandomState, np_random_generator)
|
||||
def __init__(
|
||||
self,
|
||||
generator_or_seed: Optional[
|
||||
Union["np_random_generator", np.random.RandomState, int]
|
||||
] = None,
|
||||
):
|
||||
self._rng = generator_or_seed
|
||||
elif LEGACY_RNG:
|
||||
self._rng = np.random.RandomState(generator_or_seed)
|
||||
else:
|
||||
self._rng = np.random.default_rng(generator_or_seed)
|
||||
if generator_or_seed is None or isinstance(
|
||||
generator_or_seed, (np.random.RandomState, np_random_generator)
|
||||
):
|
||||
self._rng = generator_or_seed
|
||||
elif LEGACY_RNG:
|
||||
self._rng = np.random.RandomState(generator_or_seed)
|
||||
else:
|
||||
self._rng = np.random.default_rng(generator_or_seed)
|
||||
|
||||
@property
|
||||
def legacy_rng(self) -> bool:
|
||||
return not isinstance(self._rng, np_random_generator)
|
||||
@property
|
||||
def legacy_rng(self) -> bool:
|
||||
return not isinstance(self._rng, np_random_generator)
|
||||
|
||||
@property
|
||||
def rng(self):
|
||||
# don't set self._rng to np.random to avoid picking issues
|
||||
return self._rng if self._rng is not None else np.random
|
||||
@property
|
||||
def rng(self):
|
||||
# don't set self._rng to np.random to avoid picking issues
|
||||
return self._rng if self._rng is not None else np.random
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
# https://numpy.org/doc/stable/reference/random/new-or-different.html
|
||||
if self.legacy_rng:
|
||||
if name == "integers":
|
||||
name = "randint"
|
||||
elif name == "random":
|
||||
name = "rand"
|
||||
return getattr(self.rng, name)
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
# https://numpy.org/doc/stable/reference/random/new-or-different.html
|
||||
if self.legacy_rng:
|
||||
if name == "integers":
|
||||
name = "randint"
|
||||
elif name == "random":
|
||||
name = "rand"
|
||||
return getattr(self.rng, name)
|
||||
|
||||
|
||||
RandomState = Union[
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user