Fix NamedEntityExtractor crashing in Python 3.12 if constructed using a string backend argument. (#7743)

This commit is contained in:
Silvano Cerza 2024-05-24 16:41:29 +02:00 committed by GitHub
parent 98fd270428
commit f5becf2ac0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 25 additions and 24 deletions

View File

@ -5,7 +5,7 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum, EnumMeta
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from haystack import ComponentError, DeserializationError, Document, component, default_from_dict, default_to_dict
@ -21,28 +21,7 @@ with LazyImport(message="Run 'pip install spacy'") as spacy_import:
from spacy import Language as SpacyPipeline
class _BackendEnumMeta(EnumMeta):
"""
Metaclass for fine-grained error handling of backend enums.
"""
def __call__(cls, value, names=None, *, module=None, qualname=None, type=None, start=1): # noqa: A002
if names is None:
try:
return EnumMeta.__call__(cls, value, names, module=module, qualname=qualname, type=type, start=start)
except ValueError:
supported_backends = ", ".join(sorted(v.value for v in cls)) # pylint: disable=not-an-iterable
raise ComponentError(
f"Invalid backend `{value}` for named entity extractor. "
f"Supported backends: {supported_backends}"
)
else:
return EnumMeta.__call__( # pylint: disable=too-many-function-args
cls, value, names, module, qualname, type, start
)
class NamedEntityExtractorBackend(Enum, metaclass=_BackendEnumMeta):
class NamedEntityExtractorBackend(Enum):
"""
NLP backend to use for Named Entity Recognition.
"""
@ -53,6 +32,24 @@ class NamedEntityExtractorBackend(Enum, metaclass=_BackendEnumMeta):
#: Uses a spaCy model and pipeline.
SPACY = "spacy"
def __str__(self):
return self.value
@staticmethod
def from_str(string: str) -> "NamedEntityExtractorBackend":
"""
Convert a string to a NamedEntityExtractorBackend enum.
"""
enum_map = {e.value: e for e in NamedEntityExtractorBackend}
mode = enum_map.get(string)
if mode is None:
msg = (
f"Invalid backend '{string}' for named entity extractor. "
f"Supported backends are: {list(enum_map.keys())}"
)
raise ComponentError(msg)
return mode
@dataclass
class NamedEntityAnnotation:
@ -134,7 +131,7 @@ class NamedEntityExtractor:
"""
if isinstance(backend, str):
backend = NamedEntityExtractorBackend(backend)
backend = NamedEntityExtractorBackend.from_str(backend)
self._backend: _NerBackend
device = ComponentDevice.resolve_device(device)

View File

@ -0,0 +1,4 @@
---
fixes:
- |
Fix `NamedEntityExtractor` crashing in Python 3.12 if constructed using a string backend argument.