mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-15 19:23:25 +00:00
feat: introduce generalimport
(#4662)
* introduce generalimport * pylint * fix optional deps typing for schema * leftover * typo * typing with faiss * make Base generation optional too * handle sqlalchemy * (almost) all import are optional * TO REMOVE hijacking CI for tests * some deps are actually needed * get feature branch in CI * get feature branch in CI * fix array_equal * pylint * pandas also required * improve imports.yml * fix SquadData * fix SquadData again * generalimport imports list * Update haystack/utils/openai_utils.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * Update haystack/utils/openai_utils.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * review feedback * remove todos * reference main release * pylint * circular import * review feedback * move is_imported in init * pylint --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
This commit is contained in:
parent
5b2ef2afd6
commit
28260c5c3f
2
.github/workflows/imports.yml
vendored
2
.github/workflows/imports.yml
vendored
@ -37,5 +37,5 @@ jobs:
|
|||||||
- name: Install Haystack with no extras
|
- name: Install Haystack with no extras
|
||||||
run: pip install .
|
run: pip install .
|
||||||
|
|
||||||
- name: Try to import
|
- name: Import Haystack
|
||||||
run: python -c 'import haystack'
|
run: python -c 'import haystack'
|
||||||
|
@ -1,21 +1,94 @@
|
|||||||
# pylint: disable=wrong-import-position,wrong-import-order
|
# pylint: disable=wrong-import-position
|
||||||
|
# Logging is not configured here on purpose, see https://github.com/deepset-ai/haystack/issues/2485
|
||||||
|
|
||||||
from typing import Union
|
import sys
|
||||||
from types import ModuleType
|
from importlib import metadata
|
||||||
|
|
||||||
try:
|
|
||||||
from importlib import metadata
|
|
||||||
except (ModuleNotFoundError, ImportError):
|
|
||||||
# Python <= 3.7
|
|
||||||
import importlib_metadata as metadata # type: ignore
|
|
||||||
|
|
||||||
__version__: str = str(metadata.version("farm-haystack"))
|
__version__: str = str(metadata.version("farm-haystack"))
|
||||||
|
|
||||||
|
from generalimport import generalimport, MissingOptionalDependency, FakeModule
|
||||||
|
|
||||||
# Logging is not configured here on purpose, see https://github.com/deepset-ai/haystack/issues/2485
|
generalimport(
|
||||||
import logging
|
# "pydantic", # Required for all dataclasses
|
||||||
|
# "tenacity", # Probably needed because it's a decorator, to be evaluated
|
||||||
|
# "pandas",
|
||||||
|
"aiorwlock",
|
||||||
|
"azure",
|
||||||
|
"beautifulsoup4",
|
||||||
|
"beir",
|
||||||
|
"boilerpy3",
|
||||||
|
"canals",
|
||||||
|
"dill",
|
||||||
|
"docx",
|
||||||
|
"elasticsearch",
|
||||||
|
"events",
|
||||||
|
"faiss",
|
||||||
|
"fitz",
|
||||||
|
"frontmatter",
|
||||||
|
"huggingface_hub",
|
||||||
|
"jsonschema",
|
||||||
|
"langdetect",
|
||||||
|
"magic",
|
||||||
|
"markdown",
|
||||||
|
"mlflow",
|
||||||
|
"mmh3",
|
||||||
|
"more_itertools",
|
||||||
|
"networkx",
|
||||||
|
"nltk",
|
||||||
|
"numpy",
|
||||||
|
"onnxruntime",
|
||||||
|
"onnxruntime_tools",
|
||||||
|
"opensearchpy",
|
||||||
|
"pdf2image",
|
||||||
|
"PIL",
|
||||||
|
"pinecone",
|
||||||
|
"posthog",
|
||||||
|
"protobuf",
|
||||||
|
"psycopg2",
|
||||||
|
"pymilvus",
|
||||||
|
"pytesseract",
|
||||||
|
"quantulum3",
|
||||||
|
"rank_bm25",
|
||||||
|
"rapidfuzz",
|
||||||
|
"ray",
|
||||||
|
"rdflib",
|
||||||
|
"requests",
|
||||||
|
"scipy",
|
||||||
|
"selenium",
|
||||||
|
"sentence_transformers",
|
||||||
|
"seqeval",
|
||||||
|
"sklearn",
|
||||||
|
"SPARQLWrapper",
|
||||||
|
"sqlalchemy",
|
||||||
|
"sseclient",
|
||||||
|
"tenacity",
|
||||||
|
"tika",
|
||||||
|
"tiktoken",
|
||||||
|
"tokenizers",
|
||||||
|
"torch",
|
||||||
|
"tqdm",
|
||||||
|
"transformers",
|
||||||
|
"weaviate",
|
||||||
|
"webdriver_manager",
|
||||||
|
"whisper",
|
||||||
|
"yaml",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: remove this function once this PR is merged and released by generalimport:
|
||||||
|
# https://github.com/ManderaGeneral/generalimport/pull/25
|
||||||
|
def is_imported(module_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Returns True if the module was actually imported, False, if generalimport mocked it.
|
||||||
|
"""
|
||||||
|
module = sys.modules.get(module_name)
|
||||||
|
try:
|
||||||
|
return bool(module) and not isinstance(module, FakeModule)
|
||||||
|
except MissingOptionalDependency:
|
||||||
|
# isinstance() raises MissingOptionalDependency: fake module
|
||||||
|
pass
|
||||||
|
return False
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
from haystack.schema import Document, Answer, Label, MultiLabel, Span, EvaluationResult, TableCell
|
from haystack.schema import Document, Answer, Label, MultiLabel, Span, EvaluationResult, TableCell
|
||||||
from haystack.nodes.base import BaseComponent
|
from haystack.nodes.base import BaseComponent
|
||||||
@ -23,5 +96,6 @@ from haystack.pipelines.base import Pipeline
|
|||||||
from haystack.environment import set_pytorch_secure_model_loading
|
from haystack.environment import set_pytorch_secure_model_loading
|
||||||
|
|
||||||
|
|
||||||
pd.options.display.max_colwidth = 80
|
# Enables torch's secure model loading through setting an env var.
|
||||||
|
# Does not use torch.
|
||||||
set_pytorch_secure_model_loading()
|
set_pytorch_secure_model_loading()
|
||||||
|
@ -46,7 +46,7 @@ class FAISSDocumentStore(SQLDocumentStore):
|
|||||||
vector_dim: Optional[int] = None,
|
vector_dim: Optional[int] = None,
|
||||||
embedding_dim: int = 768,
|
embedding_dim: int = 768,
|
||||||
faiss_index_factory_str: str = "Flat",
|
faiss_index_factory_str: str = "Flat",
|
||||||
faiss_index: Optional[faiss.swigfaiss.Index] = None,
|
faiss_index: Optional["faiss.swigfaiss.Index"] = None,
|
||||||
return_embedding: bool = False,
|
return_embedding: bool = False,
|
||||||
index: str = "document",
|
index: str = "document",
|
||||||
similarity: str = "dot_product",
|
similarity: str = "dot_product",
|
||||||
|
@ -58,7 +58,7 @@ class PineconeDocumentStore(BaseDocumentStore):
|
|||||||
self,
|
self,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
environment: str = "us-west1-gcp",
|
environment: str = "us-west1-gcp",
|
||||||
pinecone_index: Optional[pinecone.Index] = None,
|
pinecone_index: Optional["pinecone.Index"] = None,
|
||||||
embedding_dim: int = 768,
|
embedding_dim: int = 768,
|
||||||
return_embedding: bool = False,
|
return_embedding: bool = False,
|
||||||
index: str = "document",
|
index: str = "document",
|
||||||
|
@ -27,103 +27,115 @@ try:
|
|||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
from sqlalchemy.orm import relationship, sessionmaker, aliased
|
from sqlalchemy.orm import relationship, sessionmaker, aliased
|
||||||
from sqlalchemy.sql import case, null
|
from sqlalchemy.sql import case, null
|
||||||
|
|
||||||
except (ImportError, ModuleNotFoundError) as ie:
|
except (ImportError, ModuleNotFoundError) as ie:
|
||||||
from haystack.utils.import_utils import _optional_component_not_installed
|
from haystack.utils.import_utils import _optional_component_not_installed
|
||||||
|
|
||||||
_optional_component_not_installed(__name__, "sql", ie)
|
_optional_component_not_installed(__name__, "sql", ie)
|
||||||
|
|
||||||
|
|
||||||
|
from haystack import is_imported
|
||||||
from haystack.schema import Document, Label, Answer
|
from haystack.schema import Document, Label, Answer
|
||||||
from haystack.document_stores.base import BaseDocumentStore, FilterType
|
from haystack.document_stores.base import BaseDocumentStore, FilterType
|
||||||
from haystack.document_stores.filter_utils import LogicalFilterClause
|
from haystack.document_stores.filter_utils import LogicalFilterClause
|
||||||
|
|
||||||
|
|
||||||
|
if not is_imported("sqlalchemy"):
|
||||||
|
Base = object
|
||||||
|
ArrayType = object
|
||||||
|
ORMBase = object
|
||||||
|
DocumentORM = object
|
||||||
|
MetaDocumentORM = object
|
||||||
|
LabelORM = object
|
||||||
|
MetaLabelORM = object
|
||||||
|
|
||||||
|
else:
|
||||||
|
Base = declarative_base() # type: Any
|
||||||
|
|
||||||
|
class ArrayType(TypeDecorator):
|
||||||
|
impl = String
|
||||||
|
cache_ok = True
|
||||||
|
|
||||||
|
def process_bind_param(self, value, dialect):
|
||||||
|
return json.dumps(value)
|
||||||
|
|
||||||
|
def process_result_value(self, value, dialect):
|
||||||
|
if value is not None:
|
||||||
|
return json.loads(value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
class ORMBase(Base):
|
||||||
|
__abstract__ = True
|
||||||
|
|
||||||
|
id = Column(String(100), default=lambda: str(uuid4()), primary_key=True)
|
||||||
|
created_at = Column(DateTime, server_default=func.now())
|
||||||
|
updated_at = Column(DateTime, server_default=func.now(), server_onupdate=func.now())
|
||||||
|
|
||||||
|
class DocumentORM(ORMBase):
|
||||||
|
__tablename__ = "document"
|
||||||
|
|
||||||
|
content = Column(JSON, nullable=False)
|
||||||
|
content_type = Column(Text, nullable=True)
|
||||||
|
# primary key in combination with id to allow the same doc in different indices
|
||||||
|
index = Column(String(100), nullable=False, primary_key=True)
|
||||||
|
vector_id = Column(String(100), nullable=True)
|
||||||
|
# speeds up queries for get_documents_by_vector_ids() by having a single query that returns joined metadata
|
||||||
|
meta = relationship("MetaDocumentORM", back_populates="documents", lazy="joined")
|
||||||
|
|
||||||
|
__table_args__ = (UniqueConstraint("index", "vector_id", name="index_vector_id_uc"),)
|
||||||
|
|
||||||
|
class MetaDocumentORM(ORMBase):
|
||||||
|
__tablename__ = "meta_document"
|
||||||
|
|
||||||
|
name = Column(String(100), index=True)
|
||||||
|
value = Column(ArrayType(1000), index=True)
|
||||||
|
documents = relationship("DocumentORM", back_populates="meta")
|
||||||
|
|
||||||
|
document_id = Column(String(100), nullable=False, index=True)
|
||||||
|
document_index = Column(String(100), nullable=False, index=True)
|
||||||
|
__table_args__ = ( # type: ignore
|
||||||
|
ForeignKeyConstraint(
|
||||||
|
[document_id, document_index],
|
||||||
|
[DocumentORM.id, DocumentORM.index],
|
||||||
|
ondelete="CASCADE",
|
||||||
|
onupdate="CASCADE",
|
||||||
|
),
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
|
||||||
|
class LabelORM(ORMBase):
|
||||||
|
__tablename__ = "label"
|
||||||
|
|
||||||
|
index = Column(String(100), nullable=False, primary_key=True)
|
||||||
|
query = Column(Text, nullable=False)
|
||||||
|
answer = Column(JSON, nullable=True)
|
||||||
|
document = Column(JSON, nullable=False)
|
||||||
|
no_answer = Column(Boolean, nullable=False)
|
||||||
|
origin = Column(String(100), nullable=False)
|
||||||
|
is_correct_answer = Column(Boolean, nullable=False)
|
||||||
|
is_correct_document = Column(Boolean, nullable=False)
|
||||||
|
pipeline_id = Column(String(500), nullable=True)
|
||||||
|
|
||||||
|
meta = relationship("MetaLabelORM", back_populates="labels", lazy="joined")
|
||||||
|
|
||||||
|
class MetaLabelORM(ORMBase):
|
||||||
|
__tablename__ = "meta_label"
|
||||||
|
|
||||||
|
name = Column(String(100), index=True)
|
||||||
|
value = Column(String(1000), index=True)
|
||||||
|
labels = relationship("LabelORM", back_populates="meta")
|
||||||
|
|
||||||
|
label_id = Column(String(100), nullable=False, index=True)
|
||||||
|
label_index = Column(String(100), nullable=False, index=True)
|
||||||
|
__table_args__ = ( # type: ignore
|
||||||
|
ForeignKeyConstraint(
|
||||||
|
[label_id, label_index], [LabelORM.id, LabelORM.index], ondelete="CASCADE", onupdate="CASCADE"
|
||||||
|
),
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
Base = declarative_base() # type: Any
|
|
||||||
|
|
||||||
|
|
||||||
class ArrayType(TypeDecorator):
|
|
||||||
impl = String
|
|
||||||
cache_ok = True
|
|
||||||
|
|
||||||
def process_bind_param(self, value, dialect):
|
|
||||||
return json.dumps(value)
|
|
||||||
|
|
||||||
def process_result_value(self, value, dialect):
|
|
||||||
if value is not None:
|
|
||||||
return json.loads(value)
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
class ORMBase(Base):
|
|
||||||
__abstract__ = True
|
|
||||||
|
|
||||||
id = Column(String(100), default=lambda: str(uuid4()), primary_key=True)
|
|
||||||
created_at = Column(DateTime, server_default=func.now())
|
|
||||||
updated_at = Column(DateTime, server_default=func.now(), server_onupdate=func.now())
|
|
||||||
|
|
||||||
|
|
||||||
class DocumentORM(ORMBase):
|
|
||||||
__tablename__ = "document"
|
|
||||||
|
|
||||||
content = Column(JSON, nullable=False)
|
|
||||||
content_type = Column(Text, nullable=True)
|
|
||||||
# primary key in combination with id to allow the same doc in different indices
|
|
||||||
index = Column(String(100), nullable=False, primary_key=True)
|
|
||||||
vector_id = Column(String(100), nullable=True)
|
|
||||||
# speeds up queries for get_documents_by_vector_ids() by having a single query that returns joined metadata
|
|
||||||
meta = relationship("MetaDocumentORM", back_populates="documents", lazy="joined")
|
|
||||||
|
|
||||||
__table_args__ = (UniqueConstraint("index", "vector_id", name="index_vector_id_uc"),)
|
|
||||||
|
|
||||||
|
|
||||||
class MetaDocumentORM(ORMBase):
|
|
||||||
__tablename__ = "meta_document"
|
|
||||||
|
|
||||||
name = Column(String(100), index=True)
|
|
||||||
value = Column(ArrayType(1000), index=True)
|
|
||||||
documents = relationship("DocumentORM", back_populates="meta")
|
|
||||||
|
|
||||||
document_id = Column(String(100), nullable=False, index=True)
|
|
||||||
document_index = Column(String(100), nullable=False, index=True)
|
|
||||||
__table_args__ = ( # type: ignore
|
|
||||||
ForeignKeyConstraint(
|
|
||||||
[document_id, document_index], [DocumentORM.id, DocumentORM.index], ondelete="CASCADE", onupdate="CASCADE"
|
|
||||||
),
|
|
||||||
{},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LabelORM(ORMBase):
|
|
||||||
__tablename__ = "label"
|
|
||||||
|
|
||||||
index = Column(String(100), nullable=False, primary_key=True)
|
|
||||||
query = Column(Text, nullable=False)
|
|
||||||
answer = Column(JSON, nullable=True)
|
|
||||||
document = Column(JSON, nullable=False)
|
|
||||||
no_answer = Column(Boolean, nullable=False)
|
|
||||||
origin = Column(String(100), nullable=False)
|
|
||||||
is_correct_answer = Column(Boolean, nullable=False)
|
|
||||||
is_correct_document = Column(Boolean, nullable=False)
|
|
||||||
pipeline_id = Column(String(500), nullable=True)
|
|
||||||
|
|
||||||
meta = relationship("MetaLabelORM", back_populates="labels", lazy="joined")
|
|
||||||
|
|
||||||
|
|
||||||
class MetaLabelORM(ORMBase):
|
|
||||||
__tablename__ = "meta_label"
|
|
||||||
|
|
||||||
name = Column(String(100), index=True)
|
|
||||||
value = Column(String(1000), index=True)
|
|
||||||
labels = relationship("LabelORM", back_populates="meta")
|
|
||||||
|
|
||||||
label_id = Column(String(100), nullable=False, index=True)
|
|
||||||
label_index = Column(String(100), nullable=False, index=True)
|
|
||||||
__table_args__ = ( # type: ignore
|
|
||||||
ForeignKeyConstraint(
|
|
||||||
[label_id, label_index], [LabelORM.id, LabelORM.index], ondelete="CASCADE", onupdate="CASCADE"
|
|
||||||
),
|
|
||||||
{},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SQLDocumentStore(BaseDocumentStore):
|
class SQLDocumentStore(BaseDocumentStore):
|
||||||
|
@ -41,7 +41,7 @@ class BiAdaptiveModel(nn.Module):
|
|||||||
language_model2: LanguageModel,
|
language_model2: LanguageModel,
|
||||||
prediction_heads: List[PredictionHead],
|
prediction_heads: List[PredictionHead],
|
||||||
embeds_dropout_prob: float = 0.1,
|
embeds_dropout_prob: float = 0.1,
|
||||||
device: torch.device = torch.device("cuda"),
|
device: Optional[torch.device] = None,
|
||||||
lm1_output_types: Optional[Union[str, List[str]]] = None,
|
lm1_output_types: Optional[Union[str, List[str]]] = None,
|
||||||
lm2_output_types: Optional[Union[str, List[str]]] = None,
|
lm2_output_types: Optional[Union[str, List[str]]] = None,
|
||||||
loss_aggregation_fn: Optional[Callable] = None,
|
loss_aggregation_fn: Optional[Callable] = None,
|
||||||
@ -74,6 +74,9 @@ class BiAdaptiveModel(nn.Module):
|
|||||||
Note: The loss at this stage is per sample, i.e one tensor of
|
Note: The loss at this stage is per sample, i.e one tensor of
|
||||||
shape (batchsize) per prediction head.
|
shape (batchsize) per prediction head.
|
||||||
"""
|
"""
|
||||||
|
if not device:
|
||||||
|
device = torch.device("cuda")
|
||||||
|
|
||||||
if lm1_output_types is None:
|
if lm1_output_types is None:
|
||||||
lm1_output_types = ["per_sequence"]
|
lm1_output_types = ["per_sequence"]
|
||||||
if lm2_output_types is None:
|
if lm2_output_types is None:
|
||||||
|
@ -29,6 +29,7 @@ from transformers import PreTrainedTokenizer, RobertaTokenizer, AutoConfig, Auto
|
|||||||
from transformers.models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING_NAMES
|
from transformers.models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING_NAMES
|
||||||
from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING_NAMES
|
from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING_NAMES
|
||||||
|
|
||||||
|
from haystack import is_imported
|
||||||
from haystack.errors import ModelingError
|
from haystack.errors import ModelingError
|
||||||
from haystack.modeling.data_handler.samples import SampleBasket
|
from haystack.modeling.data_handler.samples import SampleBasket
|
||||||
|
|
||||||
@ -40,12 +41,16 @@ logger = logging.getLogger(__name__)
|
|||||||
SPECIAL_TOKENIZER_CHARS = r"^(##|Ġ|▁)"
|
SPECIAL_TOKENIZER_CHARS = r"^(##|Ġ|▁)"
|
||||||
|
|
||||||
|
|
||||||
|
if not is_imported("transformers"):
|
||||||
|
TOKENIZER_MAPPING_NAMES = {}
|
||||||
|
FEATURE_EXTRACTOR_MAPPING_NAMES = {}
|
||||||
|
|
||||||
|
|
||||||
FEATURE_EXTRACTORS = {
|
FEATURE_EXTRACTORS = {
|
||||||
**{key: AutoTokenizer for key in TOKENIZER_MAPPING_NAMES.keys()},
|
**{key: AutoTokenizer for key in TOKENIZER_MAPPING_NAMES.keys()},
|
||||||
**{key: AutoFeatureExtractor for key in FEATURE_EXTRACTOR_MAPPING_NAMES.keys()},
|
**{key: AutoFeatureExtractor for key in FEATURE_EXTRACTOR_MAPPING_NAMES.keys()},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_EXTRACTION_PARAMS = {
|
DEFAULT_EXTRACTION_PARAMS = {
|
||||||
AutoTokenizer: {
|
AutoTokenizer: {
|
||||||
"max_length": 256,
|
"max_length": 256,
|
||||||
|
@ -43,7 +43,7 @@ class TriAdaptiveModel(nn.Module):
|
|||||||
language_model3: LanguageModel,
|
language_model3: LanguageModel,
|
||||||
prediction_heads: List[PredictionHead],
|
prediction_heads: List[PredictionHead],
|
||||||
embeds_dropout_prob: float = 0.1,
|
embeds_dropout_prob: float = 0.1,
|
||||||
device: torch.device = torch.device("cuda"),
|
device: Optional[torch.device] = None,
|
||||||
lm1_output_types: Optional[Union[str, List[str]]] = None,
|
lm1_output_types: Optional[Union[str, List[str]]] = None,
|
||||||
lm2_output_types: Optional[Union[str, List[str]]] = None,
|
lm2_output_types: Optional[Union[str, List[str]]] = None,
|
||||||
lm3_output_types: Optional[Union[str, List[str]]] = None,
|
lm3_output_types: Optional[Union[str, List[str]]] = None,
|
||||||
@ -83,6 +83,9 @@ class TriAdaptiveModel(nn.Module):
|
|||||||
Note: The loss at this stage is per sample, i.e one tensor of
|
Note: The loss at this stage is per sample, i.e one tensor of
|
||||||
shape (batchsize) per prediction head.
|
shape (batchsize) per prediction head.
|
||||||
"""
|
"""
|
||||||
|
if not device:
|
||||||
|
device = torch.device("cuda")
|
||||||
|
|
||||||
if lm1_output_types is None:
|
if lm1_output_types is None:
|
||||||
lm1_output_types = ["per_sequence"]
|
lm1_output_types = ["per_sequence"]
|
||||||
if lm2_output_types is None:
|
if lm2_output_types is None:
|
||||||
|
@ -55,6 +55,10 @@ def field_singleton_schema(
|
|||||||
known_models: TypeModelSet,
|
known_models: TypeModelSet,
|
||||||
) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]:
|
) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]:
|
||||||
try:
|
try:
|
||||||
|
# Typing with optional dependencies is really tricky. Let's just use Any for now. To be fixed.
|
||||||
|
if isinstance(field.type_, ForwardRef):
|
||||||
|
logger.debug(field.type_)
|
||||||
|
field.type_ = Any
|
||||||
return _field_singleton_schema(
|
return _field_singleton_schema(
|
||||||
field,
|
field,
|
||||||
by_alias=by_alias,
|
by_alias=by_alias,
|
||||||
@ -211,7 +215,10 @@ def create_schema_for_node_class(node_class: Type[BaseComponent]) -> Tuple[Dict[
|
|||||||
|
|
||||||
# Create the model with Pydantic and extract the schema
|
# Create the model with Pydantic and extract the schema
|
||||||
model = create_model(f"{node_name}ComponentParams", __config__=Config, **param_fields_kwargs)
|
model = create_model(f"{node_name}ComponentParams", __config__=Config, **param_fields_kwargs)
|
||||||
model.update_forward_refs(**model.__dict__)
|
try:
|
||||||
|
model.update_forward_refs(**model.__dict__)
|
||||||
|
except NameError as exc:
|
||||||
|
logger.debug("%s", str(exc))
|
||||||
params_schema = model.schema()
|
params_schema = model.schema()
|
||||||
|
|
||||||
# Pydantic v1 patch to generate JSON schemas including Optional fields
|
# Pydantic v1 patch to generate JSON schemas including Optional fields
|
||||||
|
@ -20,7 +20,9 @@ from dataclasses import asdict
|
|||||||
|
|
||||||
import mmh3
|
import mmh3
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from numpy import ndarray
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
from pandas import DataFrame
|
||||||
|
|
||||||
from pydantic import BaseConfig, Field
|
from pydantic import BaseConfig, Field
|
||||||
from pydantic.json import pydantic_encoder
|
from pydantic.json import pydantic_encoder
|
||||||
@ -29,12 +31,19 @@ from pydantic.json import pydantic_encoder
|
|||||||
# See #1598 for the reasons behind this choice & performance considerations
|
# See #1598 for the reasons behind this choice & performance considerations
|
||||||
from pydantic.dataclasses import dataclass
|
from pydantic.dataclasses import dataclass
|
||||||
|
|
||||||
|
from haystack import is_imported
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
if not is_imported("pandas"):
|
||||||
|
DataFrame = object
|
||||||
|
|
||||||
|
|
||||||
BaseConfig.arbitrary_types_allowed = True
|
BaseConfig.arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
#: Types of content_types supported
|
#: Types of content_types supported
|
||||||
ContentTypes = Literal["text", "table", "image", "audio"]
|
ContentTypes = Literal["text", "table", "image", "audio"]
|
||||||
FilterType = Dict[str, Union[Dict[str, Any], List[Any], str, int, float, bool]]
|
FilterType = Dict[str, Union[Dict[str, Any], List[Any], str, int, float, bool]]
|
||||||
@ -43,12 +52,12 @@ FilterType = Dict[str, Union[Dict[str, Any], List[Any], str, int, float, bool]]
|
|||||||
@dataclass
|
@dataclass
|
||||||
class Document:
|
class Document:
|
||||||
id: str
|
id: str
|
||||||
content: Union[str, pd.DataFrame]
|
content: Union[str, DataFrame]
|
||||||
content_type: ContentTypes = Field(default="text")
|
content_type: ContentTypes = Field(default="text")
|
||||||
meta: Dict[str, Any] = Field(default={})
|
meta: Dict[str, Any] = Field(default={})
|
||||||
id_hash_keys: List[str] = Field(default=["content"])
|
id_hash_keys: List[str] = Field(default=["content"])
|
||||||
score: Optional[float] = None
|
score: Optional[float] = None
|
||||||
embedding: Optional[np.ndarray] = None
|
embedding: Optional[ndarray] = None
|
||||||
|
|
||||||
# We use a custom init here as we want some custom logic. The annotations above are however still needed in order
|
# We use a custom init here as we want some custom logic. The annotations above are however still needed in order
|
||||||
# to use some dataclass magic like "asdict()". See https://www.python.org/dev/peps/pep-0557/#custom-init-method
|
# to use some dataclass magic like "asdict()". See https://www.python.org/dev/peps/pep-0557/#custom-init-method
|
||||||
@ -56,12 +65,12 @@ class Document:
|
|||||||
# don't need to passed by the user in init and are rather initialized automatically in the init
|
# don't need to passed by the user in init and are rather initialized automatically in the init
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
content: Union[str, pd.DataFrame],
|
content: Union[str, DataFrame],
|
||||||
content_type: ContentTypes = "text",
|
content_type: ContentTypes = "text",
|
||||||
id: Optional[str] = None,
|
id: Optional[str] = None,
|
||||||
score: Optional[float] = None,
|
score: Optional[float] = None,
|
||||||
meta: Optional[Dict[str, Any]] = None,
|
meta: Optional[Dict[str, Any]] = None,
|
||||||
embedding: Optional[np.ndarray] = None,
|
embedding: Optional[ndarray] = None,
|
||||||
id_hash_keys: Optional[List[str]] = None,
|
id_hash_keys: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -184,7 +193,7 @@ class Document:
|
|||||||
continue
|
continue
|
||||||
if k == "content":
|
if k == "content":
|
||||||
# Convert pd.DataFrame to list of rows for serialization
|
# Convert pd.DataFrame to list of rows for serialization
|
||||||
if self.content_type == "table" and isinstance(self.content, pd.DataFrame):
|
if self.content_type == "table" and isinstance(self.content, DataFrame):
|
||||||
v = dataframe_to_list(self.content)
|
v = dataframe_to_list(self.content)
|
||||||
k = k if k not in inv_field_map else inv_field_map[k]
|
k = k if k not in inv_field_map else inv_field_map[k]
|
||||||
_doc[k] = v
|
_doc[k] = v
|
||||||
@ -230,7 +239,7 @@ class Document:
|
|||||||
k = field_map[k]
|
k = field_map[k]
|
||||||
_new_doc[k] = v
|
_new_doc[k] = v
|
||||||
|
|
||||||
# Convert list of rows to pd.DataFrame
|
# Convert list of rows to DataFrame
|
||||||
if _new_doc.get("content_type", None) == "table" and isinstance(_new_doc["content"], list):
|
if _new_doc.get("content_type", None) == "table" and isinstance(_new_doc["content"], list):
|
||||||
_new_doc["content"] = dataframe_from_list(_new_doc["content"])
|
_new_doc["content"] = dataframe_from_list(_new_doc["content"])
|
||||||
|
|
||||||
@ -358,7 +367,7 @@ class Answer:
|
|||||||
answer: str
|
answer: str
|
||||||
type: Literal["generative", "extractive", "other"] = "extractive"
|
type: Literal["generative", "extractive", "other"] = "extractive"
|
||||||
score: Optional[float] = None
|
score: Optional[float] = None
|
||||||
context: Optional[Union[str, pd.DataFrame]] = None
|
context: Optional[Union[str, DataFrame]] = None
|
||||||
offsets_in_document: Optional[Union[List[Span], List[TableCell]]] = None
|
offsets_in_document: Optional[Union[List[Span], List[TableCell]]] = None
|
||||||
offsets_in_context: Optional[Union[List[Span], List[TableCell]]] = None
|
offsets_in_context: Optional[Union[List[Span], List[TableCell]]] = None
|
||||||
document_ids: Optional[List[str]] = None
|
document_ids: Optional[List[str]] = None
|
||||||
@ -832,7 +841,7 @@ def dataframe_from_list(list_df: List[List]) -> pd.DataFrame:
|
|||||||
|
|
||||||
|
|
||||||
class EvaluationResult:
|
class EvaluationResult:
|
||||||
def __init__(self, node_results: Optional[Dict[str, pd.DataFrame]] = None) -> None:
|
def __init__(self, node_results: Optional[Dict[str, DataFrame]] = None) -> None:
|
||||||
"""
|
"""
|
||||||
A convenience class to store, pass, and interact with results of a pipeline evaluation run (for example `pipeline.eval()`).
|
A convenience class to store, pass, and interact with results of a pipeline evaluation run (for example `pipeline.eval()`).
|
||||||
Detailed results are stored as one dataframe per node. This class makes them more accessible and provides
|
Detailed results are stored as one dataframe per node. This class makes them more accessible and provides
|
||||||
@ -902,7 +911,7 @@ class EvaluationResult:
|
|||||||
|
|
||||||
:param node_results: The evaluation Dataframes per pipeline node.
|
:param node_results: The evaluation Dataframes per pipeline node.
|
||||||
"""
|
"""
|
||||||
self.node_results: Dict[str, pd.DataFrame] = {} if node_results is None else node_results
|
self.node_results: Dict[str, DataFrame] = {} if node_results is None else node_results
|
||||||
|
|
||||||
def __getitem__(self, key: str):
|
def __getitem__(self, key: str):
|
||||||
return self.node_results.__getitem__(key)
|
return self.node_results.__getitem__(key)
|
||||||
@ -910,7 +919,7 @@ class EvaluationResult:
|
|||||||
def __delitem__(self, key: str):
|
def __delitem__(self, key: str):
|
||||||
self.node_results.__delitem__(key)
|
self.node_results.__delitem__(key)
|
||||||
|
|
||||||
def __setitem__(self, key: str, value: pd.DataFrame):
|
def __setitem__(self, key: str, value: DataFrame):
|
||||||
self.node_results.__setitem__(key, value)
|
self.node_results.__setitem__(key, value)
|
||||||
|
|
||||||
def __contains__(self, key: str):
|
def __contains__(self, key: str):
|
||||||
@ -919,7 +928,7 @@ class EvaluationResult:
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.node_results.__len__()
|
return self.node_results.__len__()
|
||||||
|
|
||||||
def append(self, key: str, value: pd.DataFrame):
|
def append(self, key: str, value: DataFrame):
|
||||||
if value is not None and len(value) > 0:
|
if value is not None and len(value) > 0:
|
||||||
if key in self.node_results:
|
if key in self.node_results:
|
||||||
self.node_results[key] = pd.concat([self.node_results[key], value])
|
self.node_results[key] = pd.concat([self.node_results[key], value])
|
||||||
@ -1210,7 +1219,7 @@ class EvaluationResult:
|
|||||||
|
|
||||||
def _calculate_node_metrics(
|
def _calculate_node_metrics(
|
||||||
self,
|
self,
|
||||||
df: pd.DataFrame,
|
df: DataFrame,
|
||||||
simulated_top_k_reader: int = -1,
|
simulated_top_k_reader: int = -1,
|
||||||
simulated_top_k_retriever: int = -1,
|
simulated_top_k_retriever: int = -1,
|
||||||
document_scope: Literal[
|
document_scope: Literal[
|
||||||
@ -1244,7 +1253,7 @@ class EvaluationResult:
|
|||||||
|
|
||||||
return {**answer_metrics, **document_metrics}
|
return {**answer_metrics, **document_metrics}
|
||||||
|
|
||||||
def _filter_eval_mode(self, df: pd.DataFrame, eval_mode: str) -> pd.DataFrame:
|
def _filter_eval_mode(self, df: DataFrame, eval_mode: str) -> DataFrame:
|
||||||
if "eval_mode" in df.columns:
|
if "eval_mode" in df.columns:
|
||||||
df = df[df["eval_mode"] == eval_mode]
|
df = df[df["eval_mode"] == eval_mode]
|
||||||
else:
|
else:
|
||||||
@ -1253,7 +1262,7 @@ class EvaluationResult:
|
|||||||
|
|
||||||
def _calculate_answer_metrics(
|
def _calculate_answer_metrics(
|
||||||
self,
|
self,
|
||||||
df: pd.DataFrame,
|
df: DataFrame,
|
||||||
simulated_top_k_reader: int = -1,
|
simulated_top_k_reader: int = -1,
|
||||||
simulated_top_k_retriever: int = -1,
|
simulated_top_k_retriever: int = -1,
|
||||||
answer_scope: Literal["any", "context", "document_id", "document_id_and_context"] = "any",
|
answer_scope: Literal["any", "context", "document_id", "document_id_and_context"] = "any",
|
||||||
@ -1275,11 +1284,11 @@ class EvaluationResult:
|
|||||||
|
|
||||||
def _build_answer_metrics_df(
|
def _build_answer_metrics_df(
|
||||||
self,
|
self,
|
||||||
answers: pd.DataFrame,
|
answers: DataFrame,
|
||||||
simulated_top_k_reader: int = -1,
|
simulated_top_k_reader: int = -1,
|
||||||
simulated_top_k_retriever: int = -1,
|
simulated_top_k_retriever: int = -1,
|
||||||
answer_scope: Literal["any", "context", "document_id", "document_id_and_context"] = "any",
|
answer_scope: Literal["any", "context", "document_id", "document_id_and_context"] = "any",
|
||||||
) -> pd.DataFrame:
|
) -> DataFrame:
|
||||||
"""
|
"""
|
||||||
Builds a dataframe containing answer metrics (columns) per multilabel (index).
|
Builds a dataframe containing answer metrics (columns) per multilabel (index).
|
||||||
Answer metrics are:
|
Answer metrics are:
|
||||||
@ -1335,7 +1344,7 @@ class EvaluationResult:
|
|||||||
}
|
}
|
||||||
df_records.append(query_metrics)
|
df_records.append(query_metrics)
|
||||||
|
|
||||||
metrics_df = pd.DataFrame.from_records(df_records, index=multilabel_ids)
|
metrics_df = DataFrame.from_records(df_records, index=multilabel_ids)
|
||||||
return metrics_df
|
return metrics_df
|
||||||
|
|
||||||
def _get_documents_df(self):
|
def _get_documents_df(self):
|
||||||
@ -1350,7 +1359,7 @@ class EvaluationResult:
|
|||||||
|
|
||||||
def _calculate_document_metrics(
|
def _calculate_document_metrics(
|
||||||
self,
|
self,
|
||||||
df: pd.DataFrame,
|
df: DataFrame,
|
||||||
simulated_top_k_retriever: int = -1,
|
simulated_top_k_retriever: int = -1,
|
||||||
document_relevance_criterion: Literal[
|
document_relevance_criterion: Literal[
|
||||||
"document_id",
|
"document_id",
|
||||||
@ -1378,7 +1387,7 @@ class EvaluationResult:
|
|||||||
|
|
||||||
def _build_document_metrics_df(
|
def _build_document_metrics_df(
|
||||||
self,
|
self,
|
||||||
documents: pd.DataFrame,
|
documents: DataFrame,
|
||||||
simulated_top_k_retriever: int = -1,
|
simulated_top_k_retriever: int = -1,
|
||||||
document_relevance_criterion: Literal[
|
document_relevance_criterion: Literal[
|
||||||
"document_id",
|
"document_id",
|
||||||
@ -1391,7 +1400,7 @@ class EvaluationResult:
|
|||||||
"document_id_and_context_and_answer",
|
"document_id_and_context_and_answer",
|
||||||
"document_id_or_answer",
|
"document_id_or_answer",
|
||||||
] = "document_id_or_answer",
|
] = "document_id_or_answer",
|
||||||
) -> pd.DataFrame:
|
) -> DataFrame:
|
||||||
"""
|
"""
|
||||||
Builds a dataframe containing document metrics (columns) per pair of query and gold document ids (index).
|
Builds a dataframe containing document metrics (columns) per pair of query and gold document ids (index).
|
||||||
Document metrics are:
|
Document metrics are:
|
||||||
@ -1539,7 +1548,7 @@ class EvaluationResult:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
metrics_df = pd.DataFrame.from_records(metrics, index=documents["multilabel_id"].unique())
|
metrics_df = DataFrame.from_records(metrics, index=documents["multilabel_id"].unique())
|
||||||
return metrics_df
|
return metrics_df
|
||||||
|
|
||||||
def save(self, out_dir: Union[str, Path], **to_csv_kwargs):
|
def save(self, out_dir: Union[str, Path], **to_csv_kwargs):
|
||||||
@ -1548,8 +1557,8 @@ class EvaluationResult:
|
|||||||
The result of each node is saved in a separate csv with file name {node_name}.csv to the out_dir folder.
|
The result of each node is saved in a separate csv with file name {node_name}.csv to the out_dir folder.
|
||||||
|
|
||||||
:param out_dir: Path to the target folder the csvs will be saved.
|
:param out_dir: Path to the target folder the csvs will be saved.
|
||||||
:param to_csv_kwargs: kwargs to be passed to pd.DataFrame.to_csv(). See https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_csv.html.
|
:param to_csv_kwargs: kwargs to be passed to DataFrame.to_csv(). See https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_csv.html.
|
||||||
This method uses different default values than pd.DataFrame.to_csv() for the following parameters:
|
This method uses different default values than DataFrame.to_csv() for the following parameters:
|
||||||
index=False, quoting=csv.QUOTE_NONNUMERIC (to avoid problems with \r chars)
|
index=False, quoting=csv.QUOTE_NONNUMERIC (to avoid problems with \r chars)
|
||||||
"""
|
"""
|
||||||
out_dir = out_dir if isinstance(out_dir, Path) else Path(out_dir)
|
out_dir = out_dir if isinstance(out_dir, Path) else Path(out_dir)
|
||||||
|
@ -9,6 +9,7 @@ from pathlib import Path
|
|||||||
from typing import Optional, Dict, Union, Tuple, List
|
from typing import Optional, Dict, Union, Tuple, List
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from haystack.errors import DatasetsError
|
from haystack.errors import DatasetsError
|
||||||
from haystack.schema import Document
|
from haystack.schema import Document
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ import sys
|
|||||||
import json
|
import json
|
||||||
from typing import Dict, Union, Tuple, Optional, List
|
from typing import Dict, Union, Tuple, Optional, List
|
||||||
import requests
|
import requests
|
||||||
from tenacity import retry, retry_if_exception_type, wait_exponential, stop_after_attempt
|
import tenacity
|
||||||
from transformers import GPT2TokenizerFast
|
from transformers import GPT2TokenizerFast
|
||||||
|
|
||||||
from haystack.errors import OpenAIError, OpenAIRateLimitError, OpenAIUnauthorizedError
|
from haystack.errors import OpenAIError, OpenAIRateLimitError, OpenAIUnauthorizedError
|
||||||
@ -127,10 +127,10 @@ def _openai_text_completion_tokenization_details(model_name: str):
|
|||||||
return tokenizer_name, max_tokens_limit
|
return tokenizer_name, max_tokens_limit
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
@tenacity.retry(
|
||||||
retry=retry_if_exception_type(OpenAIRateLimitError),
|
retry=tenacity.retry_if_exception_type(OpenAIRateLimitError),
|
||||||
wait=wait_exponential(multiplier=OPENAI_BACKOFF),
|
wait=tenacity.wait_exponential(multiplier=OPENAI_BACKOFF),
|
||||||
stop=stop_after_attempt(OPENAI_MAX_RETRIES),
|
stop=tenacity.stop_after_attempt(OPENAI_MAX_RETRIES),
|
||||||
)
|
)
|
||||||
def openai_request(
|
def openai_request(
|
||||||
url: str,
|
url: str,
|
||||||
|
@ -7,6 +7,7 @@ import pandas as pd
|
|||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
import mmh3
|
import mmh3
|
||||||
|
|
||||||
|
from haystack import is_imported
|
||||||
from haystack.schema import Document, Label, Answer
|
from haystack.schema import Document, Label, Answer
|
||||||
from haystack.modeling.data_handler.processor import _read_squad_file
|
from haystack.modeling.data_handler.processor import _read_squad_file
|
||||||
|
|
||||||
@ -14,7 +15,8 @@ from haystack.modeling.data_handler.processor import _read_squad_file
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
tqdm.pandas()
|
if is_imported("pandas") and is_imported("tqdm"):
|
||||||
|
tqdm.pandas()
|
||||||
|
|
||||||
|
|
||||||
COLUMN_NAMES = ["title", "context", "question", "id", "answer_text", "answer_start", "is_impossible"]
|
COLUMN_NAMES = ["title", "context", "question", "id", "answer_text", "answer_start", "is_impossible"]
|
||||||
|
@ -53,6 +53,7 @@ dependencies = [
|
|||||||
"pandas",
|
"pandas",
|
||||||
"rank_bm25",
|
"rank_bm25",
|
||||||
"scikit-learn>=1.0.0", # TF-IDF, SklearnQueryClassifier and metrics
|
"scikit-learn>=1.0.0", # TF-IDF, SklearnQueryClassifier and metrics
|
||||||
|
"generalimport", # Optional imports
|
||||||
|
|
||||||
# Utils
|
# Utils
|
||||||
"dill", # pickle extension for (de-)serialization
|
"dill", # pickle extension for (de-)serialization
|
||||||
|
Loading…
x
Reference in New Issue
Block a user