mirror of
				https://github.com/allenai/olmocr.git
				synced 2025-11-04 03:56:16 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			123 lines
		
	
	
		
			3.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			123 lines
		
	
	
		
			3.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import re
 | 
						|
from typing import (
 | 
						|
    Callable,
 | 
						|
    Dict,
 | 
						|
    Generator,
 | 
						|
    Generic,
 | 
						|
    Literal,
 | 
						|
    Optional,
 | 
						|
    Tuple,
 | 
						|
    Type,
 | 
						|
    TypeVar,
 | 
						|
    overload,
 | 
						|
)
 | 
						|
 | 
						|
T = TypeVar("T")
 | 
						|
R = TypeVar("R")
 | 
						|
 | 
						|
 | 
						|
class BaseRegistry(Generic[T]):
 | 
						|
    """A registry for objects."""
 | 
						|
 | 
						|
    _registry_of_registries: Dict[str, Type["BaseRegistry"]] = {}
 | 
						|
    _registry_storage: Dict[str, Tuple[T, Optional[str]]]
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def _add_to_registry_of_registries(cls) -> None:
 | 
						|
        name = cls.__name__
 | 
						|
        if name not in cls._registry_of_registries:
 | 
						|
            cls._registry_of_registries[name] = cls
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def registries(cls) -> Generator[Tuple[str, Type["BaseRegistry"]], None, None]:
 | 
						|
        """Yield all registries in the registry of registries."""
 | 
						|
        yield from sorted(cls._registry_of_registries.items())
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def _get_storage(cls) -> Dict[str, Tuple[T, Optional[str]]]:
 | 
						|
        if not hasattr(cls, "_registry_storage"):
 | 
						|
            cls._registry_storage = {}
 | 
						|
        return cls._registry_storage  # pyright: ignore
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def items(cls) -> Generator[Tuple[str, T], None, None]:
 | 
						|
        """Yield all items in the registry."""
 | 
						|
        yield from sorted((n, t) for (n, (t, _)) in cls._get_storage().items())
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def items_with_description(cls) -> Generator[Tuple[str, T, Optional[str]], None, None]:
 | 
						|
        """Yield all items in the registry with their descriptions."""
 | 
						|
        yield from sorted((n, t, d) for (n, (t, d)) in cls._get_storage().items())
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def add(cls, name: str, desc: Optional[str] = None) -> Callable[[R], R]:
 | 
						|
        """Add a class to the registry."""
 | 
						|
 | 
						|
        # Add the registry to the registry of registries
 | 
						|
        cls._add_to_registry_of_registries()
 | 
						|
 | 
						|
        def _add(
 | 
						|
            inner_self: T,
 | 
						|
            inner_name: str = name,
 | 
						|
            inner_desc: Optional[str] = desc,
 | 
						|
            inner_cls: Type[BaseRegistry] = cls,
 | 
						|
        ) -> T:
 | 
						|
            """Add a tagger to the registry using tagger_name as the name."""
 | 
						|
 | 
						|
            existing = inner_cls.get(inner_name, raise_on_missing=False)
 | 
						|
 | 
						|
            if existing and existing != inner_self:
 | 
						|
                if inner_self.__module__ == "__main__":
 | 
						|
                    return inner_self
 | 
						|
 | 
						|
                raise ValueError(f"Tagger {inner_name} already exists")
 | 
						|
            inner_cls._get_storage()[inner_name] = (inner_self, inner_desc)
 | 
						|
            return inner_self
 | 
						|
 | 
						|
        return _add  # type: ignore
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def remove(cls, name: str) -> bool:
 | 
						|
        """Remove a tagger from the registry."""
 | 
						|
        if name in cls._get_storage():
 | 
						|
            cls._get_storage().pop(name)
 | 
						|
            return True
 | 
						|
        return False
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def has(cls, name: str) -> bool:
 | 
						|
        """Check if a tagger exists in the registry."""
 | 
						|
        return name in cls._get_storage()
 | 
						|
 | 
						|
    @overload
 | 
						|
    @classmethod
 | 
						|
    def get(cls, name: str) -> T: ...
 | 
						|
 | 
						|
    @overload
 | 
						|
    @classmethod
 | 
						|
    def get(cls, name: str, raise_on_missing: Literal[True]) -> T: ...
 | 
						|
 | 
						|
    @overload
 | 
						|
    @classmethod
 | 
						|
    def get(cls, name: str, raise_on_missing: Literal[False]) -> Optional[T]: ...
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def get(cls, name: str, raise_on_missing: bool = True) -> Optional[T]:
 | 
						|
        """Get a tagger from the registry; raise ValueError if it doesn't exist."""
 | 
						|
 | 
						|
        matches = [registered for registered in cls._get_storage() if re.match(registered, name)]
 | 
						|
 | 
						|
        if len(matches) > 1:
 | 
						|
            raise ValueError(f"Multiple taggers match {name}: {', '.join(matches)}")
 | 
						|
 | 
						|
        elif len(matches) == 0:
 | 
						|
            if raise_on_missing:
 | 
						|
                tagger_names = ", ".join([tn for tn, _ in cls.items()])
 | 
						|
                raise ValueError(f"Unknown tagger {name}; available taggers: {tagger_names}")
 | 
						|
            return None
 | 
						|
 | 
						|
        else:
 | 
						|
            name = matches[0]
 | 
						|
            t, _ = cls._get_storage()[name]
 | 
						|
            return t
 |