| 
									
										
										
										
											2025-02-26 17:49:04 +00:00
										 |  |  | import re | 
					
						
							|  |  |  | from typing import ( | 
					
						
							|  |  |  |     Callable, | 
					
						
							|  |  |  |     Dict, | 
					
						
							|  |  |  |     Generator, | 
					
						
							|  |  |  |     Generic, | 
					
						
							|  |  |  |     Literal, | 
					
						
							|  |  |  |     Optional, | 
					
						
							|  |  |  |     Tuple, | 
					
						
							|  |  |  |     Type, | 
					
						
							|  |  |  |     TypeVar, | 
					
						
							|  |  |  |     overload, | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | T = TypeVar("T") | 
					
						
							|  |  |  | R = TypeVar("R") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class BaseRegistry(Generic[T]): | 
					
						
							|  |  |  |     """A registry for objects.""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     _registry_of_registries: Dict[str, Type["BaseRegistry"]] = {} | 
					
						
							|  |  |  |     _registry_storage: Dict[str, Tuple[T, Optional[str]]] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def _add_to_registry_of_registries(cls) -> None: | 
					
						
							|  |  |  |         name = cls.__name__ | 
					
						
							|  |  |  |         if name not in cls._registry_of_registries: | 
					
						
							|  |  |  |             cls._registry_of_registries[name] = cls | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def registries(cls) -> Generator[Tuple[str, Type["BaseRegistry"]], None, None]: | 
					
						
							|  |  |  |         """Yield all registries in the registry of registries.""" | 
					
						
							|  |  |  |         yield from sorted(cls._registry_of_registries.items()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def _get_storage(cls) -> Dict[str, Tuple[T, Optional[str]]]: | 
					
						
							|  |  |  |         if not hasattr(cls, "_registry_storage"): | 
					
						
							|  |  |  |             cls._registry_storage = {} | 
					
						
							|  |  |  |         return cls._registry_storage  # pyright: ignore | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def items(cls) -> Generator[Tuple[str, T], None, None]: | 
					
						
							|  |  |  |         """Yield all items in the registry.""" | 
					
						
							|  |  |  |         yield from sorted((n, t) for (n, (t, _)) in cls._get_storage().items()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def items_with_description(cls) -> Generator[Tuple[str, T, Optional[str]], None, None]: | 
					
						
							|  |  |  |         """Yield all items in the registry with their descriptions.""" | 
					
						
							|  |  |  |         yield from sorted((n, t, d) for (n, (t, d)) in cls._get_storage().items()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def add(cls, name: str, desc: Optional[str] = None) -> Callable[[R], R]: | 
					
						
							|  |  |  |         """Add a class to the registry.""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Add the registry to the registry of registries | 
					
						
							|  |  |  |         cls._add_to_registry_of_registries() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def _add( | 
					
						
							|  |  |  |             inner_self: T, | 
					
						
							|  |  |  |             inner_name: str = name, | 
					
						
							|  |  |  |             inner_desc: Optional[str] = desc, | 
					
						
							|  |  |  |             inner_cls: Type[BaseRegistry] = cls, | 
					
						
							|  |  |  |         ) -> T: | 
					
						
							|  |  |  |             """Add a tagger to the registry using tagger_name as the name.""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             existing = inner_cls.get(inner_name, raise_on_missing=False) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if existing and existing != inner_self: | 
					
						
							|  |  |  |                 if inner_self.__module__ == "__main__": | 
					
						
							|  |  |  |                     return inner_self | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 raise ValueError(f"Tagger {inner_name} already exists") | 
					
						
							|  |  |  |             inner_cls._get_storage()[inner_name] = (inner_self, inner_desc) | 
					
						
							|  |  |  |             return inner_self | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return _add  # type: ignore | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def remove(cls, name: str) -> bool: | 
					
						
							|  |  |  |         """Remove a tagger from the registry.""" | 
					
						
							|  |  |  |         if name in cls._get_storage(): | 
					
						
							|  |  |  |             cls._get_storage().pop(name) | 
					
						
							|  |  |  |             return True | 
					
						
							|  |  |  |         return False | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def has(cls, name: str) -> bool: | 
					
						
							|  |  |  |         """Check if a tagger exists in the registry.""" | 
					
						
							|  |  |  |         return name in cls._get_storage() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @overload | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def get(cls, name: str) -> T: ... | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @overload | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def get(cls, name: str, raise_on_missing: Literal[True]) -> T: ... | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @overload | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def get(cls, name: str, raise_on_missing: Literal[False]) -> Optional[T]: ... | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def get(cls, name: str, raise_on_missing: bool = True) -> Optional[T]: | 
					
						
							|  |  |  |         """Get a tagger from the registry; raise ValueError if it doesn't exist.""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         matches = [registered for registered in cls._get_storage() if re.match(registered, name)] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if len(matches) > 1: | 
					
						
							|  |  |  |             raise ValueError(f"Multiple taggers match {name}: {', '.join(matches)}") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         elif len(matches) == 0: | 
					
						
							|  |  |  |             if raise_on_missing: | 
					
						
							|  |  |  |                 tagger_names = ", ".join([tn for tn, _ in cls.items()]) | 
					
						
							|  |  |  |                 raise ValueError(f"Unknown tagger {name}; available taggers: {tagger_names}") | 
					
						
							|  |  |  |             return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             name = matches[0] | 
					
						
							|  |  |  |             t, _ = cls._get_storage()[name] | 
					
						
							| 
									
										
										
										
											2025-03-03 13:42:13 -08:00
										 |  |  |             return t |