mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-12 16:52:20 +00:00
123 lines
3.8 KiB
Python
123 lines
3.8 KiB
Python
import re
|
|
from typing import (
|
|
Callable,
|
|
Dict,
|
|
Generator,
|
|
Generic,
|
|
Literal,
|
|
Optional,
|
|
Tuple,
|
|
Type,
|
|
TypeVar,
|
|
overload,
|
|
)
|
|
|
|
T = TypeVar("T")
|
|
R = TypeVar("R")
|
|
|
|
|
|
class BaseRegistry(Generic[T]):
|
|
"""A registry for objects."""
|
|
|
|
_registry_of_registries: Dict[str, Type["BaseRegistry"]] = {}
|
|
_registry_storage: Dict[str, Tuple[T, Optional[str]]]
|
|
|
|
@classmethod
|
|
def _add_to_registry_of_registries(cls) -> None:
|
|
name = cls.__name__
|
|
if name not in cls._registry_of_registries:
|
|
cls._registry_of_registries[name] = cls
|
|
|
|
@classmethod
|
|
def registries(cls) -> Generator[Tuple[str, Type["BaseRegistry"]], None, None]:
|
|
"""Yield all registries in the registry of registries."""
|
|
yield from sorted(cls._registry_of_registries.items())
|
|
|
|
@classmethod
|
|
def _get_storage(cls) -> Dict[str, Tuple[T, Optional[str]]]:
|
|
if not hasattr(cls, "_registry_storage"):
|
|
cls._registry_storage = {}
|
|
return cls._registry_storage # pyright: ignore
|
|
|
|
@classmethod
|
|
def items(cls) -> Generator[Tuple[str, T], None, None]:
|
|
"""Yield all items in the registry."""
|
|
yield from sorted((n, t) for (n, (t, _)) in cls._get_storage().items())
|
|
|
|
@classmethod
|
|
def items_with_description(cls) -> Generator[Tuple[str, T, Optional[str]], None, None]:
|
|
"""Yield all items in the registry with their descriptions."""
|
|
yield from sorted((n, t, d) for (n, (t, d)) in cls._get_storage().items())
|
|
|
|
@classmethod
|
|
def add(cls, name: str, desc: Optional[str] = None) -> Callable[[R], R]:
|
|
"""Add a class to the registry."""
|
|
|
|
# Add the registry to the registry of registries
|
|
cls._add_to_registry_of_registries()
|
|
|
|
def _add(
|
|
inner_self: T,
|
|
inner_name: str = name,
|
|
inner_desc: Optional[str] = desc,
|
|
inner_cls: Type[BaseRegistry] = cls,
|
|
) -> T:
|
|
"""Add a tagger to the registry using tagger_name as the name."""
|
|
|
|
existing = inner_cls.get(inner_name, raise_on_missing=False)
|
|
|
|
if existing and existing != inner_self:
|
|
if inner_self.__module__ == "__main__":
|
|
return inner_self
|
|
|
|
raise ValueError(f"Tagger {inner_name} already exists")
|
|
inner_cls._get_storage()[inner_name] = (inner_self, inner_desc)
|
|
return inner_self
|
|
|
|
return _add # type: ignore
|
|
|
|
@classmethod
|
|
def remove(cls, name: str) -> bool:
|
|
"""Remove a tagger from the registry."""
|
|
if name in cls._get_storage():
|
|
cls._get_storage().pop(name)
|
|
return True
|
|
return False
|
|
|
|
@classmethod
|
|
def has(cls, name: str) -> bool:
|
|
"""Check if a tagger exists in the registry."""
|
|
return name in cls._get_storage()
|
|
|
|
@overload
|
|
@classmethod
|
|
def get(cls, name: str) -> T: ...
|
|
|
|
@overload
|
|
@classmethod
|
|
def get(cls, name: str, raise_on_missing: Literal[True]) -> T: ...
|
|
|
|
@overload
|
|
@classmethod
|
|
def get(cls, name: str, raise_on_missing: Literal[False]) -> Optional[T]: ...
|
|
|
|
@classmethod
|
|
def get(cls, name: str, raise_on_missing: bool = True) -> Optional[T]:
|
|
"""Get a tagger from the registry; raise ValueError if it doesn't exist."""
|
|
|
|
matches = [registered for registered in cls._get_storage() if re.match(registered, name)]
|
|
|
|
if len(matches) > 1:
|
|
raise ValueError(f"Multiple taggers match {name}: {', '.join(matches)}")
|
|
|
|
elif len(matches) == 0:
|
|
if raise_on_missing:
|
|
tagger_names = ", ".join([tn for tn, _ in cls.items()])
|
|
raise ValueError(f"Unknown tagger {name}; available taggers: {tagger_names}")
|
|
return None
|
|
|
|
else:
|
|
name = matches[0]
|
|
t, _ = cls._get_storage()[name]
|
|
return t
|