2023-02-27 09:55:03 +01:00
|
|
|
import warnings
|
2022-03-31 12:36:45 +02:00
|
|
|
from datetime import timedelta
|
2022-06-10 18:22:48 +02:00
|
|
|
from typing import Any, List, Optional, Dict, Union
|
2022-03-15 11:17:26 +01:00
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
from uuid import UUID
|
2021-12-22 17:20:23 +01:00
|
|
|
import gc
|
2022-01-14 13:48:58 +01:00
|
|
|
import logging
|
2022-01-26 18:12:55 +01:00
|
|
|
from pathlib import Path
|
2022-03-21 22:24:09 +07:00
|
|
|
import os
|
2022-09-22 17:46:49 +02:00
|
|
|
import re
|
2023-01-16 15:36:14 +01:00
|
|
|
from functools import wraps
|
2022-03-21 22:24:09 +07:00
|
|
|
|
2022-03-31 12:36:45 +02:00
|
|
|
import requests_cache
|
2022-01-25 20:36:28 +01:00
|
|
|
import responses
|
2022-01-14 13:48:58 +01:00
|
|
|
from sqlalchemy import create_engine, text
|
2022-03-21 11:58:51 +01:00
|
|
|
import posthog
|
2020-05-04 18:00:07 +02:00
|
|
|
|
2021-11-12 16:44:28 +01:00
|
|
|
import numpy as np
|
2021-09-22 16:56:51 +02:00
|
|
|
import psutil
|
2020-05-04 18:00:07 +02:00
|
|
|
import pytest
|
2021-06-14 17:53:43 +02:00
|
|
|
|
2023-01-16 15:36:14 +01:00
|
|
|
from haystack import Answer, BaseComponent, __version__ as haystack_version
|
2022-09-23 13:26:49 +02:00
|
|
|
from haystack.document_stores import (
|
|
|
|
BaseDocumentStore,
|
|
|
|
InMemoryDocumentStore,
|
|
|
|
ElasticsearchDocumentStore,
|
|
|
|
WeaviateDocumentStore,
|
|
|
|
MilvusDocumentStore,
|
|
|
|
PineconeDocumentStore,
|
|
|
|
OpenSearchDocumentStore,
|
|
|
|
FAISSDocumentStore,
|
|
|
|
)
|
2022-09-21 14:53:42 +02:00
|
|
|
from haystack.nodes import (
|
|
|
|
BaseReader,
|
|
|
|
BaseRetriever,
|
|
|
|
OpenAIAnswerGenerator,
|
|
|
|
BaseGenerator,
|
|
|
|
BaseSummarizer,
|
|
|
|
BaseTranslator,
|
2022-09-21 19:08:54 +02:00
|
|
|
DenseRetriever,
|
2022-09-23 13:26:49 +02:00
|
|
|
Seq2SeqGenerator,
|
|
|
|
RAGenerator,
|
|
|
|
SentenceTransformersRanker,
|
|
|
|
TransformersDocumentClassifier,
|
|
|
|
FilterRetriever,
|
|
|
|
BM25Retriever,
|
|
|
|
TfidfRetriever,
|
2022-07-05 11:31:11 +02:00
|
|
|
DensePassageRetriever,
|
|
|
|
EmbeddingRetriever,
|
|
|
|
MultihopEmbeddingRetriever,
|
|
|
|
TableTextRetriever,
|
2022-09-23 13:26:49 +02:00
|
|
|
FARMReader,
|
|
|
|
TransformersReader,
|
|
|
|
TableReader,
|
|
|
|
RCIReader,
|
|
|
|
TransformersSummarizer,
|
|
|
|
QuestionGenerator,
|
2023-02-21 14:27:40 +01:00
|
|
|
PromptTemplate,
|
2022-07-05 11:31:11 +02:00
|
|
|
)
|
2022-03-15 11:17:26 +01:00
|
|
|
from haystack.modeling.infer import Inferencer, QAInferencer
|
2022-12-20 11:21:26 +01:00
|
|
|
from haystack.nodes.prompt import PromptNode, PromptModel
|
2023-02-21 14:27:40 +01:00
|
|
|
from haystack.schema import Document, FilterType
|
2022-09-23 13:26:49 +02:00
|
|
|
from haystack.utils.import_utils import _optional_component_not_installed
|
|
|
|
|
|
|
|
try:
|
|
|
|
from elasticsearch import Elasticsearch
|
|
|
|
import weaviate
|
|
|
|
except (ImportError, ModuleNotFoundError) as ie:
|
|
|
|
_optional_component_not_installed("test", "test", ie)
|
|
|
|
|
2022-07-14 19:03:33 +01:00
|
|
|
from .mocks import pinecone as pinecone_mock
|
|
|
|
|
2020-05-04 18:00:07 +02:00
|
|
|
|
2022-01-14 13:48:58 +01:00
|
|
|
# To manually run the tests with default PostgreSQL instead of SQLite, switch the lines below
|
|
|
|
SQL_TYPE = "sqlite"
|
2022-02-03 13:43:18 +01:00
|
|
|
SAMPLES_PATH = Path(__file__).parent / "samples"
|
2022-01-25 20:36:28 +01:00
|
|
|
DC_API_ENDPOINT = "https://DC_API/v1"
|
|
|
|
DC_TEST_INDEX = "document_retrieval_1"
|
|
|
|
DC_API_KEY = "NO_KEY"
|
|
|
|
MOCK_DC = True
|
|
|
|
|
2022-08-12 09:27:56 +01:00
|
|
|
# Set metadata fields used during testing for PineconeDocumentStore meta_config
|
|
|
|
META_FIELDS = [
|
|
|
|
"meta_field",
|
|
|
|
"name",
|
|
|
|
"date_field",
|
|
|
|
"numeric_field",
|
|
|
|
"f1",
|
|
|
|
"f3",
|
|
|
|
"meta_id",
|
|
|
|
"meta_field_for_count",
|
|
|
|
"meta_key_1",
|
|
|
|
"meta_key_2",
|
|
|
|
]
|
|
|
|
|
2022-03-21 11:58:51 +01:00
|
|
|
# Disable telemetry reports when running tests
|
|
|
|
posthog.disabled = True
|
|
|
|
|
2022-03-31 12:36:45 +02:00
|
|
|
# Cache requests (e.g. huggingface model) to circumvent load protection
|
|
|
|
# See https://requests-cache.readthedocs.io/en/stable/user_guide/filtering.html
|
|
|
|
requests_cache.install_cache(urls_expire_after={"huggingface.co": timedelta(hours=1), "*": requests_cache.DO_NOT_CACHE})
|
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
|
2023-02-27 09:55:03 +01:00
|
|
|
def fail_at_version(target_major, target_minor):
|
2023-01-16 15:36:14 +01:00
|
|
|
"""
|
|
|
|
Reminder to remove deprecated features.
|
2023-02-27 09:55:03 +01:00
|
|
|
If you're using this fixture please open an issue in the repo to keep track
|
|
|
|
of the deprecated feature that must be removed.
|
|
|
|
After opening the issue assign it to the target version milestone, if the
|
|
|
|
milestone doesn't exist either create it or notify someone that has permissions
|
|
|
|
to do so.
|
|
|
|
This way will be assured that the feature is actually removed for that release.
|
|
|
|
This will fail tests if the current major and/or minor version is equal or greater
|
|
|
|
of target_major and/or target_minor.
|
|
|
|
If the current version has `rc0` set the test won't fail but only issue a warning, this
|
|
|
|
is done because we use `rc0` to mark the development version in `main`. If we wouldn't
|
|
|
|
do this tests would continuosly fail in main.
|
2023-01-16 15:36:14 +01:00
|
|
|
|
|
|
|
```python
|
|
|
|
from ..conftest import fail_at_version
|
|
|
|
|
|
|
|
@fail_at_version(1, 10) # Will fail once Haystack version is greater than or equal to 1.10
|
|
|
|
def test_test():
|
|
|
|
assert True
|
|
|
|
```
|
|
|
|
"""
|
|
|
|
|
|
|
|
def decorator(function):
|
2023-02-27 09:55:03 +01:00
|
|
|
(current_major, current_minor) = [int(num) for num in haystack_version.split(".")[:2]]
|
|
|
|
current_rc = int(haystack_version.split("rc")[1]) if "rc" in haystack_version else -1
|
2023-01-16 15:36:14 +01:00
|
|
|
|
|
|
|
@wraps(function)
|
|
|
|
def wrapper(*args, **kwargs):
|
2023-02-27 09:55:03 +01:00
|
|
|
if current_major > target_major or (current_major == target_major and current_minor >= target_minor):
|
|
|
|
message = f"This feature is marked for removal in v{target_major}.{target_minor}"
|
|
|
|
if current_rc == 0:
|
|
|
|
warnings.warn(message)
|
|
|
|
else:
|
|
|
|
pytest.fail(reason=message)
|
2023-01-16 15:36:14 +01:00
|
|
|
return_value = function(*args, **kwargs)
|
|
|
|
return return_value
|
|
|
|
|
|
|
|
return wrapper
|
|
|
|
|
|
|
|
return decorator
|
|
|
|
|
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
def pytest_collection_modifyitems(config, items):
|
2022-06-07 09:23:03 +02:00
|
|
|
# add pytest markers for tests that are not explicitly marked but include some keywords
|
|
|
|
name_to_markers = {
|
|
|
|
"generator": [pytest.mark.generator],
|
|
|
|
"summarizer": [pytest.mark.summarizer],
|
|
|
|
"tika": [pytest.mark.tika, pytest.mark.integration],
|
|
|
|
"parsr": [pytest.mark.parsr, pytest.mark.integration],
|
|
|
|
"ocr": [pytest.mark.ocr, pytest.mark.integration],
|
|
|
|
"elasticsearch": [pytest.mark.elasticsearch],
|
|
|
|
"faiss": [pytest.mark.faiss],
|
2022-11-15 09:54:55 +01:00
|
|
|
"milvus": [pytest.mark.milvus],
|
2022-06-07 09:23:03 +02:00
|
|
|
"weaviate": [pytest.mark.weaviate],
|
|
|
|
"pinecone": [pytest.mark.pinecone],
|
|
|
|
# FIXME GraphDB can't be treated as a regular docstore, it fails most of their tests
|
|
|
|
"graphdb": [pytest.mark.integration],
|
|
|
|
}
|
2020-10-30 18:06:02 +01:00
|
|
|
for item in items:
|
2022-06-07 09:23:03 +02:00
|
|
|
for name, markers in name_to_markers.items():
|
|
|
|
if name in item.nodeid.lower():
|
|
|
|
for marker in markers:
|
|
|
|
item.add_marker(marker)
|
2020-10-30 18:06:02 +01:00
|
|
|
|
2021-09-27 10:52:07 +02:00
|
|
|
# if the cli argument "--document_store_type" is used, we want to skip all tests that have markers of other docstores
|
|
|
|
# Example: pytest -v test_document_store.py --document_store_type="memory" => skip all tests marked with "elasticsearch"
|
|
|
|
document_store_types_to_run = config.getoption("--document_store_type")
|
2022-06-07 09:23:03 +02:00
|
|
|
document_store_types_to_run = [docstore.strip() for docstore in document_store_types_to_run.split(",")]
|
2021-09-28 16:38:21 +02:00
|
|
|
keywords = []
|
2022-02-24 17:43:38 +01:00
|
|
|
|
2021-09-28 16:38:21 +02:00
|
|
|
for i in item.keywords:
|
|
|
|
if "-" in i:
|
|
|
|
keywords.extend(i.split("-"))
|
|
|
|
else:
|
|
|
|
keywords.append(i)
|
2022-09-22 17:46:49 +02:00
|
|
|
|
|
|
|
required_doc_store = infer_required_doc_store(item, keywords)
|
|
|
|
|
|
|
|
if required_doc_store and required_doc_store not in document_store_types_to_run:
|
|
|
|
skip_docstore = pytest.mark.skip(
|
|
|
|
reason=f'{required_doc_store} is disabled. Enable via pytest --document_store_type="{required_doc_store}"'
|
|
|
|
)
|
|
|
|
item.add_marker(skip_docstore)
|
|
|
|
|
2020-10-30 18:06:02 +01:00
|
|
|
|
2022-09-22 17:46:49 +02:00
|
|
|
def infer_required_doc_store(item, keywords):
|
|
|
|
# assumption: a test runs only with one document_store
|
|
|
|
# if there are multiple docstore markers, we apply the following heuristics:
|
|
|
|
# 1. if the test was parameterized, we use the the parameter
|
|
|
|
# 2. if the test name contains the docstore name, we use that
|
|
|
|
# 3. use an arbitrary one by calling set.pop()
|
|
|
|
required_doc_store = None
|
2022-11-15 09:54:55 +01:00
|
|
|
all_doc_stores = {"elasticsearch", "faiss", "sql", "memory", "milvus", "weaviate", "pinecone"}
|
2022-09-22 17:46:49 +02:00
|
|
|
docstore_markers = set(keywords).intersection(all_doc_stores)
|
|
|
|
if len(docstore_markers) > 1:
|
|
|
|
# if parameterized infer the docstore from the parameter
|
|
|
|
if hasattr(item, "callspec"):
|
|
|
|
for doc_store in all_doc_stores:
|
|
|
|
# callspec.id contains the parameter values of the test
|
|
|
|
if re.search(f"(^|-){doc_store}($|[-_])", item.callspec.id):
|
|
|
|
required_doc_store = doc_store
|
|
|
|
break
|
|
|
|
# if still not found, infer the docstore from the test name
|
|
|
|
if required_doc_store is None:
|
|
|
|
for doc_store in all_doc_stores:
|
|
|
|
if doc_store in item.name:
|
|
|
|
required_doc_store = doc_store
|
|
|
|
break
|
|
|
|
# if still not found or there is only one, use an arbitrary one from the markers
|
|
|
|
if required_doc_store is None:
|
|
|
|
required_doc_store = docstore_markers.pop() if docstore_markers else None
|
|
|
|
return required_doc_store
|
|
|
|
|
|
|
|
|
2022-03-15 11:17:26 +01:00
|
|
|
#
|
|
|
|
# Empty mocks, as a base for unit tests.
|
|
|
|
#
|
|
|
|
# Monkeypatch the methods you need with either a mock implementation
|
|
|
|
# or a unittest.mock.MagicMock object (https://docs.python.org/3/library/unittest.mock.html)
|
|
|
|
#
|
|
|
|
|
|
|
|
|
|
|
|
class MockNode(BaseComponent):
|
|
|
|
outgoing_edges = 1
|
|
|
|
|
|
|
|
def run(self, *a, **k):
|
|
|
|
pass
|
|
|
|
|
2022-05-11 11:11:00 +02:00
|
|
|
def run_batch(self, *a, **k):
|
|
|
|
pass
|
|
|
|
|
2022-03-15 11:17:26 +01:00
|
|
|
|
|
|
|
class MockDocumentStore(BaseDocumentStore):
|
|
|
|
outgoing_edges = 1
|
|
|
|
|
|
|
|
def _create_document_field_map(self, *a, **k):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def delete_documents(self, *a, **k):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def delete_labels(self, *a, **k):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def get_all_documents(self, *a, **k):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def get_all_documents_generator(self, *a, **k):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def get_all_labels(self, *a, **k):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def get_document_by_id(self, *a, **k):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def get_document_count(self, *a, **k):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def get_documents_by_id(self, *a, **k):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def get_label_count(self, *a, **k):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def query_by_embedding(self, *a, **k):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def write_documents(self, *a, **k):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def write_labels(self, *a, **k):
|
|
|
|
pass
|
|
|
|
|
2022-03-21 19:04:28 +01:00
|
|
|
def delete_index(self, *a, **k):
|
|
|
|
pass
|
|
|
|
|
2022-07-07 15:44:07 +02:00
|
|
|
def update_document_meta(self, *a, **kw):
|
|
|
|
pass
|
|
|
|
|
2022-03-15 11:17:26 +01:00
|
|
|
|
|
|
|
class MockRetriever(BaseRetriever):
|
|
|
|
outgoing_edges = 1
|
|
|
|
|
2023-02-21 14:27:40 +01:00
|
|
|
def retrieve(
|
|
|
|
self,
|
|
|
|
query: str,
|
|
|
|
filters: Optional[FilterType] = None,
|
|
|
|
top_k: Optional[int] = None,
|
|
|
|
index: Optional[str] = None,
|
|
|
|
headers: Optional[Dict[str, str]] = None,
|
|
|
|
scale_score: Optional[bool] = None,
|
|
|
|
document_store: Optional[BaseDocumentStore] = None,
|
|
|
|
) -> List[Document]:
|
|
|
|
return []
|
|
|
|
|
|
|
|
def retrieve_batch(
|
|
|
|
self,
|
|
|
|
queries: List[str],
|
|
|
|
filters: Optional[Union[FilterType, List[Optional[FilterType]]]] = None,
|
|
|
|
top_k: Optional[int] = None,
|
|
|
|
index: Optional[str] = None,
|
|
|
|
headers: Optional[Dict[str, str]] = None,
|
|
|
|
batch_size: Optional[int] = None,
|
|
|
|
scale_score: Optional[bool] = None,
|
|
|
|
document_store: Optional[BaseDocumentStore] = None,
|
|
|
|
) -> List[List[Document]]:
|
|
|
|
return [[]]
|
2022-05-11 11:11:00 +02:00
|
|
|
|
2022-03-15 11:17:26 +01:00
|
|
|
|
2022-09-21 14:53:42 +02:00
|
|
|
class MockSeq2SegGenerator(BaseGenerator):
|
|
|
|
def predict(self, query: str, documents: List[Document], top_k: Optional[int]) -> Dict:
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
class MockSummarizer(BaseSummarizer):
|
|
|
|
def predict_batch(
|
2022-12-29 11:07:47 +01:00
|
|
|
self, documents: Union[List[Document], List[List[Document]]], batch_size: Optional[int] = None
|
2022-09-21 14:53:42 +02:00
|
|
|
) -> Union[List[Document], List[List[Document]]]:
|
|
|
|
pass
|
|
|
|
|
2022-12-29 11:07:47 +01:00
|
|
|
def predict(self, documents: List[Document]) -> List[Document]:
|
2022-09-21 14:53:42 +02:00
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
class MockTranslator(BaseTranslator):
|
|
|
|
def translate(
|
|
|
|
self,
|
|
|
|
results: List[Dict[str, Any]] = None,
|
|
|
|
query: Optional[str] = None,
|
|
|
|
documents: Optional[Union[List[Document], List[Answer], List[str], List[Dict[str, Any]]]] = None,
|
|
|
|
dict_key: Optional[str] = None,
|
|
|
|
) -> Union[str, List[Document], List[Answer], List[str], List[Dict[str, Any]]]:
|
|
|
|
pass
|
|
|
|
|
|
|
|
def translate_batch(
|
|
|
|
self,
|
|
|
|
queries: Optional[List[str]] = None,
|
|
|
|
documents: Optional[Union[List[Document], List[Answer], List[List[Document]], List[List[Answer]]]] = None,
|
|
|
|
batch_size: Optional[int] = None,
|
|
|
|
) -> List[Union[str, List[Document], List[Answer], List[str], List[Dict[str, Any]]]]:
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
2022-09-21 19:08:54 +02:00
|
|
|
class MockDenseRetriever(MockRetriever, DenseRetriever):
|
2022-05-04 17:39:06 +02:00
|
|
|
def __init__(self, document_store: BaseDocumentStore, embedding_dim: int = 768):
|
|
|
|
self.embedding_dim = embedding_dim
|
|
|
|
self.document_store = document_store
|
|
|
|
|
2022-09-21 19:08:54 +02:00
|
|
|
def embed_queries(self, queries):
|
|
|
|
return np.random.rand(len(queries), self.embedding_dim)
|
2022-05-04 17:39:06 +02:00
|
|
|
|
2022-09-21 19:08:54 +02:00
|
|
|
def embed_documents(self, documents):
|
|
|
|
return np.random.rand(len(documents), self.embedding_dim)
|
2022-05-04 17:39:06 +02:00
|
|
|
|
|
|
|
|
2022-09-21 14:53:42 +02:00
|
|
|
class MockQuestionGenerator(QuestionGenerator):
|
|
|
|
def __init__(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def predict(self, query: str, documents: List[Document], top_k: Optional[int]) -> Dict:
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
2022-03-15 11:17:26 +01:00
|
|
|
class MockReader(BaseReader):
|
|
|
|
outgoing_edges = 1
|
|
|
|
|
|
|
|
def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def predict_batch(self, query_doc_list: List[dict], top_k: Optional[int] = None, batch_size: Optional[int] = None):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
2023-02-21 14:27:40 +01:00
|
|
|
class MockPromptNode(PromptNode):
|
|
|
|
def __init__(self):
|
|
|
|
self.default_prompt_template = None
|
|
|
|
|
|
|
|
def prompt(self, prompt_template: Optional[Union[str, PromptTemplate]], *args, **kwargs) -> List[str]:
|
|
|
|
return [""]
|
|
|
|
|
|
|
|
def get_prompt_template(self, prompt_template_name: str) -> PromptTemplate:
|
|
|
|
if prompt_template_name == "think-step-by-step":
|
|
|
|
return PromptTemplate(
|
|
|
|
name="think-step-by-step",
|
|
|
|
prompt_text="You are a helpful and knowledgeable agent. To achieve your goal of answering complex questions "
|
|
|
|
"correctly, you have access to the following tools:\n\n"
|
|
|
|
"$tool_names_with_descriptions\n\n"
|
|
|
|
"To answer questions, you'll need to go through multiple steps involving step-by-step thinking and "
|
|
|
|
"selecting appropriate tools and their inputs; tools will respond with observations. When you are ready "
|
|
|
|
"for a final answer, respond with the `Final Answer:`\n\n"
|
|
|
|
"Use the following format:\n\n"
|
|
|
|
"Question: the question to be answered\n"
|
|
|
|
"Thought: Reason if you have the final answer. If yes, answer the question. If not, find out the missing information needed to answer it.\n"
|
|
|
|
"Tool: [$tool_names]\n"
|
|
|
|
"Tool Input: the input for the tool\n"
|
|
|
|
"Observation: the tool will respond with the result\n"
|
|
|
|
"...\n"
|
|
|
|
"Final Answer: the final answer to the question, make it short (1-5 words)\n\n"
|
|
|
|
"Thought, Tool, Tool Input, and Observation steps can be repeated multiple times, but sometimes we can find an answer in the first pass\n"
|
|
|
|
"---\n\n"
|
|
|
|
"Question: $query\n"
|
|
|
|
"Thought: Let's think step-by-step, I first need to $generated_text",
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
return PromptTemplate(name="", prompt_text="")
|
|
|
|
|
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
#
|
|
|
|
# Document collections
|
|
|
|
#
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def docs_all_formats() -> List[Union[Document, Dict[str, Any]]]:
|
|
|
|
return [
|
|
|
|
# metafield at the top level for backward compatibility
|
|
|
|
{
|
|
|
|
"content": "My name is Paul and I live in New York",
|
|
|
|
"meta_field": "test2",
|
|
|
|
"name": "filename2",
|
|
|
|
"date_field": "2019-10-01",
|
|
|
|
"numeric_field": 5.0,
|
|
|
|
},
|
|
|
|
# "dict" format
|
|
|
|
{
|
|
|
|
"content": "My name is Carla and I live in Berlin",
|
|
|
|
"meta": {"meta_field": "test1", "name": "filename1", "date_field": "2020-03-01", "numeric_field": 5.5},
|
|
|
|
},
|
|
|
|
# Document object
|
|
|
|
Document(
|
|
|
|
content="My name is Christelle and I live in Paris",
|
|
|
|
meta={"meta_field": "test3", "name": "filename3", "date_field": "2018-10-01", "numeric_field": 4.5},
|
|
|
|
),
|
|
|
|
Document(
|
|
|
|
content="My name is Camila and I live in Madrid",
|
|
|
|
meta={"meta_field": "test4", "name": "filename4", "date_field": "2021-02-01", "numeric_field": 3.0},
|
|
|
|
),
|
|
|
|
Document(
|
|
|
|
content="My name is Matteo and I live in Rome",
|
|
|
|
meta={"meta_field": "test5", "name": "filename5", "date_field": "2019-01-01", "numeric_field": 0.0},
|
|
|
|
),
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def docs(docs_all_formats) -> List[Document]:
|
|
|
|
return [Document.from_dict(doc) if isinstance(doc, dict) else doc for doc in docs_all_formats]
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def docs_with_ids(docs) -> List[Document]:
|
|
|
|
# Should be already sorted
|
|
|
|
uuids = [
|
|
|
|
UUID("190a2421-7e48-4a49-a639-35a86e202dfb"),
|
|
|
|
UUID("20ff1706-cb55-4704-8ae8-a3459774c8dc"),
|
|
|
|
UUID("5078722f-07ae-412d-8ccb-b77224c4bacb"),
|
|
|
|
UUID("81d8ca45-fad1-4d1c-8028-d818ef33d755"),
|
|
|
|
UUID("f985789f-1673-4d8f-8d5f-2b8d3a9e8e23"),
|
|
|
|
]
|
|
|
|
uuids.sort()
|
|
|
|
for doc, uuid in zip(docs, uuids):
|
|
|
|
doc.id = str(uuid)
|
|
|
|
return docs
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def docs_with_random_emb(docs) -> List[Document]:
|
|
|
|
for doc in docs:
|
|
|
|
doc.embedding = np.random.random([768])
|
|
|
|
return docs
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def docs_with_true_emb():
|
|
|
|
return [
|
|
|
|
Document(
|
|
|
|
content="The capital of Germany is the city state of Berlin.",
|
|
|
|
embedding=np.loadtxt(SAMPLES_PATH / "embeddings" / "embedding_1.txt"),
|
|
|
|
),
|
|
|
|
Document(
|
|
|
|
content="Berlin is the capital and largest city of Germany by both area and population.",
|
|
|
|
embedding=np.loadtxt(SAMPLES_PATH / "embeddings" / "embedding_2.txt"),
|
|
|
|
),
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
2021-12-22 17:20:23 +01:00
|
|
|
def gc_cleanup(request):
|
|
|
|
"""
|
|
|
|
Run garbage collector between tests in order to reduce memory footprint for CI.
|
|
|
|
"""
|
|
|
|
yield
|
|
|
|
gc.collect()
|
|
|
|
|
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.fixture
|
2022-02-03 13:43:18 +01:00
|
|
|
def deepset_cloud_fixture():
|
2022-01-25 20:36:28 +01:00
|
|
|
if MOCK_DC:
|
|
|
|
responses.add(
|
2022-02-03 13:43:18 +01:00
|
|
|
method=responses.GET,
|
2022-01-25 20:36:28 +01:00
|
|
|
url=f"{DC_API_ENDPOINT}/workspaces/default/indexes/{DC_TEST_INDEX}",
|
|
|
|
match=[responses.matchers.header_matcher({"authorization": f"Bearer {DC_API_KEY}"})],
|
2022-02-03 13:43:18 +01:00
|
|
|
json={"indexing": {"status": "INDEXED", "pending_file_count": 0, "total_file_count": 31}},
|
|
|
|
status=200,
|
|
|
|
)
|
2022-05-10 15:21:35 +02:00
|
|
|
responses.add(
|
|
|
|
method=responses.GET,
|
|
|
|
url=f"{DC_API_ENDPOINT}/workspaces/default/pipelines",
|
|
|
|
match=[responses.matchers.header_matcher({"authorization": f"Bearer {DC_API_KEY}"})],
|
|
|
|
json={
|
|
|
|
"data": [
|
|
|
|
{
|
|
|
|
"name": DC_TEST_INDEX,
|
|
|
|
"status": "DEPLOYED",
|
|
|
|
"indexing": {"status": "INDEXED", "pending_file_count": 0, "total_file_count": 31},
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"has_more": False,
|
|
|
|
"total": 1,
|
|
|
|
},
|
|
|
|
)
|
2022-01-25 20:36:28 +01:00
|
|
|
else:
|
|
|
|
responses.add_passthru(DC_API_ENDPOINT)
|
|
|
|
|
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.fixture
|
2020-10-30 18:06:02 +01:00
|
|
|
def rag_generator():
|
2022-03-15 11:17:26 +01:00
|
|
|
return RAGenerator(model_name_or_path="facebook/rag-token-nq", generator_type="token", max_length=20)
|
2020-10-30 18:06:02 +01:00
|
|
|
|
|
|
|
|
2022-07-08 13:59:27 +02:00
|
|
|
@pytest.fixture
|
|
|
|
def openai_generator():
|
2023-03-02 09:55:09 +01:00
|
|
|
azure_conf = haystack_azure_conf()
|
|
|
|
if azure_conf:
|
|
|
|
return OpenAIAnswerGenerator(
|
|
|
|
api_key=azure_conf["api_key"],
|
|
|
|
azure_base_url=azure_conf["azure_base_url"],
|
|
|
|
azure_deployment_name=azure_conf["azure_deployment_name"],
|
|
|
|
model="text-babbage-001",
|
|
|
|
top_k=1,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
return OpenAIAnswerGenerator(api_key=os.environ.get("OPENAI_API_KEY", ""), model="text-babbage-001", top_k=1)
|
2022-07-08 13:59:27 +02:00
|
|
|
|
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.fixture
|
2021-07-26 17:20:43 +02:00
|
|
|
def question_generator():
|
|
|
|
return QuestionGenerator(model_name_or_path="valhalla/t5-small-e2e-qg")
|
|
|
|
|
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.fixture
|
2022-03-08 15:11:41 +01:00
|
|
|
def lfqa_generator(request):
|
|
|
|
return Seq2SeqGenerator(model_name_or_path=request.param, min_length=100, max_length=200)
|
2021-06-14 17:53:43 +02:00
|
|
|
|
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.fixture
|
2021-08-17 10:27:11 +02:00
|
|
|
def reader_without_normalized_scores():
|
|
|
|
return FARMReader(
|
2022-10-26 19:04:18 +02:00
|
|
|
model_name_or_path="deepset/bert-medium-squad2-distilled",
|
2021-08-17 10:27:11 +02:00
|
|
|
use_gpu=False,
|
|
|
|
top_k_per_sample=5,
|
|
|
|
num_processes=0,
|
2022-02-03 13:43:18 +01:00
|
|
|
use_confidence_scores=False,
|
2021-08-17 10:27:11 +02:00
|
|
|
)
|
|
|
|
|
2021-09-27 10:52:07 +02:00
|
|
|
|
2022-10-26 19:04:18 +02:00
|
|
|
@pytest.fixture(params=["farm", "transformers"], scope="module")
|
2020-12-03 10:27:06 +01:00
|
|
|
def reader(request):
|
2020-07-10 10:54:56 +02:00
|
|
|
if request.param == "farm":
|
2020-12-03 10:27:06 +01:00
|
|
|
return FARMReader(
|
2022-10-26 19:04:18 +02:00
|
|
|
model_name_or_path="deepset/bert-medium-squad2-distilled",
|
2020-12-03 10:27:06 +01:00
|
|
|
use_gpu=False,
|
|
|
|
top_k_per_sample=5,
|
2022-02-03 13:43:18 +01:00
|
|
|
num_processes=0,
|
2020-12-03 10:27:06 +01:00
|
|
|
)
|
2020-07-10 10:54:56 +02:00
|
|
|
if request.param == "transformers":
|
2020-12-03 10:27:06 +01:00
|
|
|
return TransformersReader(
|
2022-10-26 19:04:18 +02:00
|
|
|
model_name_or_path="deepset/bert-medium-squad2-distilled",
|
|
|
|
tokenizer="deepset/bert-medium-squad2-distilled",
|
2022-02-03 13:43:18 +01:00
|
|
|
use_gpu=-1,
|
2020-12-03 10:27:06 +01:00
|
|
|
)
|
2020-07-10 10:54:56 +02:00
|
|
|
|
2021-10-15 16:34:48 +02:00
|
|
|
|
2022-10-26 20:57:28 +02:00
|
|
|
@pytest.fixture(params=["tapas_small", "tapas_base", "tapas_scored", "rci"])
|
2022-12-07 07:30:49 -08:00
|
|
|
def table_reader_and_param(request):
|
2022-10-26 20:57:28 +02:00
|
|
|
if request.param == "tapas_small":
|
2022-12-07 07:30:49 -08:00
|
|
|
return TableReader(model_name_or_path="google/tapas-small-finetuned-wtq"), request.param
|
2022-10-26 20:57:28 +02:00
|
|
|
elif request.param == "tapas_base":
|
2022-12-07 07:30:49 -08:00
|
|
|
return TableReader(model_name_or_path="google/tapas-base-finetuned-wtq"), request.param
|
2022-10-26 20:57:28 +02:00
|
|
|
elif request.param == "tapas_scored":
|
2022-12-07 07:30:49 -08:00
|
|
|
return TableReader(model_name_or_path="deepset/tapas-large-nq-hn-reader"), request.param
|
2022-01-03 16:59:24 +01:00
|
|
|
elif request.param == "rci":
|
2022-12-07 07:30:49 -08:00
|
|
|
return (
|
|
|
|
RCIReader(
|
|
|
|
row_model_name_or_path="michaelrglass/albert-base-rci-wikisql-row",
|
|
|
|
column_model_name_or_path="michaelrglass/albert-base-rci-wikisql-col",
|
|
|
|
),
|
|
|
|
request.param,
|
2022-02-03 13:43:18 +01:00
|
|
|
)
|
2021-10-15 16:34:48 +02:00
|
|
|
|
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.fixture
|
2021-12-06 17:13:57 +01:00
|
|
|
def ranker_two_logits():
|
2022-03-07 19:25:33 +01:00
|
|
|
return SentenceTransformersRanker(model_name_or_path="deepset/gbert-base-germandpr-reranking")
|
2021-12-06 17:13:57 +01:00
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.fixture
|
2021-09-28 16:34:24 +02:00
|
|
|
def ranker():
|
2022-03-07 19:25:33 +01:00
|
|
|
return SentenceTransformersRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-12-v2")
|
2021-07-13 21:44:26 +02:00
|
|
|
|
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.fixture
|
2021-10-01 11:22:56 +02:00
|
|
|
def document_classifier():
|
|
|
|
return TransformersDocumentClassifier(
|
2022-09-21 13:16:03 +02:00
|
|
|
model_name_or_path="bhadresh-savani/distilbert-base-uncased-emotion", use_gpu=False, top_k=2
|
2021-10-01 11:22:56 +02:00
|
|
|
)
|
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.fixture
|
2021-10-01 11:22:56 +02:00
|
|
|
def zero_shot_document_classifier():
|
|
|
|
return TransformersDocumentClassifier(
|
|
|
|
model_name_or_path="cross-encoder/nli-distilroberta-base",
|
2021-11-09 18:43:00 +01:00
|
|
|
use_gpu=False,
|
2021-10-01 11:22:56 +02:00
|
|
|
task="zero-shot-classification",
|
2022-02-03 13:43:18 +01:00
|
|
|
labels=["negative", "positive"],
|
2021-10-01 11:22:56 +02:00
|
|
|
)
|
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.fixture
|
2021-11-09 18:43:00 +01:00
|
|
|
def batched_document_classifier():
|
|
|
|
return TransformersDocumentClassifier(
|
2022-02-03 13:43:18 +01:00
|
|
|
model_name_or_path="bhadresh-savani/distilbert-base-uncased-emotion", use_gpu=False, batch_size=16
|
2021-11-09 18:43:00 +01:00
|
|
|
)
|
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.fixture
|
2021-11-09 18:43:00 +01:00
|
|
|
def indexing_document_classifier():
|
|
|
|
return TransformersDocumentClassifier(
|
|
|
|
model_name_or_path="bhadresh-savani/distilbert-base-uncased-emotion",
|
|
|
|
use_gpu=False,
|
|
|
|
batch_size=16,
|
2022-02-03 13:43:18 +01:00
|
|
|
classification_field="class_field",
|
2021-11-09 18:43:00 +01:00
|
|
|
)
|
2021-10-01 11:22:56 +02:00
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
|
2022-11-22 09:24:52 +01:00
|
|
|
@pytest.fixture(params=["es_filter_only", "bm25", "dpr", "embedding", "tfidf", "table_text_retriever"])
|
2020-10-14 16:15:04 +02:00
|
|
|
def retriever(request, document_store):
|
|
|
|
return get_retriever(request.param, document_store)
|
|
|
|
|
|
|
|
|
2021-10-13 14:23:23 +02:00
|
|
|
# @pytest.fixture(params=["es_filter_only", "elasticsearch", "dpr", "embedding", "tfidf"])
|
|
|
|
@pytest.fixture(params=["tfidf"])
|
2020-10-14 16:15:04 +02:00
|
|
|
def retriever_with_docs(request, document_store_with_docs):
|
|
|
|
return get_retriever(request.param, document_store_with_docs)
|
|
|
|
|
|
|
|
|
|
|
|
def get_retriever(retriever_type, document_store):
|
|
|
|
if retriever_type == "dpr":
|
2022-02-03 13:43:18 +01:00
|
|
|
retriever = DensePassageRetriever(
|
|
|
|
document_store=document_store,
|
|
|
|
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
|
|
|
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
|
|
|
use_gpu=False,
|
|
|
|
embed_title=True,
|
|
|
|
)
|
2022-07-05 11:31:11 +02:00
|
|
|
elif retriever_type == "mdr":
|
|
|
|
retriever = MultihopEmbeddingRetriever(
|
|
|
|
document_store=document_store,
|
|
|
|
embedding_model="deutschmann/mdr_roberta_q_encoder", # or "facebook/dpr-ctx_encoder-single-nq-base"
|
|
|
|
use_gpu=False,
|
|
|
|
)
|
2020-10-14 16:15:04 +02:00
|
|
|
elif retriever_type == "tfidf":
|
2021-02-12 14:57:06 +01:00
|
|
|
retriever = TfidfRetriever(document_store=document_store)
|
2020-10-14 16:15:04 +02:00
|
|
|
elif retriever_type == "embedding":
|
2020-10-30 18:06:02 +01:00
|
|
|
retriever = EmbeddingRetriever(
|
2022-02-03 13:43:18 +01:00
|
|
|
document_store=document_store, embedding_model="deepset/sentence_bert", use_gpu=False
|
2020-10-30 18:06:02 +01:00
|
|
|
)
|
2022-06-02 16:12:47 +02:00
|
|
|
elif retriever_type == "embedding_sbert":
|
|
|
|
retriever = EmbeddingRetriever(
|
|
|
|
document_store=document_store,
|
|
|
|
embedding_model="sentence-transformers/msmarco-distilbert-base-tas-b",
|
|
|
|
model_format="sentence_transformers",
|
|
|
|
use_gpu=False,
|
|
|
|
)
|
2021-06-14 17:53:43 +02:00
|
|
|
elif retriever_type == "retribert":
|
2022-02-03 13:43:18 +01:00
|
|
|
retriever = EmbeddingRetriever(
|
2022-06-02 15:05:29 +02:00
|
|
|
document_store=document_store, embedding_model="yjernite/retribert-base-uncased", use_gpu=False
|
2022-02-03 13:43:18 +01:00
|
|
|
)
|
2022-10-14 15:01:03 +02:00
|
|
|
elif retriever_type == "openai":
|
|
|
|
retriever = EmbeddingRetriever(
|
|
|
|
document_store=document_store,
|
2023-03-06 09:37:20 -03:00
|
|
|
embedding_model="text-embedding-ada-002",
|
2022-10-14 15:01:03 +02:00
|
|
|
use_gpu=False,
|
2023-03-06 09:37:20 -03:00
|
|
|
api_key=os.getenv("OPENAI_API_KEY"),
|
|
|
|
)
|
|
|
|
elif retriever_type == "azure":
|
|
|
|
retriever = EmbeddingRetriever(
|
|
|
|
document_store=document_store,
|
|
|
|
embedding_model="text-embedding-ada-002",
|
|
|
|
use_gpu=False,
|
|
|
|
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
|
|
|
azure_base_url=os.getenv("AZURE_OPENAI_BASE_URL"),
|
|
|
|
azure_deployment_name=os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME_EMBED"),
|
2022-10-14 15:01:03 +02:00
|
|
|
)
|
2022-10-25 17:52:29 +02:00
|
|
|
elif retriever_type == "cohere":
|
|
|
|
retriever = EmbeddingRetriever(
|
|
|
|
document_store=document_store,
|
|
|
|
embedding_model="small",
|
|
|
|
use_gpu=False,
|
|
|
|
api_key=os.environ.get("COHERE_API_KEY", ""),
|
|
|
|
)
|
2022-03-08 15:11:41 +01:00
|
|
|
elif retriever_type == "dpr_lfqa":
|
|
|
|
retriever = DensePassageRetriever(
|
|
|
|
document_store=document_store,
|
|
|
|
query_embedding_model="vblagoje/dpr-question_encoder-single-lfqa-wiki",
|
|
|
|
passage_embedding_model="vblagoje/dpr-ctx_encoder-single-lfqa-wiki",
|
|
|
|
use_gpu=False,
|
|
|
|
embed_title=True,
|
|
|
|
)
|
2022-11-22 09:24:52 +01:00
|
|
|
elif retriever_type == "bm25":
|
2022-04-26 16:09:39 +02:00
|
|
|
retriever = BM25Retriever(document_store=document_store)
|
2020-10-14 16:15:04 +02:00
|
|
|
elif retriever_type == "es_filter_only":
|
2022-04-29 10:16:02 +02:00
|
|
|
retriever = FilterRetriever(document_store=document_store)
|
2021-10-25 12:27:02 +02:00
|
|
|
elif retriever_type == "table_text_retriever":
|
2022-02-03 13:43:18 +01:00
|
|
|
retriever = TableTextRetriever(
|
|
|
|
document_store=document_store,
|
|
|
|
query_embedding_model="deepset/bert-small-mm_retrieval-question_encoder",
|
|
|
|
passage_embedding_model="deepset/bert-small-mm_retrieval-passage_encoder",
|
|
|
|
table_embedding_model="deepset/bert-small-mm_retrieval-table_encoder",
|
|
|
|
use_gpu=False,
|
|
|
|
)
|
2020-10-14 16:15:04 +02:00
|
|
|
else:
|
|
|
|
raise Exception(f"No retriever fixture for '{retriever_type}'")
|
|
|
|
|
|
|
|
return retriever
|
2020-12-14 18:15:44 +01:00
|
|
|
|
|
|
|
|
2022-07-14 19:03:33 +01:00
|
|
|
# FIXME Fix this in the docstore tests refactoring
|
|
|
|
from inspect import getmembers, isclass, isfunction
|
|
|
|
|
|
|
|
|
|
|
|
def mock_pinecone(monkeypatch):
|
|
|
|
for fname, function in getmembers(pinecone_mock, isfunction):
|
|
|
|
monkeypatch.setattr(f"pinecone.{fname}", function, raising=False)
|
|
|
|
for cname, class_ in getmembers(pinecone_mock, isclass):
|
|
|
|
monkeypatch.setattr(f"pinecone.{cname}", class_, raising=False)
|
|
|
|
|
|
|
|
|
2022-11-15 09:54:55 +01:00
|
|
|
@pytest.fixture(params=["elasticsearch", "faiss", "memory", "milvus", "weaviate", "pinecone"])
|
2022-07-14 19:03:33 +01:00
|
|
|
def document_store_with_docs(request, docs, tmp_path, monkeypatch):
|
|
|
|
if request.param == "pinecone":
|
|
|
|
mock_pinecone(monkeypatch)
|
|
|
|
|
2022-01-12 19:28:20 +01:00
|
|
|
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768))
|
2022-02-03 13:43:18 +01:00
|
|
|
document_store = get_document_store(
|
|
|
|
document_store_type=request.param, embedding_dim=embedding_dim.args[0], tmp_path=tmp_path
|
|
|
|
)
|
2022-06-10 18:22:48 +02:00
|
|
|
document_store.write_documents(docs)
|
2020-12-14 18:15:44 +01:00
|
|
|
yield document_store
|
2022-04-26 19:06:30 +02:00
|
|
|
document_store.delete_index(document_store.index)
|
2020-12-14 18:15:44 +01:00
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
|
2021-06-22 16:08:23 +02:00
|
|
|
@pytest.fixture
|
2022-07-14 19:03:33 +01:00
|
|
|
def document_store(request, tmp_path, monkeypatch: pytest.MonkeyPatch):
|
|
|
|
if request.param == "pinecone":
|
|
|
|
mock_pinecone(monkeypatch)
|
|
|
|
|
2022-01-10 17:10:32 +00:00
|
|
|
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768))
|
2022-02-03 13:43:18 +01:00
|
|
|
document_store = get_document_store(
|
|
|
|
document_store_type=request.param, embedding_dim=embedding_dim.args[0], tmp_path=tmp_path
|
|
|
|
)
|
2020-12-14 18:15:44 +01:00
|
|
|
yield document_store
|
2022-04-26 19:06:30 +02:00
|
|
|
document_store.delete_index(document_store.index)
|
2022-03-21 22:24:09 +07:00
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
|
2022-11-15 09:54:55 +01:00
|
|
|
@pytest.fixture(params=["memory", "faiss", "milvus", "elasticsearch", "pinecone"])
|
2022-07-14 19:03:33 +01:00
|
|
|
def document_store_dot_product(request, tmp_path, monkeypatch):
|
|
|
|
if request.param == "pinecone":
|
|
|
|
mock_pinecone(monkeypatch)
|
|
|
|
|
2022-01-10 17:10:32 +00:00
|
|
|
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768))
|
2022-02-03 13:43:18 +01:00
|
|
|
document_store = get_document_store(
|
|
|
|
document_store_type=request.param,
|
|
|
|
embedding_dim=embedding_dim.args[0],
|
|
|
|
similarity="dot_product",
|
|
|
|
tmp_path=tmp_path,
|
|
|
|
)
|
2022-01-12 19:28:20 +01:00
|
|
|
yield document_store
|
2022-04-26 19:06:30 +02:00
|
|
|
document_store.delete_index(document_store.index)
|
2022-01-12 19:28:20 +01:00
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
|
2022-01-14 13:48:58 +01:00
|
|
|
@pytest.fixture
|
|
|
|
def sql_url(tmp_path):
|
2022-02-03 13:43:18 +01:00
|
|
|
return get_sql_url(tmp_path)
|
2022-01-14 13:48:58 +01:00
|
|
|
|
|
|
|
|
|
|
|
def get_sql_url(tmp_path):
|
|
|
|
if SQL_TYPE == "postgres":
|
|
|
|
return "postgresql://postgres:postgres@127.0.0.1/postgres"
|
|
|
|
else:
|
|
|
|
return f"sqlite:///{tmp_path}/haystack_test.db"
|
|
|
|
|
|
|
|
|
|
|
|
def setup_postgres():
|
|
|
|
# status = subprocess.run(["docker run --name postgres_test -d -e POSTGRES_HOST_AUTH_METHOD=trust -p 5432:5432 postgres"], shell=True)
|
|
|
|
# if status.returncode:
|
|
|
|
# logging.warning("Tried to start PostgreSQL through Docker but this failed. It is likely that there is already an existing instance running.")
|
|
|
|
# else:
|
|
|
|
# sleep(5)
|
2022-02-03 13:43:18 +01:00
|
|
|
engine = create_engine("postgresql://postgres:postgres@127.0.0.1/postgres", isolation_level="AUTOCOMMIT")
|
2022-01-14 13:48:58 +01:00
|
|
|
|
|
|
|
with engine.connect() as connection:
|
|
|
|
try:
|
2022-10-14 09:55:56 +02:00
|
|
|
connection.execute(text("DROP SCHEMA IF EXISTS public CASCADE"))
|
2022-01-14 13:48:58 +01:00
|
|
|
except Exception as e:
|
|
|
|
logging.error(e)
|
2022-02-03 13:43:18 +01:00
|
|
|
connection.execute(text("CREATE SCHEMA public;"))
|
2022-01-14 13:48:58 +01:00
|
|
|
connection.execute(text('SET SESSION idle_in_transaction_session_timeout = "1s";'))
|
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
|
2022-01-14 13:48:58 +01:00
|
|
|
def teardown_postgres():
|
2022-02-03 13:43:18 +01:00
|
|
|
engine = create_engine("postgresql://postgres:postgres@127.0.0.1/postgres", isolation_level="AUTOCOMMIT")
|
2022-01-14 13:48:58 +01:00
|
|
|
with engine.connect() as connection:
|
2022-02-03 13:43:18 +01:00
|
|
|
connection.execute(text("DROP SCHEMA public CASCADE"))
|
2022-01-14 13:48:58 +01:00
|
|
|
connection.close()
|
|
|
|
|
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
def get_document_store(
|
|
|
|
document_store_type,
|
|
|
|
tmp_path,
|
|
|
|
embedding_dim=768,
|
|
|
|
embedding_field="embedding",
|
|
|
|
index="haystack_test",
|
|
|
|
similarity: str = "cosine",
|
2022-08-11 03:48:21 -04:00
|
|
|
recreate_index: bool = True,
|
2022-02-03 13:43:18 +01:00
|
|
|
): # cosine is default similarity as dot product is not supported by Weaviate
|
2022-09-23 13:26:49 +02:00
|
|
|
document_store: BaseDocumentStore
|
2022-11-04 09:24:19 +01:00
|
|
|
if document_store_type == "memory":
|
2021-01-22 14:39:24 +01:00
|
|
|
document_store = InMemoryDocumentStore(
|
2022-02-03 13:43:18 +01:00
|
|
|
return_embedding=True,
|
|
|
|
embedding_dim=embedding_dim,
|
|
|
|
embedding_field=embedding_field,
|
|
|
|
index=index,
|
|
|
|
similarity=similarity,
|
2022-11-22 09:24:52 +01:00
|
|
|
use_bm25=True,
|
2023-03-01 11:35:10 +01:00
|
|
|
bm25_parameters={"k1": 1.2, "b": 0.75}, # parameters similar to those of Elasticsearch
|
2022-02-03 13:43:18 +01:00
|
|
|
)
|
|
|
|
|
2020-12-14 18:15:44 +01:00
|
|
|
elif document_store_type == "elasticsearch":
|
|
|
|
# make sure we start from a fresh index
|
2020-12-17 09:18:57 +01:00
|
|
|
document_store = ElasticsearchDocumentStore(
|
2022-02-03 13:43:18 +01:00
|
|
|
index=index,
|
|
|
|
return_embedding=True,
|
|
|
|
embedding_dim=embedding_dim,
|
|
|
|
embedding_field=embedding_field,
|
|
|
|
similarity=similarity,
|
2022-08-11 03:48:21 -04:00
|
|
|
recreate_index=recreate_index,
|
2020-12-17 09:18:57 +01:00
|
|
|
)
|
2022-01-14 13:48:58 +01:00
|
|
|
|
2020-12-14 18:15:44 +01:00
|
|
|
elif document_store_type == "faiss":
|
|
|
|
document_store = FAISSDocumentStore(
|
2022-01-10 17:10:32 +00:00
|
|
|
embedding_dim=embedding_dim,
|
2022-01-14 13:48:58 +01:00
|
|
|
sql_url=get_sql_url(tmp_path),
|
2021-01-21 16:00:08 +01:00
|
|
|
return_embedding=True,
|
|
|
|
embedding_field=embedding_field,
|
2021-11-01 15:42:32 +03:00
|
|
|
index=index,
|
2022-01-14 13:48:58 +01:00
|
|
|
similarity=similarity,
|
2022-02-03 13:43:18 +01:00
|
|
|
isolation_level="AUTOCOMMIT",
|
2020-12-14 18:15:44 +01:00
|
|
|
)
|
2022-01-14 13:48:58 +01:00
|
|
|
|
2022-02-24 17:43:38 +01:00
|
|
|
elif document_store_type == "milvus":
|
|
|
|
document_store = MilvusDocumentStore(
|
|
|
|
embedding_dim=embedding_dim,
|
|
|
|
sql_url=get_sql_url(tmp_path),
|
|
|
|
return_embedding=True,
|
|
|
|
embedding_field=embedding_field,
|
|
|
|
index=index,
|
|
|
|
similarity=similarity,
|
|
|
|
isolation_level="AUTOCOMMIT",
|
2022-08-11 03:48:21 -04:00
|
|
|
recreate_index=recreate_index,
|
2022-02-24 17:43:38 +01:00
|
|
|
)
|
|
|
|
|
2021-06-10 13:13:53 +05:30
|
|
|
elif document_store_type == "weaviate":
|
2022-04-26 19:06:30 +02:00
|
|
|
document_store = WeaviateDocumentStore(
|
2022-08-11 03:48:21 -04:00
|
|
|
index=index, similarity=similarity, embedding_dim=embedding_dim, recreate_index=recreate_index
|
2022-04-26 19:06:30 +02:00
|
|
|
)
|
2022-03-21 22:24:09 +07:00
|
|
|
|
|
|
|
elif document_store_type == "pinecone":
|
|
|
|
document_store = PineconeDocumentStore(
|
2022-08-24 12:27:15 +01:00
|
|
|
api_key=os.environ.get("PINECONE_API_KEY") or "fake-haystack-test-key",
|
2022-03-21 22:24:09 +07:00
|
|
|
embedding_dim=embedding_dim,
|
|
|
|
embedding_field=embedding_field,
|
|
|
|
index=index,
|
|
|
|
similarity=similarity,
|
2022-08-11 03:48:21 -04:00
|
|
|
recreate_index=recreate_index,
|
2022-08-12 09:27:56 +01:00
|
|
|
metadata_config={"indexed": META_FIELDS},
|
2022-03-21 22:24:09 +07:00
|
|
|
)
|
|
|
|
|
2022-09-23 13:26:49 +02:00
|
|
|
elif document_store_type == "opensearch_faiss":
|
|
|
|
document_store = OpenSearchDocumentStore(
|
|
|
|
index=index,
|
|
|
|
return_embedding=True,
|
|
|
|
embedding_dim=embedding_dim,
|
|
|
|
embedding_field=embedding_field,
|
|
|
|
similarity=similarity,
|
|
|
|
recreate_index=recreate_index,
|
|
|
|
port=9201,
|
|
|
|
knn_engine="faiss",
|
|
|
|
)
|
|
|
|
|
2020-12-14 18:15:44 +01:00
|
|
|
else:
|
|
|
|
raise Exception(f"No document store fixture for '{document_store_type}'")
|
|
|
|
|
|
|
|
return document_store
|
2021-09-22 16:56:51 +02:00
|
|
|
|
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.fixture
|
2021-09-22 16:56:51 +02:00
|
|
|
def adaptive_model_qa(num_processes):
|
|
|
|
"""
|
|
|
|
PyTest Fixture for a Question Answering Inferencer based on PyTorch.
|
|
|
|
"""
|
2022-10-04 14:08:23 +02:00
|
|
|
|
|
|
|
model = Inferencer.load(
|
2022-10-26 19:04:18 +02:00
|
|
|
"deepset/bert-medium-squad2-distilled",
|
2022-10-04 14:08:23 +02:00
|
|
|
task_type="question_answering",
|
|
|
|
batch_size=16,
|
|
|
|
num_processes=num_processes,
|
|
|
|
gpu=False,
|
|
|
|
)
|
|
|
|
yield model
|
2021-09-22 16:56:51 +02:00
|
|
|
|
|
|
|
# check if all workers (sub processes) are closed
|
|
|
|
current_process = psutil.Process()
|
|
|
|
children = current_process.children()
|
2022-06-07 09:23:03 +02:00
|
|
|
if len(children) != 0:
|
2022-09-19 18:18:32 +02:00
|
|
|
logging.error("Not all the subprocesses are closed! %s are still running.", len(children))
|
2021-09-22 16:56:51 +02:00
|
|
|
|
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.fixture
|
2021-09-22 16:56:51 +02:00
|
|
|
def bert_base_squad2(request):
|
|
|
|
model = QAInferencer.load(
|
2022-02-03 13:43:18 +01:00
|
|
|
"deepset/minilm-uncased-squad2",
|
|
|
|
task_type="question_answering",
|
|
|
|
batch_size=4,
|
|
|
|
num_processes=0,
|
|
|
|
multithreading_rust=False,
|
|
|
|
use_fast=True, # TODO parametrize this to test slow as well
|
2021-09-22 16:56:51 +02:00
|
|
|
)
|
|
|
|
return model
|
2022-12-20 11:21:26 +01:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def prompt_node():
|
|
|
|
return PromptNode("google/flan-t5-small", devices=["cpu"])
|
|
|
|
|
|
|
|
|
2023-03-02 09:55:09 +01:00
|
|
|
def haystack_azure_conf():
|
|
|
|
api_key = os.environ.get("AZURE_OPENAI_API_KEY", None)
|
|
|
|
azure_base_url = os.environ.get("AZURE_OPENAI_BASE_URL", None)
|
|
|
|
azure_deployment_name = os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME", None)
|
|
|
|
if api_key and azure_base_url and azure_deployment_name:
|
|
|
|
return {"api_key": api_key, "azure_base_url": azure_base_url, "azure_deployment_name": azure_deployment_name}
|
|
|
|
else:
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def haystack_openai_config(request):
|
|
|
|
if request.param == "openai":
|
|
|
|
api_key = os.environ.get("OPENAI_API_KEY", None)
|
|
|
|
if not api_key:
|
|
|
|
return {}
|
|
|
|
else:
|
2023-03-06 09:37:20 -03:00
|
|
|
return {"api_key": api_key, "embedding_model": "text-embedding-ada-002"}
|
2023-03-02 09:55:09 +01:00
|
|
|
elif request.param == "azure":
|
|
|
|
return haystack_azure_conf()
|
|
|
|
|
2023-03-06 09:37:20 -03:00
|
|
|
return {}
|
|
|
|
|
2023-03-02 09:55:09 +01:00
|
|
|
|
2022-12-20 11:21:26 +01:00
|
|
|
@pytest.fixture
|
|
|
|
def prompt_model(request):
|
|
|
|
if request.param == "openai":
|
|
|
|
api_key = os.environ.get("OPENAI_API_KEY", "KEY_NOT_FOUND")
|
|
|
|
if api_key is None or api_key == "":
|
|
|
|
api_key = "KEY_NOT_FOUND"
|
|
|
|
return PromptModel("text-davinci-003", api_key=api_key)
|
2023-03-02 09:55:09 +01:00
|
|
|
elif request.param == "azure":
|
|
|
|
api_key = os.environ.get("AZURE_OPENAI_API_KEY", "KEY_NOT_FOUND")
|
|
|
|
if api_key is None or api_key == "":
|
|
|
|
api_key = "KEY_NOT_FOUND"
|
|
|
|
return PromptModel("text-davinci-003", api_key=api_key, model_kwargs=haystack_azure_conf())
|
2022-12-20 11:21:26 +01:00
|
|
|
else:
|
|
|
|
return PromptModel("google/flan-t5-base", devices=["cpu"])
|
2023-03-02 09:55:09 +01:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def azure_conf():
|
|
|
|
return haystack_azure_conf()
|