mirror of
				https://github.com/allenai/olmocr.git
				synced 2025-10-31 01:55:06 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			123 lines
		
	
	
		
			3.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			123 lines
		
	
	
		
			3.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import re
 | |
| from typing import (
 | |
|     Callable,
 | |
|     Dict,
 | |
|     Generator,
 | |
|     Generic,
 | |
|     Literal,
 | |
|     Optional,
 | |
|     Tuple,
 | |
|     Type,
 | |
|     TypeVar,
 | |
|     overload,
 | |
| )
 | |
| 
 | |
| T = TypeVar("T")
 | |
| R = TypeVar("R")
 | |
| 
 | |
| 
 | |
| class BaseRegistry(Generic[T]):
 | |
|     """A registry for objects."""
 | |
| 
 | |
|     _registry_of_registries: Dict[str, Type["BaseRegistry"]] = {}
 | |
|     _registry_storage: Dict[str, Tuple[T, Optional[str]]]
 | |
| 
 | |
|     @classmethod
 | |
|     def _add_to_registry_of_registries(cls) -> None:
 | |
|         name = cls.__name__
 | |
|         if name not in cls._registry_of_registries:
 | |
|             cls._registry_of_registries[name] = cls
 | |
| 
 | |
|     @classmethod
 | |
|     def registries(cls) -> Generator[Tuple[str, Type["BaseRegistry"]], None, None]:
 | |
|         """Yield all registries in the registry of registries."""
 | |
|         yield from sorted(cls._registry_of_registries.items())
 | |
| 
 | |
|     @classmethod
 | |
|     def _get_storage(cls) -> Dict[str, Tuple[T, Optional[str]]]:
 | |
|         if not hasattr(cls, "_registry_storage"):
 | |
|             cls._registry_storage = {}
 | |
|         return cls._registry_storage  # pyright: ignore
 | |
| 
 | |
|     @classmethod
 | |
|     def items(cls) -> Generator[Tuple[str, T], None, None]:
 | |
|         """Yield all items in the registry."""
 | |
|         yield from sorted((n, t) for (n, (t, _)) in cls._get_storage().items())
 | |
| 
 | |
|     @classmethod
 | |
|     def items_with_description(cls) -> Generator[Tuple[str, T, Optional[str]], None, None]:
 | |
|         """Yield all items in the registry with their descriptions."""
 | |
|         yield from sorted((n, t, d) for (n, (t, d)) in cls._get_storage().items())
 | |
| 
 | |
|     @classmethod
 | |
|     def add(cls, name: str, desc: Optional[str] = None) -> Callable[[R], R]:
 | |
|         """Add a class to the registry."""
 | |
| 
 | |
|         # Add the registry to the registry of registries
 | |
|         cls._add_to_registry_of_registries()
 | |
| 
 | |
|         def _add(
 | |
|             inner_self: T,
 | |
|             inner_name: str = name,
 | |
|             inner_desc: Optional[str] = desc,
 | |
|             inner_cls: Type[BaseRegistry] = cls,
 | |
|         ) -> T:
 | |
|             """Add a tagger to the registry using tagger_name as the name."""
 | |
| 
 | |
|             existing = inner_cls.get(inner_name, raise_on_missing=False)
 | |
| 
 | |
|             if existing and existing != inner_self:
 | |
|                 if inner_self.__module__ == "__main__":
 | |
|                     return inner_self
 | |
| 
 | |
|                 raise ValueError(f"Tagger {inner_name} already exists")
 | |
|             inner_cls._get_storage()[inner_name] = (inner_self, inner_desc)
 | |
|             return inner_self
 | |
| 
 | |
|         return _add  # type: ignore
 | |
| 
 | |
|     @classmethod
 | |
|     def remove(cls, name: str) -> bool:
 | |
|         """Remove a tagger from the registry."""
 | |
|         if name in cls._get_storage():
 | |
|             cls._get_storage().pop(name)
 | |
|             return True
 | |
|         return False
 | |
| 
 | |
|     @classmethod
 | |
|     def has(cls, name: str) -> bool:
 | |
|         """Check if a tagger exists in the registry."""
 | |
|         return name in cls._get_storage()
 | |
| 
 | |
|     @overload
 | |
|     @classmethod
 | |
|     def get(cls, name: str) -> T: ...
 | |
| 
 | |
|     @overload
 | |
|     @classmethod
 | |
|     def get(cls, name: str, raise_on_missing: Literal[True]) -> T: ...
 | |
| 
 | |
|     @overload
 | |
|     @classmethod
 | |
|     def get(cls, name: str, raise_on_missing: Literal[False]) -> Optional[T]: ...
 | |
| 
 | |
|     @classmethod
 | |
|     def get(cls, name: str, raise_on_missing: bool = True) -> Optional[T]:
 | |
|         """Get a tagger from the registry; raise ValueError if it doesn't exist."""
 | |
| 
 | |
|         matches = [registered for registered in cls._get_storage() if re.match(registered, name)]
 | |
| 
 | |
|         if len(matches) > 1:
 | |
|             raise ValueError(f"Multiple taggers match {name}: {', '.join(matches)}")
 | |
| 
 | |
|         elif len(matches) == 0:
 | |
|             if raise_on_missing:
 | |
|                 tagger_names = ", ".join([tn for tn, _ in cls.items()])
 | |
|                 raise ValueError(f"Unknown tagger {name}; available taggers: {tagger_names}")
 | |
|             return None
 | |
| 
 | |
|         else:
 | |
|             name = matches[0]
 | |
|             t, _ = cls._get_storage()[name]
 | |
|             return t
 | 
