2022-03-31 12:36:45 +02:00
|
|
|
from datetime import timedelta
|
2022-06-10 18:22:48 +02:00
|
|
|
from typing import Any, List, Optional, Dict, Union
|
2022-03-15 11:17:26 +01:00
|
|
|
|
2020-09-15 15:04:46 +02:00
|
|
|
import subprocess
|
2022-06-10 18:22:48 +02:00
|
|
|
from uuid import UUID
|
2020-05-04 18:00:07 +02:00
|
|
|
import time
|
2020-09-15 15:04:46 +02:00
|
|
|
from subprocess import run
|
|
|
|
from sys import platform
|
2021-12-22 17:20:23 +01:00
|
|
|
import gc
|
2022-01-14 13:48:58 +01:00
|
|
|
import uuid
|
|
|
|
import logging
|
2022-01-26 18:12:55 +01:00
|
|
|
from pathlib import Path
|
2022-03-21 22:24:09 +07:00
|
|
|
import os
|
|
|
|
|
2022-03-31 12:36:45 +02:00
|
|
|
import requests_cache
|
2022-01-25 20:36:28 +01:00
|
|
|
import responses
|
2022-01-14 13:48:58 +01:00
|
|
|
from sqlalchemy import create_engine, text
|
2022-03-21 11:58:51 +01:00
|
|
|
import posthog
|
2020-05-04 18:00:07 +02:00
|
|
|
|
2021-11-12 16:44:28 +01:00
|
|
|
import numpy as np
|
2021-09-22 16:56:51 +02:00
|
|
|
import psutil
|
2020-05-04 18:00:07 +02:00
|
|
|
import pytest
|
2020-09-15 15:04:46 +02:00
|
|
|
import requests
|
2021-06-14 17:53:43 +02:00
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
from haystack.nodes.base import BaseComponent
|
2022-03-15 11:17:26 +01:00
|
|
|
|
2022-02-24 17:43:38 +01:00
|
|
|
try:
|
|
|
|
from milvus import Milvus
|
|
|
|
|
|
|
|
milvus1 = True
|
|
|
|
except ImportError:
|
|
|
|
milvus1 = False
|
|
|
|
from pymilvus import utility
|
|
|
|
|
2022-01-26 18:12:55 +01:00
|
|
|
try:
|
|
|
|
from elasticsearch import Elasticsearch
|
|
|
|
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore
|
|
|
|
import weaviate
|
|
|
|
from haystack.document_stores.weaviate import WeaviateDocumentStore
|
2022-03-21 22:24:09 +07:00
|
|
|
from haystack.document_stores import MilvusDocumentStore, PineconeDocumentStore
|
2022-01-26 18:12:55 +01:00
|
|
|
from haystack.document_stores.graphdb import GraphDBKnowledgeGraph
|
|
|
|
from haystack.document_stores.faiss import FAISSDocumentStore
|
|
|
|
from haystack.document_stores.sql import SQLDocumentStore
|
|
|
|
|
|
|
|
except (ImportError, ModuleNotFoundError) as ie:
|
|
|
|
from haystack.utils.import_utils import _optional_component_not_installed
|
2022-02-03 13:43:18 +01:00
|
|
|
|
|
|
|
_optional_component_not_installed("test", "test", ie)
|
2022-01-26 18:12:55 +01:00
|
|
|
|
2022-03-15 11:17:26 +01:00
|
|
|
from haystack.document_stores import BaseDocumentStore, DeepsetCloudDocumentStore, InMemoryDocumentStore
|
2022-02-24 17:43:38 +01:00
|
|
|
|
2022-03-15 11:17:26 +01:00
|
|
|
from haystack.nodes import BaseReader, BaseRetriever
|
2022-01-26 18:12:55 +01:00
|
|
|
from haystack.nodes.answer_generator.transformers import Seq2SeqGenerator
|
2022-03-15 11:17:26 +01:00
|
|
|
from haystack.nodes.answer_generator.transformers import RAGenerator
|
2021-10-25 15:50:23 +02:00
|
|
|
from haystack.nodes.ranker import SentenceTransformersRanker
|
|
|
|
from haystack.nodes.document_classifier.transformers import TransformersDocumentClassifier
|
2022-04-29 10:16:02 +02:00
|
|
|
from haystack.nodes.retriever.sparse import FilterRetriever, BM25Retriever, TfidfRetriever
|
2022-07-05 11:31:11 +02:00
|
|
|
from haystack.nodes.retriever.dense import (
|
|
|
|
DensePassageRetriever,
|
|
|
|
EmbeddingRetriever,
|
|
|
|
MultihopEmbeddingRetriever,
|
|
|
|
TableTextRetriever,
|
|
|
|
)
|
2022-01-26 18:12:55 +01:00
|
|
|
from haystack.nodes.reader.farm import FARMReader
|
|
|
|
from haystack.nodes.reader.transformers import TransformersReader
|
|
|
|
from haystack.nodes.reader.table import TableReader, RCIReader
|
|
|
|
from haystack.nodes.summarizer.transformers import TransformersSummarizer
|
|
|
|
from haystack.nodes.translator import TransformersTranslator
|
|
|
|
from haystack.nodes.question_generator import QuestionGenerator
|
2020-07-10 10:54:56 +02:00
|
|
|
|
2022-03-15 11:17:26 +01:00
|
|
|
from haystack.modeling.infer import Inferencer, QAInferencer
|
|
|
|
|
|
|
|
from haystack.schema import Document
|
|
|
|
|
2020-05-04 18:00:07 +02:00
|
|
|
|
2022-01-14 13:48:58 +01:00
|
|
|
# To manually run the tests with default PostgreSQL instead of SQLite, switch the lines below
|
|
|
|
SQL_TYPE = "sqlite"
|
|
|
|
# SQL_TYPE = "postgres"
|
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
SAMPLES_PATH = Path(__file__).parent / "samples"
|
2022-01-26 18:12:55 +01:00
|
|
|
|
2022-01-25 20:36:28 +01:00
|
|
|
# to run tests against Deepset Cloud set MOCK_DC to False and set the following params
|
|
|
|
DC_API_ENDPOINT = "https://DC_API/v1"
|
|
|
|
DC_TEST_INDEX = "document_retrieval_1"
|
|
|
|
DC_API_KEY = "NO_KEY"
|
|
|
|
MOCK_DC = True
|
|
|
|
|
2022-03-21 11:58:51 +01:00
|
|
|
# Disable telemetry reports when running tests
|
|
|
|
posthog.disabled = True
|
|
|
|
|
2022-03-31 12:36:45 +02:00
|
|
|
# Cache requests (e.g. huggingface model) to circumvent load protection
|
|
|
|
# See https://requests-cache.readthedocs.io/en/stable/user_guide/filtering.html
|
|
|
|
requests_cache.install_cache(urls_expire_after={"huggingface.co": timedelta(hours=1), "*": requests_cache.DO_NOT_CACHE})
|
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
|
2021-01-22 14:39:24 +01:00
|
|
|
def _sql_session_rollback(self, attr):
|
|
|
|
"""
|
|
|
|
Inject SQLDocumentStore at runtime to do a session rollback each time it is called. This allows to catch
|
|
|
|
errors where an intended operation is still in a transaction, but not committed to the database.
|
|
|
|
"""
|
|
|
|
method = object.__getattribute__(self, attr)
|
|
|
|
if callable(method):
|
|
|
|
try:
|
|
|
|
self.session.rollback()
|
|
|
|
except AttributeError:
|
|
|
|
pass
|
|
|
|
|
|
|
|
return method
|
|
|
|
|
|
|
|
|
|
|
|
SQLDocumentStore.__getattribute__ = _sql_session_rollback
|
|
|
|
|
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
def pytest_collection_modifyitems(config, items):
|
2022-06-07 09:23:03 +02:00
|
|
|
# add pytest markers for tests that are not explicitly marked but include some keywords
|
|
|
|
name_to_markers = {
|
|
|
|
"generator": [pytest.mark.generator],
|
|
|
|
"summarizer": [pytest.mark.summarizer],
|
|
|
|
"tika": [pytest.mark.tika, pytest.mark.integration],
|
|
|
|
"parsr": [pytest.mark.parsr, pytest.mark.integration],
|
|
|
|
"ocr": [pytest.mark.ocr, pytest.mark.integration],
|
|
|
|
"elasticsearch": [pytest.mark.elasticsearch],
|
|
|
|
"faiss": [pytest.mark.faiss],
|
|
|
|
"milvus": [pytest.mark.milvus, pytest.mark.milvus1],
|
|
|
|
"weaviate": [pytest.mark.weaviate],
|
|
|
|
"pinecone": [pytest.mark.pinecone],
|
|
|
|
# FIXME GraphDB can't be treated as a regular docstore, it fails most of their tests
|
|
|
|
"graphdb": [pytest.mark.integration],
|
|
|
|
}
|
2020-10-30 18:06:02 +01:00
|
|
|
for item in items:
|
2022-06-07 09:23:03 +02:00
|
|
|
for name, markers in name_to_markers.items():
|
|
|
|
if name in item.nodeid.lower():
|
|
|
|
for marker in markers:
|
|
|
|
item.add_marker(marker)
|
2020-10-30 18:06:02 +01:00
|
|
|
|
2021-09-27 10:52:07 +02:00
|
|
|
# if the cli argument "--document_store_type" is used, we want to skip all tests that have markers of other docstores
|
|
|
|
# Example: pytest -v test_document_store.py --document_store_type="memory" => skip all tests marked with "elasticsearch"
|
|
|
|
document_store_types_to_run = config.getoption("--document_store_type")
|
2022-06-07 09:23:03 +02:00
|
|
|
document_store_types_to_run = [docstore.strip() for docstore in document_store_types_to_run.split(",")]
|
2021-09-28 16:38:21 +02:00
|
|
|
keywords = []
|
2022-02-24 17:43:38 +01:00
|
|
|
|
2021-09-28 16:38:21 +02:00
|
|
|
for i in item.keywords:
|
|
|
|
if "-" in i:
|
|
|
|
keywords.extend(i.split("-"))
|
|
|
|
else:
|
|
|
|
keywords.append(i)
|
2022-03-21 22:24:09 +07:00
|
|
|
for cur_doc_store in ["elasticsearch", "faiss", "sql", "memory", "milvus1", "milvus", "weaviate", "pinecone"]:
|
2021-09-28 16:38:21 +02:00
|
|
|
if cur_doc_store in keywords and cur_doc_store not in document_store_types_to_run:
|
2021-09-27 10:52:07 +02:00
|
|
|
skip_docstore = pytest.mark.skip(
|
2022-02-03 13:43:18 +01:00
|
|
|
reason=f'{cur_doc_store} is disabled. Enable via pytest --document_store_type="{cur_doc_store}"'
|
|
|
|
)
|
2021-09-27 10:52:07 +02:00
|
|
|
item.add_marker(skip_docstore)
|
|
|
|
|
2022-03-03 15:19:27 +01:00
|
|
|
if "milvus1" in keywords and not milvus1:
|
|
|
|
skip_milvus1 = pytest.mark.skip(reason="Skipping Tests for 'milvus1', as Milvus2 seems to be installed.")
|
|
|
|
item.add_marker(skip_milvus1)
|
|
|
|
elif "milvus" in keywords and milvus1:
|
|
|
|
skip_milvus = pytest.mark.skip(reason="Skipping Tests for 'milvus', as Milvus1 seems to be installed.")
|
|
|
|
item.add_marker(skip_milvus)
|
|
|
|
|
2022-03-21 22:24:09 +07:00
|
|
|
# Skip PineconeDocumentStore if PINECONE_API_KEY not in environment variables
|
|
|
|
if not os.environ.get("PINECONE_API_KEY", False) and "pinecone" in keywords:
|
|
|
|
skip_pinecone = pytest.mark.skip(reason="PINECONE_API_KEY not in environment variables.")
|
|
|
|
item.add_marker(skip_pinecone)
|
|
|
|
|
2020-10-30 18:06:02 +01:00
|
|
|
|
2022-03-15 11:17:26 +01:00
|
|
|
#
|
|
|
|
# Empty mocks, as a base for unit tests.
|
|
|
|
#
|
|
|
|
# Monkeypatch the methods you need with either a mock implementation
|
|
|
|
# or a unittest.mock.MagicMock object (https://docs.python.org/3/library/unittest.mock.html)
|
|
|
|
#
|
|
|
|
|
|
|
|
|
|
|
|
class MockNode(BaseComponent):
|
|
|
|
outgoing_edges = 1
|
|
|
|
|
|
|
|
def run(self, *a, **k):
|
|
|
|
pass
|
|
|
|
|
2022-05-11 11:11:00 +02:00
|
|
|
def run_batch(self, *a, **k):
|
|
|
|
pass
|
|
|
|
|
2022-03-15 11:17:26 +01:00
|
|
|
|
|
|
|
class MockDocumentStore(BaseDocumentStore):
|
|
|
|
outgoing_edges = 1
|
|
|
|
|
|
|
|
def _create_document_field_map(self, *a, **k):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def delete_documents(self, *a, **k):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def delete_labels(self, *a, **k):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def get_all_documents(self, *a, **k):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def get_all_documents_generator(self, *a, **k):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def get_all_labels(self, *a, **k):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def get_document_by_id(self, *a, **k):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def get_document_count(self, *a, **k):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def get_documents_by_id(self, *a, **k):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def get_label_count(self, *a, **k):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def query_by_embedding(self, *a, **k):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def write_documents(self, *a, **k):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def write_labels(self, *a, **k):
|
|
|
|
pass
|
|
|
|
|
2022-03-21 19:04:28 +01:00
|
|
|
def delete_index(self, *a, **k):
|
|
|
|
pass
|
|
|
|
|
2022-03-15 11:17:26 +01:00
|
|
|
|
|
|
|
class MockRetriever(BaseRetriever):
|
|
|
|
outgoing_edges = 1
|
|
|
|
|
|
|
|
def retrieve(self, query: str, top_k: int):
|
|
|
|
pass
|
|
|
|
|
2022-05-24 12:33:45 +02:00
|
|
|
def retrieve_batch(self, queries: List[str], top_k: int):
|
2022-05-11 11:11:00 +02:00
|
|
|
pass
|
|
|
|
|
2022-03-15 11:17:26 +01:00
|
|
|
|
2022-05-04 17:39:06 +02:00
|
|
|
class MockDenseRetriever(MockRetriever):
|
|
|
|
def __init__(self, document_store: BaseDocumentStore, embedding_dim: int = 768):
|
|
|
|
self.embedding_dim = embedding_dim
|
|
|
|
self.document_store = document_store
|
|
|
|
|
|
|
|
def embed_queries(self, texts):
|
|
|
|
return [np.random.rand(self.embedding_dim)] * len(texts)
|
|
|
|
|
|
|
|
def embed_documents(self, docs):
|
|
|
|
return [np.random.rand(self.embedding_dim)] * len(docs)
|
|
|
|
|
|
|
|
|
2022-03-15 11:17:26 +01:00
|
|
|
class MockReader(BaseReader):
|
|
|
|
outgoing_edges = 1
|
|
|
|
|
|
|
|
def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def predict_batch(self, query_doc_list: List[dict], top_k: Optional[int] = None, batch_size: Optional[int] = None):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
#
|
|
|
|
# Document collections
|
|
|
|
#
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def docs_all_formats() -> List[Union[Document, Dict[str, Any]]]:
|
|
|
|
return [
|
|
|
|
# metafield at the top level for backward compatibility
|
|
|
|
{
|
|
|
|
"content": "My name is Paul and I live in New York",
|
|
|
|
"meta_field": "test2",
|
|
|
|
"name": "filename2",
|
|
|
|
"date_field": "2019-10-01",
|
|
|
|
"numeric_field": 5.0,
|
|
|
|
},
|
|
|
|
# "dict" format
|
|
|
|
{
|
|
|
|
"content": "My name is Carla and I live in Berlin",
|
|
|
|
"meta": {"meta_field": "test1", "name": "filename1", "date_field": "2020-03-01", "numeric_field": 5.5},
|
|
|
|
},
|
|
|
|
# Document object
|
|
|
|
Document(
|
|
|
|
content="My name is Christelle and I live in Paris",
|
|
|
|
meta={"meta_field": "test3", "name": "filename3", "date_field": "2018-10-01", "numeric_field": 4.5},
|
|
|
|
),
|
|
|
|
Document(
|
|
|
|
content="My name is Camila and I live in Madrid",
|
|
|
|
meta={"meta_field": "test4", "name": "filename4", "date_field": "2021-02-01", "numeric_field": 3.0},
|
|
|
|
),
|
|
|
|
Document(
|
|
|
|
content="My name is Matteo and I live in Rome",
|
|
|
|
meta={"meta_field": "test5", "name": "filename5", "date_field": "2019-01-01", "numeric_field": 0.0},
|
|
|
|
),
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def docs(docs_all_formats) -> List[Document]:
|
|
|
|
return [Document.from_dict(doc) if isinstance(doc, dict) else doc for doc in docs_all_formats]
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def docs_with_ids(docs) -> List[Document]:
|
|
|
|
# Should be already sorted
|
|
|
|
uuids = [
|
|
|
|
UUID("190a2421-7e48-4a49-a639-35a86e202dfb"),
|
|
|
|
UUID("20ff1706-cb55-4704-8ae8-a3459774c8dc"),
|
|
|
|
UUID("5078722f-07ae-412d-8ccb-b77224c4bacb"),
|
|
|
|
UUID("81d8ca45-fad1-4d1c-8028-d818ef33d755"),
|
|
|
|
UUID("f985789f-1673-4d8f-8d5f-2b8d3a9e8e23"),
|
|
|
|
]
|
|
|
|
uuids.sort()
|
|
|
|
for doc, uuid in zip(docs, uuids):
|
|
|
|
doc.id = str(uuid)
|
|
|
|
return docs
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def docs_with_random_emb(docs) -> List[Document]:
|
|
|
|
for doc in docs:
|
|
|
|
doc.embedding = np.random.random([768])
|
|
|
|
return docs
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def docs_with_true_emb():
|
|
|
|
return [
|
|
|
|
Document(
|
|
|
|
content="The capital of Germany is the city state of Berlin.",
|
|
|
|
embedding=np.loadtxt(SAMPLES_PATH / "embeddings" / "embedding_1.txt"),
|
|
|
|
),
|
|
|
|
Document(
|
|
|
|
content="Berlin is the capital and largest city of Germany by both area and population.",
|
|
|
|
embedding=np.loadtxt(SAMPLES_PATH / "embeddings" / "embedding_2.txt"),
|
|
|
|
),
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
2021-12-22 17:20:23 +01:00
|
|
|
def gc_cleanup(request):
|
|
|
|
"""
|
|
|
|
Run garbage collector between tests in order to reduce memory footprint for CI.
|
|
|
|
"""
|
|
|
|
yield
|
|
|
|
gc.collect()
|
|
|
|
|
|
|
|
|
2020-05-04 18:00:07 +02:00
|
|
|
@pytest.fixture(scope="session")
|
2020-09-15 15:04:46 +02:00
|
|
|
def elasticsearch_fixture():
|
2020-06-09 12:46:15 +02:00
|
|
|
# test if a ES cluster is already running. If not, download and start an ES instance locally.
|
|
|
|
try:
|
2020-09-15 15:04:46 +02:00
|
|
|
client = Elasticsearch(hosts=[{"host": "localhost", "port": "9200"}])
|
2020-06-09 12:46:15 +02:00
|
|
|
client.info()
|
|
|
|
except:
|
2020-09-15 15:04:46 +02:00
|
|
|
print("Starting Elasticsearch ...")
|
2022-02-03 13:43:18 +01:00
|
|
|
status = subprocess.run(["docker rm haystack_test_elastic"], shell=True)
|
2020-09-15 15:04:46 +02:00
|
|
|
status = subprocess.run(
|
2022-02-03 13:43:18 +01:00
|
|
|
[
|
|
|
|
'docker run -d --name haystack_test_elastic -p 9200:9200 -e "discovery.type=single-node" elasticsearch:7.9.2'
|
|
|
|
],
|
|
|
|
shell=True,
|
2020-09-15 15:04:46 +02:00
|
|
|
)
|
|
|
|
if status.returncode:
|
2022-02-03 13:43:18 +01:00
|
|
|
raise Exception("Failed to launch Elasticsearch. Please check docker container logs.")
|
2020-09-15 15:04:46 +02:00
|
|
|
time.sleep(30)
|
|
|
|
|
|
|
|
|
2021-01-29 13:29:12 +01:00
|
|
|
@pytest.fixture(scope="session")
|
|
|
|
def milvus_fixture():
|
|
|
|
# test if a Milvus server is already running. If not, start Milvus docker container locally.
|
|
|
|
# Make sure you have given > 6GB memory to docker engine
|
|
|
|
try:
|
|
|
|
milvus_server = Milvus(uri="tcp://localhost:19530", timeout=5, wait_timeout=5)
|
|
|
|
milvus_server.server_status(timeout=5)
|
|
|
|
except:
|
|
|
|
print("Starting Milvus ...")
|
2022-02-03 13:43:18 +01:00
|
|
|
status = subprocess.run(
|
|
|
|
[
|
|
|
|
"docker run -d --name milvus_cpu_0.10.5 -p 19530:19530 -p 19121:19121 "
|
|
|
|
"milvusdb/milvus:0.10.5-cpu-d010621-4eda95"
|
|
|
|
],
|
|
|
|
shell=True,
|
|
|
|
)
|
2021-01-29 13:29:12 +01:00
|
|
|
time.sleep(40)
|
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
|
2021-06-10 13:13:53 +05:30
|
|
|
@pytest.fixture(scope="session")
|
|
|
|
def weaviate_fixture():
|
|
|
|
# test if a Weaviate server is already running. If not, start Weaviate docker container locally.
|
|
|
|
# Make sure you have given > 6GB memory to docker engine
|
|
|
|
try:
|
2022-02-03 13:43:18 +01:00
|
|
|
weaviate_server = weaviate.Client(url="http://localhost:8080", timeout_config=(5, 15))
|
2021-06-10 13:13:53 +05:30
|
|
|
weaviate_server.is_ready()
|
|
|
|
except:
|
|
|
|
print("Starting Weaviate servers ...")
|
2022-02-03 13:43:18 +01:00
|
|
|
status = subprocess.run(["docker rm haystack_test_weaviate"], shell=True)
|
2021-06-10 13:13:53 +05:30
|
|
|
status = subprocess.run(
|
2022-04-01 14:37:34 +02:00
|
|
|
["docker run -d --name haystack_test_weaviate -p 8080:8080 semitechnologies/weaviate:1.11.0"], shell=True
|
2021-06-10 13:13:53 +05:30
|
|
|
)
|
|
|
|
if status.returncode:
|
2022-02-03 13:43:18 +01:00
|
|
|
raise Exception("Failed to launch Weaviate. Please check docker container logs.")
|
2021-06-10 13:13:53 +05:30
|
|
|
time.sleep(60)
|
2021-01-29 13:29:12 +01:00
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
|
2021-04-08 14:05:33 +02:00
|
|
|
@pytest.fixture(scope="session")
|
|
|
|
def graphdb_fixture():
|
|
|
|
# test if a GraphDB instance is already running. If not, download and start a GraphDB instance locally.
|
|
|
|
try:
|
|
|
|
kg = GraphDBKnowledgeGraph()
|
|
|
|
# fail if not running GraphDB
|
|
|
|
kg.delete_index()
|
|
|
|
except:
|
|
|
|
print("Starting GraphDB ...")
|
2022-02-03 13:43:18 +01:00
|
|
|
status = subprocess.run(["docker rm haystack_test_graphdb"], shell=True)
|
2021-04-08 14:05:33 +02:00
|
|
|
status = subprocess.run(
|
2022-02-03 13:43:18 +01:00
|
|
|
[
|
|
|
|
"docker run -d -p 7200:7200 --name haystack_test_graphdb docker-registry.ontotext.com/graphdb-free:9.4.1-adoptopenjdk11"
|
|
|
|
],
|
|
|
|
shell=True,
|
2021-04-08 14:05:33 +02:00
|
|
|
)
|
|
|
|
if status.returncode:
|
2022-02-03 13:43:18 +01:00
|
|
|
raise Exception("Failed to launch GraphDB. Please check docker container logs.")
|
2021-04-08 14:05:33 +02:00
|
|
|
time.sleep(30)
|
|
|
|
|
|
|
|
|
2020-09-15 15:04:46 +02:00
|
|
|
@pytest.fixture(scope="session")
|
|
|
|
def tika_fixture():
|
|
|
|
try:
|
|
|
|
tika_url = "http://localhost:9998/tika"
|
|
|
|
ping = requests.get(tika_url)
|
|
|
|
if ping.status_code != 200:
|
2022-02-03 13:43:18 +01:00
|
|
|
raise Exception("Unable to connect Tika. Please check tika endpoint {0}.".format(tika_url))
|
2020-09-15 15:04:46 +02:00
|
|
|
except:
|
|
|
|
print("Starting Tika ...")
|
2022-02-03 13:43:18 +01:00
|
|
|
status = subprocess.run(["docker run -d --name tika -p 9998:9998 apache/tika:1.24.1"], shell=True)
|
2020-09-15 15:04:46 +02:00
|
|
|
if status.returncode:
|
2022-02-03 13:43:18 +01:00
|
|
|
raise Exception("Failed to launch Tika. Please check docker container logs.")
|
2020-09-15 15:04:46 +02:00
|
|
|
time.sleep(30)
|
2020-06-08 11:07:19 +02:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
2021-10-29 13:52:28 +05:30
|
|
|
def xpdf_fixture():
|
2020-06-08 11:07:19 +02:00
|
|
|
verify_installation = run(["pdftotext"], shell=True)
|
|
|
|
if verify_installation.returncode == 127:
|
2020-09-15 15:04:46 +02:00
|
|
|
if platform.startswith("linux"):
|
|
|
|
platform_id = "linux"
|
|
|
|
sudo_prefix = "sudo"
|
|
|
|
elif platform.startswith("darwin"):
|
|
|
|
platform_id = "mac"
|
|
|
|
# For Mac, generally sudo need password in interactive console.
|
|
|
|
# But most of the cases current user already have permission to copy to /user/local/bin.
|
|
|
|
# Hence removing sudo requirement for Mac.
|
|
|
|
sudo_prefix = ""
|
|
|
|
else:
|
|
|
|
raise Exception(
|
|
|
|
"""Currently auto installation of pdftotext is not supported on {0} platform """.format(platform)
|
|
|
|
)
|
2021-02-12 13:38:54 +01:00
|
|
|
commands = """ wget --no-check-certificate https://dl.xpdfreader.com/xpdf-tools-{0}-4.03.tar.gz &&
|
|
|
|
tar -xvf xpdf-tools-{0}-4.03.tar.gz &&
|
2022-02-03 13:43:18 +01:00
|
|
|
{1} cp xpdf-tools-{0}-4.03/bin64/pdftotext /usr/local/bin""".format(
|
|
|
|
platform_id, sudo_prefix
|
|
|
|
)
|
2020-06-08 11:07:19 +02:00
|
|
|
run([commands], shell=True)
|
|
|
|
|
|
|
|
verify_installation = run(["pdftotext -v"], shell=True)
|
|
|
|
if verify_installation.returncode == 127:
|
|
|
|
raise Exception(
|
|
|
|
"""pdftotext is not installed. It is part of xpdf or poppler-utils software suite.
|
|
|
|
You can download for your OS from here: https://www.xpdfreader.com/download.html."""
|
|
|
|
)
|
2020-07-10 10:54:56 +02:00
|
|
|
|
2020-08-17 11:21:09 +02:00
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.fixture
|
2022-02-03 13:43:18 +01:00
|
|
|
def deepset_cloud_fixture():
|
2022-01-25 20:36:28 +01:00
|
|
|
if MOCK_DC:
|
|
|
|
responses.add(
|
2022-02-03 13:43:18 +01:00
|
|
|
method=responses.GET,
|
2022-01-25 20:36:28 +01:00
|
|
|
url=f"{DC_API_ENDPOINT}/workspaces/default/indexes/{DC_TEST_INDEX}",
|
|
|
|
match=[responses.matchers.header_matcher({"authorization": f"Bearer {DC_API_KEY}"})],
|
2022-02-03 13:43:18 +01:00
|
|
|
json={"indexing": {"status": "INDEXED", "pending_file_count": 0, "total_file_count": 31}},
|
|
|
|
status=200,
|
|
|
|
)
|
2022-05-10 15:21:35 +02:00
|
|
|
responses.add(
|
|
|
|
method=responses.GET,
|
|
|
|
url=f"{DC_API_ENDPOINT}/workspaces/default/pipelines",
|
|
|
|
match=[responses.matchers.header_matcher({"authorization": f"Bearer {DC_API_KEY}"})],
|
|
|
|
json={
|
|
|
|
"data": [
|
|
|
|
{
|
|
|
|
"name": DC_TEST_INDEX,
|
|
|
|
"status": "DEPLOYED",
|
|
|
|
"indexing": {"status": "INDEXED", "pending_file_count": 0, "total_file_count": 31},
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"has_more": False,
|
|
|
|
"total": 1,
|
|
|
|
},
|
|
|
|
)
|
2022-01-25 20:36:28 +01:00
|
|
|
else:
|
|
|
|
responses.add_passthru(DC_API_ENDPOINT)
|
|
|
|
|
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.fixture
|
2022-01-25 20:36:28 +01:00
|
|
|
@responses.activate
|
|
|
|
def deepset_cloud_document_store(deepset_cloud_fixture):
|
|
|
|
return DeepsetCloudDocumentStore(api_endpoint=DC_API_ENDPOINT, api_key=DC_API_KEY, index=DC_TEST_INDEX)
|
|
|
|
|
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.fixture
|
2020-10-30 18:06:02 +01:00
|
|
|
def rag_generator():
|
2022-03-15 11:17:26 +01:00
|
|
|
return RAGenerator(model_name_or_path="facebook/rag-token-nq", generator_type="token", max_length=20)
|
2020-10-30 18:06:02 +01:00
|
|
|
|
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.fixture
|
2021-07-26 17:20:43 +02:00
|
|
|
def question_generator():
|
|
|
|
return QuestionGenerator(model_name_or_path="valhalla/t5-small-e2e-qg")
|
|
|
|
|
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.fixture
|
2022-03-08 15:11:41 +01:00
|
|
|
def lfqa_generator(request):
|
|
|
|
return Seq2SeqGenerator(model_name_or_path=request.param, min_length=100, max_length=200)
|
2021-06-14 17:53:43 +02:00
|
|
|
|
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.fixture
|
2021-01-08 14:29:46 +01:00
|
|
|
def summarizer():
|
2022-02-03 13:43:18 +01:00
|
|
|
return TransformersSummarizer(model_name_or_path="google/pegasus-xsum", use_gpu=-1)
|
2021-01-08 14:29:46 +01:00
|
|
|
|
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.fixture
|
2021-02-12 15:58:26 +01:00
|
|
|
def en_to_de_translator():
|
2022-03-07 19:25:33 +01:00
|
|
|
return TransformersTranslator(model_name_or_path="Helsinki-NLP/opus-mt-en-de")
|
2021-02-12 15:58:26 +01:00
|
|
|
|
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.fixture
|
2021-02-12 15:58:26 +01:00
|
|
|
def de_to_en_translator():
|
2022-03-07 19:25:33 +01:00
|
|
|
return TransformersTranslator(model_name_or_path="Helsinki-NLP/opus-mt-de-en")
|
2021-02-12 15:58:26 +01:00
|
|
|
|
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.fixture
|
2021-08-17 10:27:11 +02:00
|
|
|
def reader_without_normalized_scores():
|
|
|
|
return FARMReader(
|
|
|
|
model_name_or_path="distilbert-base-uncased-distilled-squad",
|
|
|
|
use_gpu=False,
|
|
|
|
top_k_per_sample=5,
|
|
|
|
num_processes=0,
|
2022-02-03 13:43:18 +01:00
|
|
|
use_confidence_scores=False,
|
2021-08-17 10:27:11 +02:00
|
|
|
)
|
|
|
|
|
2021-09-27 10:52:07 +02:00
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.fixture(params=["farm", "transformers"])
|
2020-12-03 10:27:06 +01:00
|
|
|
def reader(request):
|
2020-07-10 10:54:56 +02:00
|
|
|
if request.param == "farm":
|
2020-12-03 10:27:06 +01:00
|
|
|
return FARMReader(
|
|
|
|
model_name_or_path="distilbert-base-uncased-distilled-squad",
|
|
|
|
use_gpu=False,
|
|
|
|
top_k_per_sample=5,
|
2022-02-03 13:43:18 +01:00
|
|
|
num_processes=0,
|
2020-12-03 10:27:06 +01:00
|
|
|
)
|
2020-07-10 10:54:56 +02:00
|
|
|
if request.param == "transformers":
|
2020-12-03 10:27:06 +01:00
|
|
|
return TransformersReader(
|
|
|
|
model_name_or_path="distilbert-base-uncased-distilled-squad",
|
|
|
|
tokenizer="distilbert-base-uncased",
|
2022-02-03 13:43:18 +01:00
|
|
|
use_gpu=-1,
|
2020-12-03 10:27:06 +01:00
|
|
|
)
|
2020-07-10 10:54:56 +02:00
|
|
|
|
2021-10-15 16:34:48 +02:00
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.fixture(params=["tapas", "rci"])
|
2022-01-03 16:59:24 +01:00
|
|
|
def table_reader(request):
|
|
|
|
if request.param == "tapas":
|
|
|
|
return TableReader(model_name_or_path="google/tapas-base-finetuned-wtq")
|
|
|
|
elif request.param == "rci":
|
2022-02-03 13:43:18 +01:00
|
|
|
return RCIReader(
|
|
|
|
row_model_name_or_path="michaelrglass/albert-base-rci-wikisql-row",
|
|
|
|
column_model_name_or_path="michaelrglass/albert-base-rci-wikisql-col",
|
|
|
|
)
|
2021-10-15 16:34:48 +02:00
|
|
|
|
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.fixture
|
2021-12-06 17:13:57 +01:00
|
|
|
def ranker_two_logits():
|
2022-03-07 19:25:33 +01:00
|
|
|
return SentenceTransformersRanker(model_name_or_path="deepset/gbert-base-germandpr-reranking")
|
2021-12-06 17:13:57 +01:00
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.fixture
|
2021-09-28 16:34:24 +02:00
|
|
|
def ranker():
|
2022-03-07 19:25:33 +01:00
|
|
|
return SentenceTransformersRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-12-v2")
|
2021-07-13 21:44:26 +02:00
|
|
|
|
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.fixture
|
2021-10-01 11:22:56 +02:00
|
|
|
def document_classifier():
|
|
|
|
return TransformersDocumentClassifier(
|
2022-02-03 13:43:18 +01:00
|
|
|
model_name_or_path="bhadresh-savani/distilbert-base-uncased-emotion", use_gpu=False
|
2021-10-01 11:22:56 +02:00
|
|
|
)
|
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.fixture
|
2021-10-01 11:22:56 +02:00
|
|
|
def zero_shot_document_classifier():
|
|
|
|
return TransformersDocumentClassifier(
|
|
|
|
model_name_or_path="cross-encoder/nli-distilroberta-base",
|
2021-11-09 18:43:00 +01:00
|
|
|
use_gpu=False,
|
2021-10-01 11:22:56 +02:00
|
|
|
task="zero-shot-classification",
|
2022-02-03 13:43:18 +01:00
|
|
|
labels=["negative", "positive"],
|
2021-10-01 11:22:56 +02:00
|
|
|
)
|
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.fixture
|
2021-11-09 18:43:00 +01:00
|
|
|
def batched_document_classifier():
|
|
|
|
return TransformersDocumentClassifier(
|
2022-02-03 13:43:18 +01:00
|
|
|
model_name_or_path="bhadresh-savani/distilbert-base-uncased-emotion", use_gpu=False, batch_size=16
|
2021-11-09 18:43:00 +01:00
|
|
|
)
|
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.fixture
|
2021-11-09 18:43:00 +01:00
|
|
|
def indexing_document_classifier():
|
|
|
|
return TransformersDocumentClassifier(
|
|
|
|
model_name_or_path="bhadresh-savani/distilbert-base-uncased-emotion",
|
|
|
|
use_gpu=False,
|
|
|
|
batch_size=16,
|
2022-02-03 13:43:18 +01:00
|
|
|
classification_field="class_field",
|
2021-11-09 18:43:00 +01:00
|
|
|
)
|
2021-10-01 11:22:56 +02:00
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
|
2020-07-14 18:53:15 +02:00
|
|
|
# TODO Fix bug in test_no_answer_output when using
|
|
|
|
# @pytest.fixture(params=["farm", "transformers"])
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.fixture(params=["farm"])
|
2020-12-03 10:27:06 +01:00
|
|
|
def no_answer_reader(request):
|
2020-07-14 18:53:15 +02:00
|
|
|
if request.param == "farm":
|
2020-12-03 10:27:06 +01:00
|
|
|
return FARMReader(
|
|
|
|
model_name_or_path="deepset/roberta-base-squad2",
|
|
|
|
use_gpu=False,
|
|
|
|
top_k_per_sample=5,
|
|
|
|
no_ans_boost=0,
|
2020-12-07 14:07:20 +01:00
|
|
|
return_no_answer=True,
|
2022-02-03 13:43:18 +01:00
|
|
|
num_processes=0,
|
2020-12-03 10:27:06 +01:00
|
|
|
)
|
2020-07-14 18:53:15 +02:00
|
|
|
if request.param == "transformers":
|
2020-12-03 10:27:06 +01:00
|
|
|
return TransformersReader(
|
|
|
|
model_name_or_path="deepset/roberta-base-squad2",
|
|
|
|
tokenizer="deepset/roberta-base-squad2",
|
|
|
|
use_gpu=-1,
|
2022-02-03 13:43:18 +01:00
|
|
|
top_k_per_candidate=5,
|
2020-12-03 10:27:06 +01:00
|
|
|
)
|
2020-07-14 18:53:15 +02:00
|
|
|
|
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.fixture
|
|
|
|
def prediction(reader, docs):
|
2020-11-30 17:50:04 +01:00
|
|
|
prediction = reader.predict(query="Who lives in Berlin?", documents=docs, top_k=5)
|
2020-07-14 18:53:15 +02:00
|
|
|
return prediction
|
|
|
|
|
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.fixture
|
|
|
|
def no_answer_prediction(no_answer_reader, docs):
|
2020-11-30 17:50:04 +01:00
|
|
|
prediction = no_answer_reader.predict(query="What is the meaning of life?", documents=docs, top_k=5)
|
2020-07-14 18:53:15 +02:00
|
|
|
return prediction
|
|
|
|
|
|
|
|
|
2021-10-25 12:27:02 +02:00
|
|
|
@pytest.fixture(params=["es_filter_only", "elasticsearch", "dpr", "embedding", "tfidf", "table_text_retriever"])
|
2020-10-14 16:15:04 +02:00
|
|
|
def retriever(request, document_store):
|
|
|
|
return get_retriever(request.param, document_store)
|
|
|
|
|
|
|
|
|
2021-10-13 14:23:23 +02:00
|
|
|
# @pytest.fixture(params=["es_filter_only", "elasticsearch", "dpr", "embedding", "tfidf"])
|
|
|
|
@pytest.fixture(params=["tfidf"])
|
2020-10-14 16:15:04 +02:00
|
|
|
def retriever_with_docs(request, document_store_with_docs):
|
|
|
|
return get_retriever(request.param, document_store_with_docs)
|
|
|
|
|
|
|
|
|
|
|
|
def get_retriever(retriever_type, document_store):
|
|
|
|
|
|
|
|
if retriever_type == "dpr":
|
2022-02-03 13:43:18 +01:00
|
|
|
retriever = DensePassageRetriever(
|
|
|
|
document_store=document_store,
|
|
|
|
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
|
|
|
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
|
|
|
use_gpu=False,
|
|
|
|
embed_title=True,
|
|
|
|
)
|
2022-07-05 11:31:11 +02:00
|
|
|
elif retriever_type == "mdr":
|
|
|
|
retriever = MultihopEmbeddingRetriever(
|
|
|
|
document_store=document_store,
|
|
|
|
embedding_model="deutschmann/mdr_roberta_q_encoder", # or "facebook/dpr-ctx_encoder-single-nq-base"
|
|
|
|
use_gpu=False,
|
|
|
|
)
|
2020-10-14 16:15:04 +02:00
|
|
|
elif retriever_type == "tfidf":
|
2021-02-12 14:57:06 +01:00
|
|
|
retriever = TfidfRetriever(document_store=document_store)
|
|
|
|
retriever.fit()
|
2020-10-14 16:15:04 +02:00
|
|
|
elif retriever_type == "embedding":
|
2020-10-30 18:06:02 +01:00
|
|
|
retriever = EmbeddingRetriever(
|
2022-02-03 13:43:18 +01:00
|
|
|
document_store=document_store, embedding_model="deepset/sentence_bert", use_gpu=False
|
2020-10-30 18:06:02 +01:00
|
|
|
)
|
2022-06-02 16:12:47 +02:00
|
|
|
elif retriever_type == "embedding_sbert":
|
|
|
|
retriever = EmbeddingRetriever(
|
|
|
|
document_store=document_store,
|
|
|
|
embedding_model="sentence-transformers/msmarco-distilbert-base-tas-b",
|
|
|
|
model_format="sentence_transformers",
|
|
|
|
use_gpu=False,
|
|
|
|
)
|
2021-06-14 17:53:43 +02:00
|
|
|
elif retriever_type == "retribert":
|
2022-02-03 13:43:18 +01:00
|
|
|
retriever = EmbeddingRetriever(
|
2022-06-02 15:05:29 +02:00
|
|
|
document_store=document_store, embedding_model="yjernite/retribert-base-uncased", use_gpu=False
|
2022-02-03 13:43:18 +01:00
|
|
|
)
|
2022-03-08 15:11:41 +01:00
|
|
|
elif retriever_type == "dpr_lfqa":
|
|
|
|
retriever = DensePassageRetriever(
|
|
|
|
document_store=document_store,
|
|
|
|
query_embedding_model="vblagoje/dpr-question_encoder-single-lfqa-wiki",
|
|
|
|
passage_embedding_model="vblagoje/dpr-ctx_encoder-single-lfqa-wiki",
|
|
|
|
use_gpu=False,
|
|
|
|
embed_title=True,
|
|
|
|
)
|
2020-10-23 17:50:49 +02:00
|
|
|
elif retriever_type == "elasticsearch":
|
2022-04-26 16:09:39 +02:00
|
|
|
retriever = BM25Retriever(document_store=document_store)
|
2020-10-14 16:15:04 +02:00
|
|
|
elif retriever_type == "es_filter_only":
|
2022-04-29 10:16:02 +02:00
|
|
|
retriever = FilterRetriever(document_store=document_store)
|
2021-10-25 12:27:02 +02:00
|
|
|
elif retriever_type == "table_text_retriever":
|
2022-02-03 13:43:18 +01:00
|
|
|
retriever = TableTextRetriever(
|
|
|
|
document_store=document_store,
|
|
|
|
query_embedding_model="deepset/bert-small-mm_retrieval-question_encoder",
|
|
|
|
passage_embedding_model="deepset/bert-small-mm_retrieval-passage_encoder",
|
|
|
|
table_embedding_model="deepset/bert-small-mm_retrieval-table_encoder",
|
|
|
|
use_gpu=False,
|
|
|
|
)
|
2020-10-14 16:15:04 +02:00
|
|
|
else:
|
|
|
|
raise Exception(f"No retriever fixture for '{retriever_type}'")
|
|
|
|
|
|
|
|
return retriever
|
2020-12-14 18:15:44 +01:00
|
|
|
|
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
def ensure_ids_are_correct_uuids(docs: list, document_store: object) -> None:
|
2022-01-14 13:48:58 +01:00
|
|
|
# Weaviate currently only supports UUIDs
|
2022-02-03 13:43:18 +01:00
|
|
|
if type(document_store) == WeaviateDocumentStore:
|
2022-01-14 13:48:58 +01:00
|
|
|
for d in docs:
|
|
|
|
d["id"] = str(uuid.uuid4())
|
|
|
|
|
|
|
|
|
2022-03-21 22:24:09 +07:00
|
|
|
@pytest.fixture(params=["elasticsearch", "faiss", "memory", "milvus1", "milvus", "weaviate", "pinecone"])
|
2022-06-10 18:22:48 +02:00
|
|
|
def document_store_with_docs(request, docs, tmp_path):
|
2022-01-12 19:28:20 +01:00
|
|
|
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768))
|
2022-02-03 13:43:18 +01:00
|
|
|
document_store = get_document_store(
|
|
|
|
document_store_type=request.param, embedding_dim=embedding_dim.args[0], tmp_path=tmp_path
|
|
|
|
)
|
2022-06-10 18:22:48 +02:00
|
|
|
document_store.write_documents(docs)
|
2020-12-14 18:15:44 +01:00
|
|
|
yield document_store
|
2022-04-26 19:06:30 +02:00
|
|
|
document_store.delete_index(document_store.index)
|
2020-12-14 18:15:44 +01:00
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
|
2021-06-22 16:08:23 +02:00
|
|
|
@pytest.fixture
|
2022-01-14 13:48:58 +01:00
|
|
|
def document_store(request, tmp_path):
|
2022-01-10 17:10:32 +00:00
|
|
|
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768))
|
2022-02-03 13:43:18 +01:00
|
|
|
document_store = get_document_store(
|
|
|
|
document_store_type=request.param, embedding_dim=embedding_dim.args[0], tmp_path=tmp_path
|
|
|
|
)
|
2020-12-14 18:15:44 +01:00
|
|
|
yield document_store
|
2022-04-26 19:06:30 +02:00
|
|
|
document_store.delete_index(document_store.index)
|
2022-03-21 22:24:09 +07:00
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
|
2022-03-21 22:24:09 +07:00
|
|
|
@pytest.fixture(params=["memory", "faiss", "milvus1", "milvus", "elasticsearch", "pinecone"])
|
2022-01-14 13:48:58 +01:00
|
|
|
def document_store_dot_product(request, tmp_path):
|
2022-01-10 17:10:32 +00:00
|
|
|
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768))
|
2022-02-03 13:43:18 +01:00
|
|
|
document_store = get_document_store(
|
|
|
|
document_store_type=request.param,
|
|
|
|
embedding_dim=embedding_dim.args[0],
|
|
|
|
similarity="dot_product",
|
|
|
|
tmp_path=tmp_path,
|
|
|
|
)
|
2022-01-12 19:28:20 +01:00
|
|
|
yield document_store
|
2022-04-26 19:06:30 +02:00
|
|
|
document_store.delete_index(document_store.index)
|
2022-01-12 19:28:20 +01:00
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
|
2022-03-21 22:24:09 +07:00
|
|
|
@pytest.fixture(params=["memory", "faiss", "milvus1", "milvus", "elasticsearch", "pinecone"])
|
2022-06-10 18:22:48 +02:00
|
|
|
def document_store_dot_product_with_docs(request, docs, tmp_path):
|
2022-01-12 19:28:20 +01:00
|
|
|
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768))
|
2022-02-03 13:43:18 +01:00
|
|
|
document_store = get_document_store(
|
|
|
|
document_store_type=request.param,
|
|
|
|
embedding_dim=embedding_dim.args[0],
|
|
|
|
similarity="dot_product",
|
|
|
|
tmp_path=tmp_path,
|
|
|
|
)
|
2022-06-10 18:22:48 +02:00
|
|
|
document_store.write_documents(docs)
|
2022-01-12 19:28:20 +01:00
|
|
|
yield document_store
|
2022-04-26 19:06:30 +02:00
|
|
|
document_store.delete_index(document_store.index)
|
2022-01-12 19:28:20 +01:00
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
|
2022-03-21 22:24:09 +07:00
|
|
|
@pytest.fixture(params=["elasticsearch", "faiss", "memory", "milvus1", "pinecone"])
|
2022-01-14 13:48:58 +01:00
|
|
|
def document_store_dot_product_small(request, tmp_path):
|
2022-01-12 19:28:20 +01:00
|
|
|
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(3))
|
2022-02-03 13:43:18 +01:00
|
|
|
document_store = get_document_store(
|
|
|
|
document_store_type=request.param,
|
|
|
|
embedding_dim=embedding_dim.args[0],
|
|
|
|
similarity="dot_product",
|
|
|
|
tmp_path=tmp_path,
|
|
|
|
)
|
2021-11-01 15:42:32 +03:00
|
|
|
yield document_store
|
2022-04-26 19:06:30 +02:00
|
|
|
document_store.delete_index(document_store.index)
|
2021-11-01 15:42:32 +03:00
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
|
2022-03-21 22:24:09 +07:00
|
|
|
@pytest.fixture(params=["elasticsearch", "faiss", "memory", "milvus1", "milvus", "weaviate", "pinecone"])
|
2022-01-14 13:48:58 +01:00
|
|
|
def document_store_small(request, tmp_path):
|
2022-01-10 17:10:32 +00:00
|
|
|
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(3))
|
2022-02-03 13:43:18 +01:00
|
|
|
document_store = get_document_store(
|
|
|
|
document_store_type=request.param, embedding_dim=embedding_dim.args[0], similarity="cosine", tmp_path=tmp_path
|
|
|
|
)
|
2021-11-01 15:42:32 +03:00
|
|
|
yield document_store
|
2022-04-26 19:06:30 +02:00
|
|
|
document_store.delete_index(document_store.index)
|
2020-12-14 18:15:44 +01:00
|
|
|
|
2022-01-14 13:48:58 +01:00
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.fixture(autouse=True)
|
2022-01-14 13:48:58 +01:00
|
|
|
def postgres_fixture():
|
|
|
|
if SQL_TYPE == "postgres":
|
|
|
|
setup_postgres()
|
|
|
|
yield
|
|
|
|
teardown_postgres()
|
|
|
|
else:
|
|
|
|
yield
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def sql_url(tmp_path):
|
2022-02-03 13:43:18 +01:00
|
|
|
return get_sql_url(tmp_path)
|
2022-01-14 13:48:58 +01:00
|
|
|
|
|
|
|
|
|
|
|
def get_sql_url(tmp_path):
|
|
|
|
if SQL_TYPE == "postgres":
|
|
|
|
return "postgresql://postgres:postgres@127.0.0.1/postgres"
|
|
|
|
else:
|
|
|
|
return f"sqlite:///{tmp_path}/haystack_test.db"
|
|
|
|
|
|
|
|
|
|
|
|
def setup_postgres():
|
|
|
|
# status = subprocess.run(["docker run --name postgres_test -d -e POSTGRES_HOST_AUTH_METHOD=trust -p 5432:5432 postgres"], shell=True)
|
|
|
|
# if status.returncode:
|
|
|
|
# logging.warning("Tried to start PostgreSQL through Docker but this failed. It is likely that there is already an existing instance running.")
|
|
|
|
# else:
|
|
|
|
# sleep(5)
|
2022-02-03 13:43:18 +01:00
|
|
|
engine = create_engine("postgresql://postgres:postgres@127.0.0.1/postgres", isolation_level="AUTOCOMMIT")
|
2022-01-14 13:48:58 +01:00
|
|
|
|
|
|
|
with engine.connect() as connection:
|
|
|
|
try:
|
2022-02-03 13:43:18 +01:00
|
|
|
connection.execute(text("DROP SCHEMA public CASCADE"))
|
2022-01-14 13:48:58 +01:00
|
|
|
except Exception as e:
|
|
|
|
logging.error(e)
|
2022-02-03 13:43:18 +01:00
|
|
|
connection.execute(text("CREATE SCHEMA public;"))
|
2022-01-14 13:48:58 +01:00
|
|
|
connection.execute(text('SET SESSION idle_in_transaction_session_timeout = "1s";'))
|
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
|
2022-01-14 13:48:58 +01:00
|
|
|
def teardown_postgres():
|
2022-02-03 13:43:18 +01:00
|
|
|
engine = create_engine("postgresql://postgres:postgres@127.0.0.1/postgres", isolation_level="AUTOCOMMIT")
|
2022-01-14 13:48:58 +01:00
|
|
|
with engine.connect() as connection:
|
2022-02-03 13:43:18 +01:00
|
|
|
connection.execute(text("DROP SCHEMA public CASCADE"))
|
2022-01-14 13:48:58 +01:00
|
|
|
connection.close()
|
|
|
|
|
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
def get_document_store(
|
|
|
|
document_store_type,
|
|
|
|
tmp_path,
|
|
|
|
embedding_dim=768,
|
|
|
|
embedding_field="embedding",
|
|
|
|
index="haystack_test",
|
|
|
|
similarity: str = "cosine",
|
|
|
|
): # cosine is default similarity as dot product is not supported by Weaviate
|
2020-12-14 18:15:44 +01:00
|
|
|
if document_store_type == "sql":
|
2022-01-14 13:48:58 +01:00
|
|
|
document_store = SQLDocumentStore(url=get_sql_url(tmp_path), index=index, isolation_level="AUTOCOMMIT")
|
|
|
|
|
2020-12-14 18:15:44 +01:00
|
|
|
elif document_store_type == "memory":
|
2021-01-22 14:39:24 +01:00
|
|
|
document_store = InMemoryDocumentStore(
|
2022-02-03 13:43:18 +01:00
|
|
|
return_embedding=True,
|
|
|
|
embedding_dim=embedding_dim,
|
|
|
|
embedding_field=embedding_field,
|
|
|
|
index=index,
|
|
|
|
similarity=similarity,
|
|
|
|
)
|
|
|
|
|
2020-12-14 18:15:44 +01:00
|
|
|
elif document_store_type == "elasticsearch":
|
|
|
|
# make sure we start from a fresh index
|
2020-12-17 09:18:57 +01:00
|
|
|
document_store = ElasticsearchDocumentStore(
|
2022-02-03 13:43:18 +01:00
|
|
|
index=index,
|
|
|
|
return_embedding=True,
|
|
|
|
embedding_dim=embedding_dim,
|
|
|
|
embedding_field=embedding_field,
|
|
|
|
similarity=similarity,
|
2022-04-26 19:06:30 +02:00
|
|
|
recreate_index=True,
|
2020-12-17 09:18:57 +01:00
|
|
|
)
|
2022-01-14 13:48:58 +01:00
|
|
|
|
2020-12-14 18:15:44 +01:00
|
|
|
elif document_store_type == "faiss":
|
|
|
|
document_store = FAISSDocumentStore(
|
2022-01-10 17:10:32 +00:00
|
|
|
embedding_dim=embedding_dim,
|
2022-01-14 13:48:58 +01:00
|
|
|
sql_url=get_sql_url(tmp_path),
|
2021-01-21 16:00:08 +01:00
|
|
|
return_embedding=True,
|
|
|
|
embedding_field=embedding_field,
|
2021-11-01 15:42:32 +03:00
|
|
|
index=index,
|
2022-01-14 13:48:58 +01:00
|
|
|
similarity=similarity,
|
2022-02-03 13:43:18 +01:00
|
|
|
isolation_level="AUTOCOMMIT",
|
2020-12-14 18:15:44 +01:00
|
|
|
)
|
2022-01-14 13:48:58 +01:00
|
|
|
|
2022-02-24 17:43:38 +01:00
|
|
|
elif document_store_type == "milvus1":
|
2021-01-29 13:29:12 +01:00
|
|
|
document_store = MilvusDocumentStore(
|
2022-01-10 17:10:32 +00:00
|
|
|
embedding_dim=embedding_dim,
|
2022-01-14 13:48:58 +01:00
|
|
|
sql_url=get_sql_url(tmp_path),
|
2021-01-29 13:29:12 +01:00
|
|
|
return_embedding=True,
|
|
|
|
embedding_field=embedding_field,
|
2021-11-01 15:42:32 +03:00
|
|
|
index=index,
|
2022-01-14 13:48:58 +01:00
|
|
|
similarity=similarity,
|
2022-02-03 13:43:18 +01:00
|
|
|
isolation_level="AUTOCOMMIT",
|
2021-01-29 13:29:12 +01:00
|
|
|
)
|
2022-02-03 13:43:18 +01:00
|
|
|
|
2022-02-24 17:43:38 +01:00
|
|
|
elif document_store_type == "milvus":
|
|
|
|
document_store = MilvusDocumentStore(
|
|
|
|
embedding_dim=embedding_dim,
|
|
|
|
sql_url=get_sql_url(tmp_path),
|
|
|
|
return_embedding=True,
|
|
|
|
embedding_field=embedding_field,
|
|
|
|
index=index,
|
|
|
|
similarity=similarity,
|
|
|
|
isolation_level="AUTOCOMMIT",
|
2022-04-26 19:06:30 +02:00
|
|
|
recreate_index=True,
|
2022-02-24 17:43:38 +01:00
|
|
|
)
|
|
|
|
|
2021-06-10 13:13:53 +05:30
|
|
|
elif document_store_type == "weaviate":
|
2022-04-26 19:06:30 +02:00
|
|
|
document_store = WeaviateDocumentStore(
|
|
|
|
index=index, similarity=similarity, embedding_dim=embedding_dim, recreate_index=True
|
|
|
|
)
|
2022-03-21 22:24:09 +07:00
|
|
|
|
|
|
|
elif document_store_type == "pinecone":
|
|
|
|
document_store = PineconeDocumentStore(
|
|
|
|
api_key=os.environ["PINECONE_API_KEY"],
|
|
|
|
embedding_dim=embedding_dim,
|
|
|
|
embedding_field=embedding_field,
|
|
|
|
index=index,
|
|
|
|
similarity=similarity,
|
2022-04-26 19:06:30 +02:00
|
|
|
recreate_index=True,
|
2022-03-21 22:24:09 +07:00
|
|
|
)
|
|
|
|
|
2020-12-14 18:15:44 +01:00
|
|
|
else:
|
|
|
|
raise Exception(f"No document store fixture for '{document_store_type}'")
|
|
|
|
|
|
|
|
return document_store
|
2021-09-22 16:56:51 +02:00
|
|
|
|
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.fixture
|
2021-09-22 16:56:51 +02:00
|
|
|
def adaptive_model_qa(num_processes):
|
|
|
|
"""
|
|
|
|
PyTest Fixture for a Question Answering Inferencer based on PyTorch.
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
model = Inferencer.load(
|
|
|
|
"deepset/bert-base-cased-squad2",
|
|
|
|
task_type="question_answering",
|
|
|
|
batch_size=16,
|
|
|
|
num_processes=num_processes,
|
|
|
|
gpu=False,
|
|
|
|
)
|
|
|
|
yield model
|
|
|
|
finally:
|
|
|
|
if num_processes != 0:
|
|
|
|
# close the pool
|
|
|
|
# we pass join=True to wait for all sub processes to close
|
|
|
|
# this is because below we want to test if all sub-processes
|
|
|
|
# have exited
|
|
|
|
model.close_multiprocessing_pool(join=True)
|
|
|
|
|
|
|
|
# check if all workers (sub processes) are closed
|
|
|
|
current_process = psutil.Process()
|
|
|
|
children = current_process.children()
|
2022-06-07 09:23:03 +02:00
|
|
|
if len(children) != 0:
|
|
|
|
logging.error(f"Not all the subprocesses are closed! {len(children)} are still running.")
|
2021-09-22 16:56:51 +02:00
|
|
|
|
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.fixture
|
2021-09-22 16:56:51 +02:00
|
|
|
def bert_base_squad2(request):
|
|
|
|
model = QAInferencer.load(
|
2022-02-03 13:43:18 +01:00
|
|
|
"deepset/minilm-uncased-squad2",
|
|
|
|
task_type="question_answering",
|
|
|
|
batch_size=4,
|
|
|
|
num_processes=0,
|
|
|
|
multithreading_rust=False,
|
|
|
|
use_fast=True, # TODO parametrize this to test slow as well
|
2021-09-22 16:56:51 +02:00
|
|
|
)
|
|
|
|
return model
|