autogen/flaml/tune/space.py
Chi Wang 0b25e89f29
reproducibility for random sampling (#349)
* reproducibility for random sampling #236

* doc update
2021-12-22 12:12:25 -08:00

543 lines
22 KiB
Python

try:
from ray import __version__ as ray_version
assert ray_version >= "1.0.0"
from ray.tune import sample
from ray.tune.suggest.variant_generator import generate_variants
except (ImportError, AssertionError):
from . import sample
from ..searcher.variant_generator import generate_variants
from typing import Dict, Optional, Any, Tuple, Generator
import numpy as np
import logging
logger = logging.getLogger(__name__)
def generate_variants_compatible(
unresolved_spec: Dict, constant_grid_search: bool = False, random_state=None
) -> Generator[Tuple[Dict, Dict], None, None]:
try:
return generate_variants(unresolved_spec, constant_grid_search, random_state)
except TypeError:
return generate_variants(unresolved_spec, constant_grid_search)
def define_by_run_func(trial, space: Dict, path: str = "") -> Optional[Dict[str, Any]]:
"""Define-by-run function to create the search space.
Returns:
A dict with constant values.
"""
config = {}
for key, domain in space.items():
if path:
key = path + "/" + key
if isinstance(domain, dict):
config.update(define_by_run_func(trial, domain, key))
continue
if not isinstance(domain, sample.Domain):
config[key] = domain
continue
sampler = domain.get_sampler()
quantize = None
if isinstance(sampler, sample.Quantized):
quantize = sampler.q
sampler = sampler.sampler
if isinstance(sampler, sample.LogUniform):
logger.warning(
"Optuna does not handle quantization in loguniform "
"sampling. The parameter will be passed but it will "
"probably be ignored."
)
if isinstance(domain, sample.Float):
if isinstance(sampler, sample.LogUniform):
if quantize:
logger.warning(
"Optuna does not support both quantization and "
"sampling from LogUniform. Dropped quantization."
)
trial.suggest_float(key, domain.lower, domain.upper, log=True)
elif isinstance(sampler, sample.Uniform):
if quantize:
trial.suggest_float(key, domain.lower, domain.upper, step=quantize)
trial.suggest_float(key, domain.lower, domain.upper)
elif isinstance(domain, sample.Integer):
if isinstance(sampler, sample.LogUniform):
trial.suggest_int(
key, domain.lower, domain.upper - int(bool(not quantize)), log=True
)
elif isinstance(sampler, sample.Uniform):
# Upper bound should be inclusive for quantization and
# exclusive otherwise
trial.suggest_int(
key,
domain.lower,
domain.upper - int(bool(not quantize)),
step=quantize or 1,
)
elif isinstance(domain, sample.Categorical):
if isinstance(sampler, sample.Uniform):
if not hasattr(domain, "choices"):
domain.choices = list(range(len(domain.categories)))
choices = domain.choices
# This choice needs to be removed from the final config
index = trial.suggest_categorical(key + "_choice_", choices)
choice = domain.categories[index]
if isinstance(choice, dict):
key += f":{index}"
# the suffix needs to be removed from the final config
config.update(define_by_run_func(trial, choice, key))
else:
raise ValueError(
"Optuna search does not support parameters of type "
"`{}` with samplers of type `{}`".format(
type(domain).__name__, type(domain.sampler).__name__
)
)
# Return all constants in a dictionary.
return config
# def convert_key(
# conf: Dict, space: Dict, path: str = ""
# ) -> Optional[Dict[str, Any]]:
# """Convert config keys to define-by-run keys.
# Returns:
# A dict with converted keys.
# """
# config = {}
# for key, domain in space.items():
# value = conf[key]
# if path:
# key = path + '/' + key
# if isinstance(domain, dict):
# config.update(convert_key(conf[key], domain, key))
# elif isinstance(domain, sample.Categorical):
# index = indexof(domain, value)
# config[key + '_choice_'] = index
# if isinstance(value, dict):
# key += f":{index}"
# config.update(convert_key(value, domain.categories[index], key))
# else:
# config[key] = value
# return config
def unflatten_hierarchical(config: Dict, space: Dict) -> Tuple[Dict, Dict]:
"""Unflatten hierarchical config."""
hier = {}
subspace = {}
for key, value in config.items():
if "/" in key:
key = key[key.rfind("/") + 1 :]
if ":" in key:
pos = key.rfind(":")
true_key = key[:pos]
choice = int(key[pos + 1 :])
hier[true_key], subspace[true_key] = unflatten_hierarchical(
value, space[true_key][choice]
)
else:
if key.endswith("_choice_"):
key = key[:-8]
domain = space.get(key)
if domain is not None:
subspace[key] = domain
if isinstance(domain, sample.Domain):
sampler = domain.sampler
if isinstance(domain, sample.Categorical):
value = domain.categories[value]
if isinstance(value, dict):
continue
elif isinstance(sampler, sample.Quantized):
q = sampler.q
sampler = sampler.sampler
if isinstance(sampler, sample.LogUniform):
value = domain.cast(np.round(value / q) * q)
hier[key] = value
return hier, subspace
def add_cost_to_space(space: Dict, low_cost_point: Dict, choice_cost: Dict):
"""Update the space in place by adding low_cost_point and choice_cost.
Returns:
A dict with constant values.
"""
config = {}
for key in space:
domain = space[key]
if not isinstance(domain, sample.Domain):
if isinstance(domain, dict):
low_cost = low_cost_point.get(key, {})
choice_cost_list = choice_cost.get(key, {})
const = add_cost_to_space(domain, low_cost, choice_cost_list)
if const:
config[key] = const
else:
config[key] = domain
continue
low_cost = low_cost_point.get(key)
choice_cost_list = choice_cost.get(key)
if callable(getattr(domain, "get_sampler", None)):
sampler = domain.get_sampler()
if isinstance(sampler, sample.Quantized):
sampler = sampler.get_sampler()
domain.bounded = str(sampler) != "Normal"
if isinstance(domain, sample.Categorical):
domain.const = []
for i, cat in enumerate(domain.categories):
if isinstance(cat, dict):
if isinstance(low_cost, list):
low_cost_dict = low_cost[i]
else:
low_cost_dict = {}
if choice_cost_list:
choice_cost_dict = choice_cost_list[i]
else:
choice_cost_dict = {}
domain.const.append(
add_cost_to_space(cat, low_cost_dict, choice_cost_dict)
)
else:
domain.const.append(None)
if choice_cost_list:
if len(choice_cost_list) == len(domain.categories):
domain.choice_cost = choice_cost_list
else:
domain.choice_cost = choice_cost_list[-1]
# sort the choices by cost
cost = np.array(domain.choice_cost)
ind = np.argsort(cost)
domain.categories = [domain.categories[i] for i in ind]
domain.choice_cost = cost[ind]
domain.const = [domain.const[i] for i in ind]
domain.ordered = True
elif all(
isinstance(x, int) or isinstance(x, float) for x in domain.categories
):
# sort the choices by value
ind = np.argsort(domain.categories)
domain.categories = [domain.categories[i] for i in ind]
domain.ordered = True
else:
domain.ordered = False
if low_cost and low_cost not in domain.categories:
assert isinstance(
low_cost, list
), f"low cost {low_cost} not in domain {domain.categories}"
if domain.ordered:
sorted_points = [low_cost[i] for i in ind]
for i, point in enumerate(sorted_points):
low_cost[i] = point
if len(low_cost) > len(domain.categories):
if domain.ordered:
low_cost[-1] = int(np.where(ind == low_cost[-1])[0])
domain.low_cost_point = low_cost[-1]
return
if low_cost:
domain.low_cost_point = low_cost
return config
def normalize(
config: Dict,
space: Dict,
reference_config: Dict,
normalized_reference_config: Dict,
recursive: bool = False,
):
"""Normalize config in space according to reference_config.
Normalize each dimension in config to [0,1].
"""
config_norm = {}
for key, value in config.items():
domain = space.get(key)
if domain is None: # e.g., resource_attr
config_norm[key] = value
continue
if not callable(getattr(domain, "get_sampler", None)):
if recursive and isinstance(domain, dict):
config_norm[key] = normalize(value, domain, reference_config[key], {})
else:
config_norm[key] = value
continue
# domain: sample.Categorical/Integer/Float/Function
if isinstance(domain, sample.Categorical):
norm = None
# value is: a category, a nested dict, or a low_cost_point list
if value not in domain.categories:
# nested
if isinstance(value, list):
# low_cost_point list
norm = []
for i, cat in enumerate(domain.categories):
norm.append(
normalize(value[i], cat, reference_config[key][i], {})
if recursive
else value[i]
)
if len(value) > len(domain.categories):
# the low cost index was appended to low_cost_point list
index = value[-1]
value = domain.categories[index]
elif not recursive:
# no low cost index. randomly pick one as init point
continue
else:
# nested dict
config_norm[key] = value
continue
# normalize categorical
n = len(domain.categories)
if domain.ordered:
normalized = (domain.categories.index(value) + 0.5) / n
elif key in normalized_reference_config:
normalized = (
normalized_reference_config[key]
if value == reference_config[key]
else (normalized_reference_config[key] + 1 / n) % 1
)
else:
normalized = 0.5
if norm:
norm.append(normalized)
else:
norm = normalized
config_norm[key] = norm
continue
# Uniform/LogUniform/Normal/Base
sampler = domain.get_sampler()
if isinstance(sampler, sample.Quantized):
# sampler is sample.Quantized
quantize = sampler.q
sampler = sampler.get_sampler()
else:
quantize = None
if str(sampler) == "LogUniform":
upper = domain.upper - (
isinstance(domain, sample.Integer) & (quantize is None)
)
config_norm[key] = np.log(value / domain.lower) / np.log(
upper / domain.lower
)
elif str(sampler) == "Uniform":
upper = domain.upper - (
isinstance(domain, sample.Integer) & (quantize is None)
)
config_norm[key] = (value - domain.lower) / (upper - domain.lower)
elif str(sampler) == "Normal":
# N(mean, sd) -> N(0,1)
config_norm[key] = (value - sampler.mean) / sampler.sd
# else:
# config_norm[key] = value
return config_norm
def denormalize(
config: Dict,
space: Dict,
reference_config: Dict,
normalized_reference_config: Dict,
random_state,
):
config_denorm = {}
for key, value in config.items():
if key in space:
# domain: sample.Categorical/Integer/Float/Function
domain = space[key]
if isinstance(value, dict) or not callable(
getattr(domain, "get_sampler", None)
):
config_denorm[key] = value
else:
if isinstance(domain, sample.Categorical):
# denormalize categorical
n = len(domain.categories)
if isinstance(value, list):
# denormalize list
choice = int(np.floor(value[-1] * n))
config_denorm[key] = point = value[choice]
point["_choice_"] = choice
continue
if domain.ordered:
config_denorm[key] = domain.categories[
min(n - 1, int(np.floor(value * n)))
]
else:
assert key in normalized_reference_config
if np.floor(value * n) == np.floor(
normalized_reference_config[key] * n
):
config_denorm[key] = reference_config[key]
else: # ****random value each time!****
config_denorm[key] = random_state.choice(
[
x
for x in domain.categories
if x != reference_config[key]
]
)
continue
# Uniform/LogUniform/Normal/Base
sampler = domain.get_sampler()
if isinstance(sampler, sample.Quantized):
# sampler is sample.Quantized
quantize = sampler.q
sampler = sampler.get_sampler()
else:
quantize = None
# Handle Log/Uniform
if str(sampler) == "LogUniform":
upper = domain.upper - (
isinstance(domain, sample.Integer) & (quantize is None)
)
config_denorm[key] = (upper / domain.lower) ** value * domain.lower
elif str(sampler) == "Uniform":
upper = domain.upper - (
isinstance(domain, sample.Integer) & (quantize is None)
)
config_denorm[key] = value * (upper - domain.lower) + domain.lower
elif str(sampler) == "Normal":
# denormalization for 'Normal'
config_denorm[key] = value * sampler.sd + sampler.mean
else:
config_denorm[key] = value
# Handle quantized
if quantize is not None:
config_denorm[key] = (
np.round(np.divide(config_denorm[key], quantize)) * quantize
)
# Handle int (4.6 -> 5)
if isinstance(domain, sample.Integer):
config_denorm[key] = int(round(config_denorm[key]))
else: # resource_attr
config_denorm[key] = value
return config_denorm
def indexof(domain: Dict, config: Dict) -> int:
"""Find the index of config in domain.categories."""
index = config.get("_choice_")
if index is not None:
return index
if config in domain.categories:
return domain.categories.index(config)
for i, cat in enumerate(domain.categories):
if not isinstance(cat, dict):
continue
# print(len(cat), len(config))
# if len(cat) != len(config):
# continue
# print(cat.keys())
if not set(config.keys()).issubset(set(cat.keys())):
continue
# print(domain.const[i])
if all(config[key] == value for key, value in domain.const[i].items()):
# assumption: the concatenation of constants is a unique identifier
return i
return None
def complete_config(
partial_config: Dict,
space: Dict,
flow2,
disturb: bool = False,
lower: Optional[Dict] = None,
upper: Optional[Dict] = None,
) -> Tuple[Dict, Dict]:
"""Complete partial config in space.
Returns:
config, space.
"""
config = partial_config.copy()
normalized = normalize(config, space, partial_config, {})
# print("normalized", normalized)
if disturb:
for key, value in normalized.items():
domain = space.get(key)
if getattr(domain, "ordered", True) is False:
# don't change unordered cat choice
continue
if not callable(getattr(domain, "get_sampler", None)):
continue
if upper and lower:
up, low = upper[key], lower[key]
if isinstance(up, list):
gauss_std = (up[-1] - low[-1]) or flow2.STEPSIZE
up[-1] += flow2.STEPSIZE
low[-1] -= flow2.STEPSIZE
else:
gauss_std = (up - low) or flow2.STEPSIZE
# allowed bound
up += flow2.STEPSIZE
low -= flow2.STEPSIZE
elif domain.bounded:
up, low, gauss_std = 1, 0, 1.0
else:
up, low, gauss_std = np.Inf, -np.Inf, 1.0
if domain.bounded:
if isinstance(up, list):
up[-1] = min(up[-1], 1)
low[-1] = max(low[-1], 0)
else:
up = min(up, 1)
low = max(low, 0)
delta = flow2.rand_vector_gaussian(1, gauss_std)[0]
if isinstance(value, list):
# points + normalized index
value[-1] = max(low[-1], min(up[-1], value[-1] + delta))
else:
normalized[key] = max(low, min(up, value + delta))
config = denormalize(normalized, space, config, normalized, flow2._random)
# print("denormalized", config)
for key, value in space.items():
if key not in config:
config[key] = value
for _, generated in generate_variants_compatible(
{"config": config}, random_state=flow2.rs_random
):
config = generated["config"]
break
subspace = {}
for key, domain in space.items():
value = config[key]
if isinstance(value, dict):
if isinstance(domain, sample.Categorical):
# nested space
index = indexof(domain, value)
# point = partial_config.get(key)
# if isinstance(point, list): # low cost point list
# point = point[index]
# else:
# point = {}
config[key], subspace[key] = complete_config(
value,
domain.categories[index],
flow2,
disturb,
lower and lower[key][index],
upper and upper[key][index],
)
assert (
"_choice_" not in subspace[key]
), "_choice_ is a reserved key for hierarchical search space"
subspace[key]["_choice_"] = index
else:
config[key], subspace[key] = complete_config(
value,
space[key],
flow2,
disturb,
lower and lower[key],
upper and upper[key],
)
continue
subspace[key] = domain
return config, subspace