mirror of
https://github.com/allenai/olmocr.git
synced 2025-09-03 22:06:16 +00:00
332 lines
12 KiB
Python
332 lines
12 KiB
Python
"""
|
|
Utilities to work with a OmegaConf structured config object
|
|
|
|
From Dolma Toolkit: https://github.com/allenai/dolma/blob/64886d9db15bd99acea9e28740ae20a510875dfb/python/dolma/cli/__init__.py
|
|
|
|
Author: Luca Soldaini (@soldni)
|
|
""" # noqa: E501
|
|
|
|
from argparse import ArgumentParser, Namespace
|
|
from collections.abc import Iterable
|
|
from copy import deepcopy
|
|
from dataclasses import Field
|
|
from dataclasses import field as dataclass_field
|
|
from dataclasses import is_dataclass
|
|
from logging import warning
|
|
from typing import (
|
|
Any,
|
|
Dict,
|
|
Literal,
|
|
Optional,
|
|
Protocol,
|
|
Type,
|
|
TypeVar,
|
|
Union,
|
|
get_args,
|
|
get_origin,
|
|
)
|
|
|
|
import smart_open
|
|
from necessary import necessary
|
|
from omegaconf import MISSING, DictConfig, ListConfig
|
|
from omegaconf import OmegaConf as om
|
|
from omegaconf.errors import OmegaConfBaseException
|
|
from rich.console import Console
|
|
from rich.syntax import Syntax
|
|
from yaml import safe_load
|
|
|
|
from .errors import DolmaRefineError
|
|
|
|
__all__ = ["field", "namespace_to_nested_omegaconf", "print_config", "make_cli", "read_config", "to_native_types"]
|
|
|
|
|
|
T = TypeVar("T", bound=Any)
|
|
D = TypeVar("D", bound="DataClass")
|
|
A = TypeVar("A", bound="ArgumentParser")
|
|
|
|
|
|
def _field_nargs(default: Any) -> Union[Literal["?"], Literal["*"]]:
|
|
# return '+' if _default is iterable but not string/bytes, else 1
|
|
if isinstance(default, (str, bytes)):
|
|
return "?"
|
|
|
|
if isinstance(default, Iterable):
|
|
return "*"
|
|
|
|
return "?"
|
|
|
|
|
|
def field(default: T = MISSING, help: Optional[str] = None, **extra: Any) -> T:
|
|
metadata = {"help": help, "type": type(default), "default": default, "nargs": _field_nargs(default), **extra}
|
|
return dataclass_field(default_factory=lambda: deepcopy(default), metadata=metadata)
|
|
|
|
|
|
class DataClass(Protocol):
|
|
__dataclass_fields__: Dict[str, Field]
|
|
|
|
|
|
def read_config(path: Union[None, str]) -> Dict[str, Any]:
|
|
"""Read a configuration file if it exists"""
|
|
if path is None:
|
|
return {}
|
|
|
|
try:
|
|
with smart_open.open(path, mode="rt") as f:
|
|
return dict(safe_load(f))
|
|
except FileNotFoundError as ex:
|
|
raise DolmaRefineError(f"Config file not found: {path}") from ex
|
|
except Exception as ex:
|
|
raise DolmaRefineError(f"Error while reading config file: {path}") from ex
|
|
|
|
|
|
def save_config(config: Union[dict, DictConfig, list, ListConfig, DataClass], path: str) -> None:
|
|
"""Save a configuration to a file"""
|
|
if isinstance(config, (list, dict)):
|
|
config = om.create(config)
|
|
elif is_dataclass(config):
|
|
config = om.structured(config)
|
|
|
|
with smart_open.open(path, mode="wt") as f:
|
|
f.write(om.to_yaml(config))
|
|
|
|
|
|
def _make_parser(parser: A, config: Type[DataClass], prefix: Optional[str] = None) -> A:
|
|
for field_name, dt_field in config.__dataclass_fields__.items():
|
|
# get type from annotations or metadata
|
|
typ_ = config.__annotations__.get(field_name, dt_field.metadata.get("type", MISSING))
|
|
|
|
if typ_ is MISSING:
|
|
warning(f"No type annotation for field {field_name} in {config.__name__}")
|
|
continue
|
|
|
|
# join prefix and field name
|
|
field_name = f"{prefix}.{field_name}" if prefix else field_name
|
|
|
|
# This section here is to handle Optional[T] types; we only care for cases where T is a dataclass
|
|
# So we first check if type is Union since Optional[T] is just a shorthand for Union[T, None]
|
|
# and that the union contains only one non-None type
|
|
if get_origin(typ_) == Union:
|
|
# get all non-None types
|
|
args = [a for a in get_args(typ_) if a is not type(None)] # noqa: E721
|
|
|
|
if len(args) == 1:
|
|
# simple Optional[T] type
|
|
typ_ = args[0]
|
|
|
|
# here's where we check if T is a dataclass
|
|
if is_dataclass(typ_):
|
|
# recursively add subparsers
|
|
_make_parser(parser, typ_, prefix=field_name)
|
|
continue
|
|
|
|
if typ_ is bool:
|
|
# for boolean values, we add two arguments: --field_name and --no-field_name
|
|
parser.add_argument(
|
|
f"--{field_name}",
|
|
help=dt_field.metadata.get("help"),
|
|
dest=field_name,
|
|
action="store_true",
|
|
default=MISSING,
|
|
)
|
|
parser.add_argument(
|
|
f"--no-{field_name}",
|
|
help=f"Disable {field_name}",
|
|
dest=field_name,
|
|
action="store_false",
|
|
default=MISSING,
|
|
)
|
|
else:
|
|
# else it's just a normal argument
|
|
parser.add_argument(
|
|
f"--{field_name}",
|
|
help=dt_field.metadata.get("help"),
|
|
nargs=dt_field.metadata.get("nargs", "?"),
|
|
default=MISSING,
|
|
)
|
|
|
|
return parser
|
|
|
|
|
|
def make_nested_dict(key: str, value: Any, d: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
|
d = d or {}
|
|
|
|
if "." in key:
|
|
key, rest = key.split(".", 1)
|
|
value = make_nested_dict(rest, value, d.get(key))
|
|
|
|
# the value was provided (is not MISSING constant) and is not an empty dict or list
|
|
if value != MISSING and (not isinstance(value, (dict, list)) or len(value) > 0):
|
|
d[key] = value
|
|
|
|
return d
|
|
|
|
|
|
def to_native_types(
|
|
obj: Any, resolve: bool = True, throw_on_missing: bool = True, enum_to_str: bool = True
|
|
) -> Any:
|
|
"""Converts an OmegaConf object to native types (dicts, lists, etc.)"""
|
|
|
|
# convert dataclass to structured config
|
|
if hasattr(obj, "to_dict"):
|
|
# huggingface objects have a to_dict method, we prefer that
|
|
obj = obj.to_dict()
|
|
elif is_dataclass(obj):
|
|
# we go through structured config instead and hope for the best
|
|
obj = om.to_container(obj)
|
|
|
|
if isinstance(obj, DictConfig) or isinstance(obj, ListConfig):
|
|
obj = om.to_container(obj, resolve=resolve, throw_on_missing=throw_on_missing, enum_to_str=enum_to_str)
|
|
|
|
if isinstance(obj, dict):
|
|
return {k: to_native_types(v) for k, v in obj.items()}
|
|
elif isinstance(obj, list):
|
|
return [to_native_types(v) for v in obj]
|
|
else:
|
|
return obj
|
|
|
|
|
|
def namespace_to_nested_omegaconf(args: Namespace, structured: Type[T], config: Optional[dict] = None) -> T:
|
|
nested_config_dict: Dict[str, Any] = {}
|
|
for key, value in vars(args).items():
|
|
nested_config_dict = make_nested_dict(key, value, nested_config_dict)
|
|
|
|
untyped_config: DictConfig = om.merge(
|
|
om.create(config or {}), om.create(nested_config_dict)
|
|
) # pyright: ignore (pylance is confused because om.create might return a DictConfig or a ListConfig)
|
|
|
|
# resolve any interpolations in the config
|
|
om.resolve(untyped_config)
|
|
|
|
# create structured config from cli dataclass
|
|
base_structured_config: DictConfig = om.structured(structured)
|
|
|
|
# merge with options parsed from config file and
|
|
merged_config = om.merge(base_structured_config, untyped_config)
|
|
|
|
# check for type
|
|
if not isinstance(merged_config, DictConfig):
|
|
raise DolmaRefineError(f"Expected a DictConfig, got {type(merged_config).__name__}")
|
|
|
|
# try resolving all cross references in the config, raise a DolmaConfigError if it fails
|
|
try:
|
|
om.resolve(merged_config)
|
|
except OmegaConfBaseException as ex:
|
|
raise DolmaRefineError(f"Invalid error while parsing key `{ex.full_key}`: {type(ex).__name__}") from ex
|
|
|
|
return merged_config # pyright: ignore
|
|
|
|
|
|
def print_config(config: Any, console: Optional[Console] = None) -> None:
|
|
if not isinstance(config, (DictConfig, ListConfig)):
|
|
config = om.create(config)
|
|
|
|
# print the config as yaml using a rich syntax highlighter
|
|
console = console or Console()
|
|
yaml_config = om.to_yaml(config, sort_keys=True).strip()
|
|
highlighted = Syntax(code=yaml_config, lexer="yaml", theme="ansi_dark")
|
|
console.print(highlighted)
|
|
|
|
|
|
def _patch_old_omegaconf():
|
|
"""Monkey patch omegaconf below version 2.3.0 to support custom resolver returning
|
|
lists or dicts. Applies patch https://github.com/omry/omegaconf/pull/1093"""
|
|
|
|
if necessary(("omegaconf", "2.4.0"), soft=True):
|
|
# no need to patch
|
|
return
|
|
|
|
if getattr(_patch_old_omegaconf, "__patched__", False):
|
|
# already patched
|
|
return
|
|
|
|
from omegaconf import _impl # pylint: disable=import-outside-toplevel
|
|
from omegaconf import ( # pylint: disable=import-outside-toplevel
|
|
Container,
|
|
Node,
|
|
ValueNode,
|
|
)
|
|
from omegaconf._utils import ( # noqa: F401 # pylint: disable=import-outside-toplevel
|
|
_ensure_container,
|
|
_get_value,
|
|
is_primitive_container,
|
|
is_structured_config,
|
|
)
|
|
from omegaconf.errors import ( # pylint: disable=import-outside-toplevel
|
|
InterpolationToMissingValueError,
|
|
)
|
|
from omegaconf.nodes import ( # pylint: disable=import-outside-toplevel
|
|
InterpolationResultNode,
|
|
)
|
|
|
|
def _resolve_container_value(cfg: Container, key: Any) -> None:
|
|
node = cfg._get_child(key) # pylint: disable=protected-access
|
|
assert isinstance(node, Node)
|
|
if node._is_interpolation(): # pylint: disable=protected-access
|
|
try:
|
|
resolved = node._dereference_node() # pylint: disable=protected-access
|
|
except InterpolationToMissingValueError:
|
|
node._set_value(MISSING) # pylint: disable=protected-access
|
|
else:
|
|
if isinstance(resolved, Container):
|
|
_impl._resolve(resolved) # pylint: disable=protected-access
|
|
if isinstance(resolved, InterpolationResultNode):
|
|
resolved_value = _get_value(resolved)
|
|
if is_primitive_container(resolved_value) or is_structured_config(resolved_value):
|
|
resolved = _ensure_container(resolved_value)
|
|
if isinstance(resolved, Container) and isinstance(node, ValueNode):
|
|
cfg[key] = resolved
|
|
else:
|
|
node._set_value(_get_value(resolved)) # pylint: disable=protected-access
|
|
else:
|
|
_impl._resolve(node) # pylint: disable=protected-access
|
|
|
|
# set new function and mark as patched
|
|
setattr(_impl, "_resolve_container_value", _resolve_container_value)
|
|
setattr(_patch_old_omegaconf, "__patched__", True)
|
|
|
|
|
|
# actually executes the patch
|
|
_patch_old_omegaconf()
|
|
|
|
|
|
def make_cli(config_cls: Type[D], _config_flag: str = "config", _dryrun_flag: str = "dryrun") -> D:
|
|
"""Create a CLI parser for a dataclass and parse the arguments into a structured config object."""
|
|
|
|
if hasattr(config_cls, _config_flag):
|
|
raise DolmaRefineError(f"`{_config_flag}` is a reserved attribute; remove it from `{config_cls.__name__}`")
|
|
|
|
if hasattr(config_cls, _dryrun_flag):
|
|
raise DolmaRefineError(f"`{_dryrun_flag}` is a reserved attribute; remove it from `{config_cls.__name__}`")
|
|
|
|
parser = ArgumentParser()
|
|
parser.add_argument(
|
|
f"-{_config_flag[0]}", f"--{_config_flag}", help="Path to config file", default=None, type=str
|
|
)
|
|
parser.add_argument(
|
|
f"-{_dryrun_flag[0]}",
|
|
f"--{_dryrun_flag}",
|
|
help="Dry run mode: print config and exit",
|
|
action="store_true",
|
|
default=False,
|
|
)
|
|
|
|
parser = _make_parser(parser, config_cls)
|
|
args = parser.parse_args()
|
|
|
|
parsed_config: Dict[str, Any] = {}
|
|
if (config_path := getattr(args, _config_flag)) is not None:
|
|
parsed_config = read_config(config_path)
|
|
delattr(args, _config_flag)
|
|
|
|
only_dryrun = getattr(args, _dryrun_flag, False)
|
|
delattr(args, _dryrun_flag)
|
|
|
|
full_config = namespace_to_nested_omegaconf(args, config_cls, parsed_config)
|
|
|
|
print_config(full_config)
|
|
|
|
if only_dryrun:
|
|
exit(0)
|
|
|
|
return full_config
|