mirror of
https://github.com/microsoft/autogen.git
synced 2025-09-04 13:57:42 +00:00
543 lines
22 KiB
Python
543 lines
22 KiB
Python
try:
|
|
from ray import __version__ as ray_version
|
|
|
|
assert ray_version >= "1.0.0"
|
|
from ray.tune import sample
|
|
from ray.tune.suggest.variant_generator import generate_variants
|
|
except (ImportError, AssertionError):
|
|
from . import sample
|
|
from ..searcher.variant_generator import generate_variants
|
|
from typing import Dict, Optional, Any, Tuple, Generator
|
|
import numpy as np
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def generate_variants_compatible(
|
|
unresolved_spec: Dict, constant_grid_search: bool = False, random_state=None
|
|
) -> Generator[Tuple[Dict, Dict], None, None]:
|
|
try:
|
|
return generate_variants(unresolved_spec, constant_grid_search, random_state)
|
|
except TypeError:
|
|
return generate_variants(unresolved_spec, constant_grid_search)
|
|
|
|
|
|
def define_by_run_func(trial, space: Dict, path: str = "") -> Optional[Dict[str, Any]]:
|
|
"""Define-by-run function to create the search space.
|
|
|
|
Returns:
|
|
A dict with constant values.
|
|
"""
|
|
config = {}
|
|
for key, domain in space.items():
|
|
if path:
|
|
key = path + "/" + key
|
|
if isinstance(domain, dict):
|
|
config.update(define_by_run_func(trial, domain, key))
|
|
continue
|
|
if not isinstance(domain, sample.Domain):
|
|
config[key] = domain
|
|
continue
|
|
sampler = domain.get_sampler()
|
|
quantize = None
|
|
if isinstance(sampler, sample.Quantized):
|
|
quantize = sampler.q
|
|
sampler = sampler.sampler
|
|
if isinstance(sampler, sample.LogUniform):
|
|
logger.warning(
|
|
"Optuna does not handle quantization in loguniform "
|
|
"sampling. The parameter will be passed but it will "
|
|
"probably be ignored."
|
|
)
|
|
if isinstance(domain, sample.Float):
|
|
if isinstance(sampler, sample.LogUniform):
|
|
if quantize:
|
|
logger.warning(
|
|
"Optuna does not support both quantization and "
|
|
"sampling from LogUniform. Dropped quantization."
|
|
)
|
|
trial.suggest_float(key, domain.lower, domain.upper, log=True)
|
|
elif isinstance(sampler, sample.Uniform):
|
|
if quantize:
|
|
trial.suggest_float(key, domain.lower, domain.upper, step=quantize)
|
|
trial.suggest_float(key, domain.lower, domain.upper)
|
|
elif isinstance(domain, sample.Integer):
|
|
if isinstance(sampler, sample.LogUniform):
|
|
trial.suggest_int(
|
|
key, domain.lower, domain.upper - int(bool(not quantize)), log=True
|
|
)
|
|
elif isinstance(sampler, sample.Uniform):
|
|
# Upper bound should be inclusive for quantization and
|
|
# exclusive otherwise
|
|
trial.suggest_int(
|
|
key,
|
|
domain.lower,
|
|
domain.upper - int(bool(not quantize)),
|
|
step=quantize or 1,
|
|
)
|
|
elif isinstance(domain, sample.Categorical):
|
|
if isinstance(sampler, sample.Uniform):
|
|
if not hasattr(domain, "choices"):
|
|
domain.choices = list(range(len(domain.categories)))
|
|
choices = domain.choices
|
|
# This choice needs to be removed from the final config
|
|
index = trial.suggest_categorical(key + "_choice_", choices)
|
|
choice = domain.categories[index]
|
|
if isinstance(choice, dict):
|
|
key += f":{index}"
|
|
# the suffix needs to be removed from the final config
|
|
config.update(define_by_run_func(trial, choice, key))
|
|
else:
|
|
raise ValueError(
|
|
"Optuna search does not support parameters of type "
|
|
"`{}` with samplers of type `{}`".format(
|
|
type(domain).__name__, type(domain.sampler).__name__
|
|
)
|
|
)
|
|
# Return all constants in a dictionary.
|
|
return config
|
|
|
|
|
|
# def convert_key(
|
|
# conf: Dict, space: Dict, path: str = ""
|
|
# ) -> Optional[Dict[str, Any]]:
|
|
# """Convert config keys to define-by-run keys.
|
|
|
|
# Returns:
|
|
# A dict with converted keys.
|
|
# """
|
|
# config = {}
|
|
# for key, domain in space.items():
|
|
# value = conf[key]
|
|
# if path:
|
|
# key = path + '/' + key
|
|
# if isinstance(domain, dict):
|
|
# config.update(convert_key(conf[key], domain, key))
|
|
# elif isinstance(domain, sample.Categorical):
|
|
# index = indexof(domain, value)
|
|
# config[key + '_choice_'] = index
|
|
# if isinstance(value, dict):
|
|
# key += f":{index}"
|
|
# config.update(convert_key(value, domain.categories[index], key))
|
|
# else:
|
|
# config[key] = value
|
|
# return config
|
|
|
|
|
|
def unflatten_hierarchical(config: Dict, space: Dict) -> Tuple[Dict, Dict]:
|
|
"""Unflatten hierarchical config."""
|
|
hier = {}
|
|
subspace = {}
|
|
for key, value in config.items():
|
|
if "/" in key:
|
|
key = key[key.rfind("/") + 1 :]
|
|
if ":" in key:
|
|
pos = key.rfind(":")
|
|
true_key = key[:pos]
|
|
choice = int(key[pos + 1 :])
|
|
hier[true_key], subspace[true_key] = unflatten_hierarchical(
|
|
value, space[true_key][choice]
|
|
)
|
|
else:
|
|
if key.endswith("_choice_"):
|
|
key = key[:-8]
|
|
domain = space.get(key)
|
|
if domain is not None:
|
|
subspace[key] = domain
|
|
if isinstance(domain, sample.Domain):
|
|
sampler = domain.sampler
|
|
if isinstance(domain, sample.Categorical):
|
|
value = domain.categories[value]
|
|
if isinstance(value, dict):
|
|
continue
|
|
elif isinstance(sampler, sample.Quantized):
|
|
q = sampler.q
|
|
sampler = sampler.sampler
|
|
if isinstance(sampler, sample.LogUniform):
|
|
value = domain.cast(np.round(value / q) * q)
|
|
hier[key] = value
|
|
return hier, subspace
|
|
|
|
|
|
def add_cost_to_space(space: Dict, low_cost_point: Dict, choice_cost: Dict):
|
|
"""Update the space in place by adding low_cost_point and choice_cost.
|
|
|
|
Returns:
|
|
A dict with constant values.
|
|
"""
|
|
config = {}
|
|
for key in space:
|
|
domain = space[key]
|
|
if not isinstance(domain, sample.Domain):
|
|
if isinstance(domain, dict):
|
|
low_cost = low_cost_point.get(key, {})
|
|
choice_cost_list = choice_cost.get(key, {})
|
|
const = add_cost_to_space(domain, low_cost, choice_cost_list)
|
|
if const:
|
|
config[key] = const
|
|
else:
|
|
config[key] = domain
|
|
continue
|
|
low_cost = low_cost_point.get(key)
|
|
choice_cost_list = choice_cost.get(key)
|
|
if callable(getattr(domain, "get_sampler", None)):
|
|
sampler = domain.get_sampler()
|
|
if isinstance(sampler, sample.Quantized):
|
|
sampler = sampler.get_sampler()
|
|
domain.bounded = str(sampler) != "Normal"
|
|
if isinstance(domain, sample.Categorical):
|
|
domain.const = []
|
|
for i, cat in enumerate(domain.categories):
|
|
if isinstance(cat, dict):
|
|
if isinstance(low_cost, list):
|
|
low_cost_dict = low_cost[i]
|
|
else:
|
|
low_cost_dict = {}
|
|
if choice_cost_list:
|
|
choice_cost_dict = choice_cost_list[i]
|
|
else:
|
|
choice_cost_dict = {}
|
|
domain.const.append(
|
|
add_cost_to_space(cat, low_cost_dict, choice_cost_dict)
|
|
)
|
|
else:
|
|
domain.const.append(None)
|
|
if choice_cost_list:
|
|
if len(choice_cost_list) == len(domain.categories):
|
|
domain.choice_cost = choice_cost_list
|
|
else:
|
|
domain.choice_cost = choice_cost_list[-1]
|
|
# sort the choices by cost
|
|
cost = np.array(domain.choice_cost)
|
|
ind = np.argsort(cost)
|
|
domain.categories = [domain.categories[i] for i in ind]
|
|
domain.choice_cost = cost[ind]
|
|
domain.const = [domain.const[i] for i in ind]
|
|
domain.ordered = True
|
|
elif all(
|
|
isinstance(x, int) or isinstance(x, float) for x in domain.categories
|
|
):
|
|
# sort the choices by value
|
|
ind = np.argsort(domain.categories)
|
|
domain.categories = [domain.categories[i] for i in ind]
|
|
domain.ordered = True
|
|
else:
|
|
domain.ordered = False
|
|
if low_cost and low_cost not in domain.categories:
|
|
assert isinstance(
|
|
low_cost, list
|
|
), f"low cost {low_cost} not in domain {domain.categories}"
|
|
if domain.ordered:
|
|
sorted_points = [low_cost[i] for i in ind]
|
|
for i, point in enumerate(sorted_points):
|
|
low_cost[i] = point
|
|
if len(low_cost) > len(domain.categories):
|
|
if domain.ordered:
|
|
low_cost[-1] = int(np.where(ind == low_cost[-1])[0])
|
|
domain.low_cost_point = low_cost[-1]
|
|
return
|
|
if low_cost:
|
|
domain.low_cost_point = low_cost
|
|
return config
|
|
|
|
|
|
def normalize(
|
|
config: Dict,
|
|
space: Dict,
|
|
reference_config: Dict,
|
|
normalized_reference_config: Dict,
|
|
recursive: bool = False,
|
|
):
|
|
"""Normalize config in space according to reference_config.
|
|
|
|
Normalize each dimension in config to [0,1].
|
|
"""
|
|
config_norm = {}
|
|
for key, value in config.items():
|
|
domain = space.get(key)
|
|
if domain is None: # e.g., resource_attr
|
|
config_norm[key] = value
|
|
continue
|
|
if not callable(getattr(domain, "get_sampler", None)):
|
|
if recursive and isinstance(domain, dict):
|
|
config_norm[key] = normalize(value, domain, reference_config[key], {})
|
|
else:
|
|
config_norm[key] = value
|
|
continue
|
|
# domain: sample.Categorical/Integer/Float/Function
|
|
if isinstance(domain, sample.Categorical):
|
|
norm = None
|
|
# value is: a category, a nested dict, or a low_cost_point list
|
|
if value not in domain.categories:
|
|
# nested
|
|
if isinstance(value, list):
|
|
# low_cost_point list
|
|
norm = []
|
|
for i, cat in enumerate(domain.categories):
|
|
norm.append(
|
|
normalize(value[i], cat, reference_config[key][i], {})
|
|
if recursive
|
|
else value[i]
|
|
)
|
|
if len(value) > len(domain.categories):
|
|
# the low cost index was appended to low_cost_point list
|
|
index = value[-1]
|
|
value = domain.categories[index]
|
|
elif not recursive:
|
|
# no low cost index. randomly pick one as init point
|
|
continue
|
|
else:
|
|
# nested dict
|
|
config_norm[key] = value
|
|
continue
|
|
# normalize categorical
|
|
n = len(domain.categories)
|
|
if domain.ordered:
|
|
normalized = (domain.categories.index(value) + 0.5) / n
|
|
elif key in normalized_reference_config:
|
|
normalized = (
|
|
normalized_reference_config[key]
|
|
if value == reference_config[key]
|
|
else (normalized_reference_config[key] + 1 / n) % 1
|
|
)
|
|
else:
|
|
normalized = 0.5
|
|
if norm:
|
|
norm.append(normalized)
|
|
else:
|
|
norm = normalized
|
|
config_norm[key] = norm
|
|
continue
|
|
# Uniform/LogUniform/Normal/Base
|
|
sampler = domain.get_sampler()
|
|
if isinstance(sampler, sample.Quantized):
|
|
# sampler is sample.Quantized
|
|
quantize = sampler.q
|
|
sampler = sampler.get_sampler()
|
|
else:
|
|
quantize = None
|
|
if str(sampler) == "LogUniform":
|
|
upper = domain.upper - (
|
|
isinstance(domain, sample.Integer) & (quantize is None)
|
|
)
|
|
config_norm[key] = np.log(value / domain.lower) / np.log(
|
|
upper / domain.lower
|
|
)
|
|
elif str(sampler) == "Uniform":
|
|
upper = domain.upper - (
|
|
isinstance(domain, sample.Integer) & (quantize is None)
|
|
)
|
|
config_norm[key] = (value - domain.lower) / (upper - domain.lower)
|
|
elif str(sampler) == "Normal":
|
|
# N(mean, sd) -> N(0,1)
|
|
config_norm[key] = (value - sampler.mean) / sampler.sd
|
|
# else:
|
|
# config_norm[key] = value
|
|
return config_norm
|
|
|
|
|
|
def denormalize(
|
|
config: Dict,
|
|
space: Dict,
|
|
reference_config: Dict,
|
|
normalized_reference_config: Dict,
|
|
random_state,
|
|
):
|
|
config_denorm = {}
|
|
for key, value in config.items():
|
|
if key in space:
|
|
# domain: sample.Categorical/Integer/Float/Function
|
|
domain = space[key]
|
|
if isinstance(value, dict) or not callable(
|
|
getattr(domain, "get_sampler", None)
|
|
):
|
|
config_denorm[key] = value
|
|
else:
|
|
if isinstance(domain, sample.Categorical):
|
|
# denormalize categorical
|
|
n = len(domain.categories)
|
|
if isinstance(value, list):
|
|
# denormalize list
|
|
choice = int(np.floor(value[-1] * n))
|
|
config_denorm[key] = point = value[choice]
|
|
point["_choice_"] = choice
|
|
continue
|
|
if domain.ordered:
|
|
config_denorm[key] = domain.categories[
|
|
min(n - 1, int(np.floor(value * n)))
|
|
]
|
|
else:
|
|
assert key in normalized_reference_config
|
|
if np.floor(value * n) == np.floor(
|
|
normalized_reference_config[key] * n
|
|
):
|
|
config_denorm[key] = reference_config[key]
|
|
else: # ****random value each time!****
|
|
config_denorm[key] = random_state.choice(
|
|
[
|
|
x
|
|
for x in domain.categories
|
|
if x != reference_config[key]
|
|
]
|
|
)
|
|
continue
|
|
# Uniform/LogUniform/Normal/Base
|
|
sampler = domain.get_sampler()
|
|
if isinstance(sampler, sample.Quantized):
|
|
# sampler is sample.Quantized
|
|
quantize = sampler.q
|
|
sampler = sampler.get_sampler()
|
|
else:
|
|
quantize = None
|
|
# Handle Log/Uniform
|
|
if str(sampler) == "LogUniform":
|
|
upper = domain.upper - (
|
|
isinstance(domain, sample.Integer) & (quantize is None)
|
|
)
|
|
config_denorm[key] = (upper / domain.lower) ** value * domain.lower
|
|
elif str(sampler) == "Uniform":
|
|
upper = domain.upper - (
|
|
isinstance(domain, sample.Integer) & (quantize is None)
|
|
)
|
|
config_denorm[key] = value * (upper - domain.lower) + domain.lower
|
|
elif str(sampler) == "Normal":
|
|
# denormalization for 'Normal'
|
|
config_denorm[key] = value * sampler.sd + sampler.mean
|
|
else:
|
|
config_denorm[key] = value
|
|
# Handle quantized
|
|
if quantize is not None:
|
|
config_denorm[key] = (
|
|
np.round(np.divide(config_denorm[key], quantize)) * quantize
|
|
)
|
|
# Handle int (4.6 -> 5)
|
|
if isinstance(domain, sample.Integer):
|
|
config_denorm[key] = int(round(config_denorm[key]))
|
|
else: # resource_attr
|
|
config_denorm[key] = value
|
|
return config_denorm
|
|
|
|
|
|
def indexof(domain: Dict, config: Dict) -> int:
|
|
"""Find the index of config in domain.categories."""
|
|
index = config.get("_choice_")
|
|
if index is not None:
|
|
return index
|
|
if config in domain.categories:
|
|
return domain.categories.index(config)
|
|
for i, cat in enumerate(domain.categories):
|
|
if not isinstance(cat, dict):
|
|
continue
|
|
# print(len(cat), len(config))
|
|
# if len(cat) != len(config):
|
|
# continue
|
|
# print(cat.keys())
|
|
if not set(config.keys()).issubset(set(cat.keys())):
|
|
continue
|
|
# print(domain.const[i])
|
|
if all(config[key] == value for key, value in domain.const[i].items()):
|
|
# assumption: the concatenation of constants is a unique identifier
|
|
return i
|
|
return None
|
|
|
|
|
|
def complete_config(
|
|
partial_config: Dict,
|
|
space: Dict,
|
|
flow2,
|
|
disturb: bool = False,
|
|
lower: Optional[Dict] = None,
|
|
upper: Optional[Dict] = None,
|
|
) -> Tuple[Dict, Dict]:
|
|
"""Complete partial config in space.
|
|
|
|
Returns:
|
|
config, space.
|
|
"""
|
|
config = partial_config.copy()
|
|
normalized = normalize(config, space, partial_config, {})
|
|
# print("normalized", normalized)
|
|
if disturb:
|
|
for key, value in normalized.items():
|
|
domain = space.get(key)
|
|
if getattr(domain, "ordered", True) is False:
|
|
# don't change unordered cat choice
|
|
continue
|
|
if not callable(getattr(domain, "get_sampler", None)):
|
|
continue
|
|
if upper and lower:
|
|
up, low = upper[key], lower[key]
|
|
if isinstance(up, list):
|
|
gauss_std = (up[-1] - low[-1]) or flow2.STEPSIZE
|
|
up[-1] += flow2.STEPSIZE
|
|
low[-1] -= flow2.STEPSIZE
|
|
else:
|
|
gauss_std = (up - low) or flow2.STEPSIZE
|
|
# allowed bound
|
|
up += flow2.STEPSIZE
|
|
low -= flow2.STEPSIZE
|
|
elif domain.bounded:
|
|
up, low, gauss_std = 1, 0, 1.0
|
|
else:
|
|
up, low, gauss_std = np.Inf, -np.Inf, 1.0
|
|
if domain.bounded:
|
|
if isinstance(up, list):
|
|
up[-1] = min(up[-1], 1)
|
|
low[-1] = max(low[-1], 0)
|
|
else:
|
|
up = min(up, 1)
|
|
low = max(low, 0)
|
|
delta = flow2.rand_vector_gaussian(1, gauss_std)[0]
|
|
if isinstance(value, list):
|
|
# points + normalized index
|
|
value[-1] = max(low[-1], min(up[-1], value[-1] + delta))
|
|
else:
|
|
normalized[key] = max(low, min(up, value + delta))
|
|
config = denormalize(normalized, space, config, normalized, flow2._random)
|
|
# print("denormalized", config)
|
|
for key, value in space.items():
|
|
if key not in config:
|
|
config[key] = value
|
|
for _, generated in generate_variants_compatible(
|
|
{"config": config}, random_state=flow2.rs_random
|
|
):
|
|
config = generated["config"]
|
|
break
|
|
subspace = {}
|
|
for key, domain in space.items():
|
|
value = config[key]
|
|
if isinstance(value, dict):
|
|
if isinstance(domain, sample.Categorical):
|
|
# nested space
|
|
index = indexof(domain, value)
|
|
# point = partial_config.get(key)
|
|
# if isinstance(point, list): # low cost point list
|
|
# point = point[index]
|
|
# else:
|
|
# point = {}
|
|
config[key], subspace[key] = complete_config(
|
|
value,
|
|
domain.categories[index],
|
|
flow2,
|
|
disturb,
|
|
lower and lower[key][index],
|
|
upper and upper[key][index],
|
|
)
|
|
assert (
|
|
"_choice_" not in subspace[key]
|
|
), "_choice_ is a reserved key for hierarchical search space"
|
|
subspace[key]["_choice_"] = index
|
|
else:
|
|
config[key], subspace[key] = complete_config(
|
|
value,
|
|
space[key],
|
|
flow2,
|
|
disturb,
|
|
lower and lower[key],
|
|
upper and upper[key],
|
|
)
|
|
continue
|
|
subspace[key] = domain
|
|
return config, subspace
|