diff --git a/haystack/components/extractors/named_entity_extractor.py b/haystack/components/extractors/named_entity_extractor.py index 93fe20ece..cbb96ab46 100644 --- a/haystack/components/extractors/named_entity_extractor.py +++ b/haystack/components/extractors/named_entity_extractor.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from contextlib import contextmanager from dataclasses import dataclass -from enum import Enum, EnumMeta +from enum import Enum from typing import Any, Dict, List, Optional, Union from haystack import ComponentError, DeserializationError, Document, component, default_from_dict, default_to_dict @@ -21,28 +21,7 @@ with LazyImport(message="Run 'pip install spacy'") as spacy_import: from spacy import Language as SpacyPipeline -class _BackendEnumMeta(EnumMeta): - """ - Metaclass for fine-grained error handling of backend enums. - """ - - def __call__(cls, value, names=None, *, module=None, qualname=None, type=None, start=1): # noqa: A002 - if names is None: - try: - return EnumMeta.__call__(cls, value, names, module=module, qualname=qualname, type=type, start=start) - except ValueError: - supported_backends = ", ".join(sorted(v.value for v in cls)) # pylint: disable=not-an-iterable - raise ComponentError( - f"Invalid backend `{value}` for named entity extractor. " - f"Supported backends: {supported_backends}" - ) - else: - return EnumMeta.__call__( # pylint: disable=too-many-function-args - cls, value, names, module, qualname, type, start - ) - - -class NamedEntityExtractorBackend(Enum, metaclass=_BackendEnumMeta): +class NamedEntityExtractorBackend(Enum): """ NLP backend to use for Named Entity Recognition. """ @@ -53,6 +32,24 @@ class NamedEntityExtractorBackend(Enum, metaclass=_BackendEnumMeta): #: Uses a spaCy model and pipeline. SPACY = "spacy" + def __str__(self): + return self.value + + @staticmethod + def from_str(string: str) -> "NamedEntityExtractorBackend": + """ + Convert a string to a NamedEntityExtractorBackend enum. + """ + enum_map = {e.value: e for e in NamedEntityExtractorBackend} + mode = enum_map.get(string) + if mode is None: + msg = ( + f"Invalid backend '{string}' for named entity extractor. " + f"Supported backends are: {list(enum_map.keys())}" + ) + raise ComponentError(msg) + return mode + @dataclass class NamedEntityAnnotation: @@ -134,7 +131,7 @@ class NamedEntityExtractor: """ if isinstance(backend, str): - backend = NamedEntityExtractorBackend(backend) + backend = NamedEntityExtractorBackend.from_str(backend) self._backend: _NerBackend device = ComponentDevice.resolve_device(device) diff --git a/releasenotes/notes/fix-named-entity-entity-extractor-backend-enum-2e64e25f8d7f1b08.yaml b/releasenotes/notes/fix-named-entity-entity-extractor-backend-enum-2e64e25f8d7f1b08.yaml new file mode 100644 index 000000000..0df95101c --- /dev/null +++ b/releasenotes/notes/fix-named-entity-entity-extractor-backend-enum-2e64e25f8d7f1b08.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Fix `NamedEntityExtractor` crashing in Python 3.12 if constructed using a string backend argument.