2020-09-15 15:04:46 +02:00
|
|
|
import subprocess
|
2020-05-04 18:00:07 +02:00
|
|
|
import time
|
2020-09-15 15:04:46 +02:00
|
|
|
from subprocess import run
|
|
|
|
from sys import platform
|
2020-05-04 18:00:07 +02:00
|
|
|
|
|
|
|
import pytest
|
2020-09-15 15:04:46 +02:00
|
|
|
import requests
|
2020-06-09 12:46:15 +02:00
|
|
|
from elasticsearch import Elasticsearch
|
2021-06-14 17:53:43 +02:00
|
|
|
|
2021-07-13 21:44:26 +02:00
|
|
|
from haystack.classifier import FARMClassifier
|
2021-06-14 17:53:43 +02:00
|
|
|
from haystack.generator.transformers import Seq2SeqGenerator
|
2021-04-08 14:05:33 +02:00
|
|
|
from haystack.knowledge_graph.graphdb import GraphDBKnowledgeGraph
|
2021-01-29 13:29:12 +01:00
|
|
|
from milvus import Milvus
|
|
|
|
|
2021-06-10 13:13:53 +05:30
|
|
|
import weaviate
|
|
|
|
from haystack.document_store.weaviate import WeaviateDocumentStore
|
|
|
|
|
2021-01-29 13:29:12 +01:00
|
|
|
from haystack.document_store.milvus import MilvusDocumentStore
|
2020-10-30 18:06:02 +01:00
|
|
|
from haystack.generator.transformers import RAGenerator, RAGeneratorType
|
2021-07-07 17:31:45 +02:00
|
|
|
from haystack.ranker import FARMRanker, SentenceTransformersRanker
|
2020-10-30 18:06:02 +01:00
|
|
|
|
2020-10-14 16:15:04 +02:00
|
|
|
from haystack.retriever.sparse import ElasticsearchFilterOnlyRetriever, ElasticsearchRetriever, TfidfRetriever
|
|
|
|
|
|
|
|
from haystack.retriever.dense import DensePassageRetriever, EmbeddingRetriever
|
2020-05-04 18:00:07 +02:00
|
|
|
|
2020-09-16 18:33:23 +02:00
|
|
|
from haystack import Document
|
|
|
|
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
|
|
|
|
from haystack.document_store.faiss import FAISSDocumentStore
|
|
|
|
from haystack.document_store.memory import InMemoryDocumentStore
|
|
|
|
from haystack.document_store.sql import SQLDocumentStore
|
2020-07-10 10:54:56 +02:00
|
|
|
from haystack.reader.farm import FARMReader
|
|
|
|
from haystack.reader.transformers import TransformersReader
|
2021-01-08 14:29:46 +01:00
|
|
|
from haystack.summarizer.transformers import TransformersSummarizer
|
2021-02-12 15:58:26 +01:00
|
|
|
from haystack.translator import TransformersTranslator
|
2021-07-26 17:20:43 +02:00
|
|
|
from haystack.question_generator import QuestionGenerator
|
2020-07-10 10:54:56 +02:00
|
|
|
|
2020-05-04 18:00:07 +02:00
|
|
|
|
2021-06-22 16:08:23 +02:00
|
|
|
def pytest_addoption(parser):
|
|
|
|
parser.addoption("--document_store_type", action="store", default="all")
|
|
|
|
|
|
|
|
|
|
|
|
def pytest_generate_tests(metafunc):
|
|
|
|
# parametrize document_store fixture if it's in the test function argument list
|
|
|
|
# but does not have an explicit parametrize annotation e.g
|
|
|
|
# @pytest.mark.parametrize("document_store", ["memory"], indirect=False)
|
|
|
|
found_mark_parametrize_document_store = False
|
|
|
|
for marker in metafunc.definition.iter_markers('parametrize'):
|
|
|
|
if 'document_store' in marker.args[0]:
|
|
|
|
found_mark_parametrize_document_store = True
|
|
|
|
break
|
|
|
|
|
|
|
|
if 'document_store' in metafunc.fixturenames and not found_mark_parametrize_document_store:
|
|
|
|
document_store_type = metafunc.config.option.document_store_type
|
|
|
|
if "all" in document_store_type:
|
|
|
|
document_store_type = "elasticsearch, faiss, memory, milvus"
|
|
|
|
|
|
|
|
document_store_types = [item.strip() for item in document_store_type.split(",")]
|
|
|
|
metafunc.parametrize("document_store", document_store_types, indirect=True)
|
|
|
|
|
|
|
|
|
2021-01-22 14:39:24 +01:00
|
|
|
def _sql_session_rollback(self, attr):
|
|
|
|
"""
|
|
|
|
Inject SQLDocumentStore at runtime to do a session rollback each time it is called. This allows to catch
|
|
|
|
errors where an intended operation is still in a transaction, but not committed to the database.
|
|
|
|
"""
|
|
|
|
method = object.__getattribute__(self, attr)
|
|
|
|
if callable(method):
|
|
|
|
try:
|
|
|
|
self.session.rollback()
|
|
|
|
except AttributeError:
|
|
|
|
pass
|
|
|
|
|
|
|
|
return method
|
|
|
|
|
|
|
|
|
|
|
|
SQLDocumentStore.__getattribute__ = _sql_session_rollback
|
|
|
|
|
|
|
|
|
2020-10-30 18:06:02 +01:00
|
|
|
def pytest_collection_modifyitems(items):
|
|
|
|
for item in items:
|
|
|
|
if "generator" in item.nodeid:
|
|
|
|
item.add_marker(pytest.mark.generator)
|
2021-01-08 14:29:46 +01:00
|
|
|
elif "summarizer" in item.nodeid:
|
|
|
|
item.add_marker(pytest.mark.summarizer)
|
2020-10-30 18:06:02 +01:00
|
|
|
elif "tika" in item.nodeid:
|
|
|
|
item.add_marker(pytest.mark.tika)
|
|
|
|
elif "elasticsearch" in item.nodeid:
|
|
|
|
item.add_marker(pytest.mark.elasticsearch)
|
2021-04-08 14:05:33 +02:00
|
|
|
elif "graphdb" in item.nodeid:
|
|
|
|
item.add_marker(pytest.mark.graphdb)
|
2020-12-03 10:27:06 +01:00
|
|
|
elif "pipeline" in item.nodeid:
|
|
|
|
item.add_marker(pytest.mark.pipeline)
|
2020-10-30 18:06:02 +01:00
|
|
|
elif "slow" in item.nodeid:
|
|
|
|
item.add_marker(pytest.mark.slow)
|
2021-06-10 13:13:53 +05:30
|
|
|
elif "weaviate" in item.nodeid:
|
|
|
|
item.add_marker(pytest.mark.weaviate)
|
2020-10-30 18:06:02 +01:00
|
|
|
|
|
|
|
|
2020-05-04 18:00:07 +02:00
|
|
|
@pytest.fixture(scope="session")
|
2020-09-15 15:04:46 +02:00
|
|
|
def elasticsearch_fixture():
|
2020-06-09 12:46:15 +02:00
|
|
|
# test if a ES cluster is already running. If not, download and start an ES instance locally.
|
|
|
|
try:
|
2020-09-15 15:04:46 +02:00
|
|
|
client = Elasticsearch(hosts=[{"host": "localhost", "port": "9200"}])
|
2020-06-09 12:46:15 +02:00
|
|
|
client.info()
|
|
|
|
except:
|
2020-09-15 15:04:46 +02:00
|
|
|
print("Starting Elasticsearch ...")
|
|
|
|
status = subprocess.run(
|
2020-10-06 16:09:56 +02:00
|
|
|
['docker rm haystack_test_elastic'],
|
|
|
|
shell=True
|
|
|
|
)
|
|
|
|
status = subprocess.run(
|
2020-10-23 17:50:49 +02:00
|
|
|
['docker run -d --name haystack_test_elastic -p 9200:9200 -e "discovery.type=single-node" elasticsearch:7.9.2'],
|
2020-09-15 15:04:46 +02:00
|
|
|
shell=True
|
|
|
|
)
|
|
|
|
if status.returncode:
|
|
|
|
raise Exception(
|
|
|
|
"Failed to launch Elasticsearch. Please check docker container logs.")
|
|
|
|
time.sleep(30)
|
|
|
|
|
|
|
|
|
2021-01-29 13:29:12 +01:00
|
|
|
@pytest.fixture(scope="session")
|
|
|
|
def milvus_fixture():
|
|
|
|
# test if a Milvus server is already running. If not, start Milvus docker container locally.
|
|
|
|
# Make sure you have given > 6GB memory to docker engine
|
|
|
|
try:
|
|
|
|
milvus_server = Milvus(uri="tcp://localhost:19530", timeout=5, wait_timeout=5)
|
|
|
|
milvus_server.server_status(timeout=5)
|
|
|
|
except:
|
|
|
|
print("Starting Milvus ...")
|
|
|
|
status = subprocess.run(['docker run -d --name milvus_cpu_0.10.5 -p 19530:19530 -p 19121:19121 '
|
|
|
|
'milvusdb/milvus:0.10.5-cpu-d010621-4eda95'], shell=True)
|
|
|
|
time.sleep(40)
|
|
|
|
|
2021-06-10 13:13:53 +05:30
|
|
|
@pytest.fixture(scope="session")
|
|
|
|
def weaviate_fixture():
|
|
|
|
# test if a Weaviate server is already running. If not, start Weaviate docker container locally.
|
|
|
|
# Make sure you have given > 6GB memory to docker engine
|
|
|
|
try:
|
|
|
|
weaviate_server = weaviate.Client(url='http://localhost:8080', timeout_config=(5, 15))
|
|
|
|
weaviate_server.is_ready()
|
|
|
|
except:
|
|
|
|
print("Starting Weaviate servers ...")
|
|
|
|
status = subprocess.run(
|
|
|
|
['docker rm haystack_test_weaviate'],
|
|
|
|
shell=True
|
|
|
|
)
|
|
|
|
status = subprocess.run(
|
|
|
|
['docker run -d --name haystack_test_weaviate -p 8080:8080 semitechnologies/weaviate:1.4.0'],
|
|
|
|
shell=True
|
|
|
|
)
|
|
|
|
if status.returncode:
|
|
|
|
raise Exception(
|
|
|
|
"Failed to launch Weaviate. Please check docker container logs.")
|
|
|
|
time.sleep(60)
|
2021-01-29 13:29:12 +01:00
|
|
|
|
2021-04-08 14:05:33 +02:00
|
|
|
@pytest.fixture(scope="session")
|
|
|
|
def graphdb_fixture():
|
|
|
|
# test if a GraphDB instance is already running. If not, download and start a GraphDB instance locally.
|
|
|
|
try:
|
|
|
|
kg = GraphDBKnowledgeGraph()
|
|
|
|
# fail if not running GraphDB
|
|
|
|
kg.delete_index()
|
|
|
|
except:
|
|
|
|
print("Starting GraphDB ...")
|
|
|
|
status = subprocess.run(
|
|
|
|
['docker rm haystack_test_graphdb'],
|
|
|
|
shell=True
|
|
|
|
)
|
|
|
|
status = subprocess.run(
|
|
|
|
['docker run -d -p 7200:7200 --name haystack_test_graphdb docker-registry.ontotext.com/graphdb-free:9.4.1-adoptopenjdk11'],
|
|
|
|
shell=True
|
|
|
|
)
|
|
|
|
if status.returncode:
|
|
|
|
raise Exception(
|
|
|
|
"Failed to launch GraphDB. Please check docker container logs.")
|
|
|
|
time.sleep(30)
|
|
|
|
|
|
|
|
|
2020-09-15 15:04:46 +02:00
|
|
|
@pytest.fixture(scope="session")
|
|
|
|
def tika_fixture():
|
|
|
|
try:
|
|
|
|
tika_url = "http://localhost:9998/tika"
|
|
|
|
ping = requests.get(tika_url)
|
|
|
|
if ping.status_code != 200:
|
|
|
|
raise Exception(
|
|
|
|
"Unable to connect Tika. Please check tika endpoint {0}.".format(tika_url))
|
|
|
|
except:
|
|
|
|
print("Starting Tika ...")
|
|
|
|
status = subprocess.run(
|
|
|
|
['docker run -d --name tika -p 9998:9998 apache/tika:1.24.1'],
|
|
|
|
shell=True
|
|
|
|
)
|
|
|
|
if status.returncode:
|
|
|
|
raise Exception(
|
|
|
|
"Failed to launch Tika. Please check docker container logs.")
|
|
|
|
time.sleep(30)
|
2020-06-08 11:07:19 +02:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
2020-09-15 15:04:46 +02:00
|
|
|
def xpdf_fixture(tika_fixture):
|
2020-06-08 11:07:19 +02:00
|
|
|
verify_installation = run(["pdftotext"], shell=True)
|
|
|
|
if verify_installation.returncode == 127:
|
2020-09-15 15:04:46 +02:00
|
|
|
if platform.startswith("linux"):
|
|
|
|
platform_id = "linux"
|
|
|
|
sudo_prefix = "sudo"
|
|
|
|
elif platform.startswith("darwin"):
|
|
|
|
platform_id = "mac"
|
|
|
|
# For Mac, generally sudo need password in interactive console.
|
|
|
|
# But most of the cases current user already have permission to copy to /user/local/bin.
|
|
|
|
# Hence removing sudo requirement for Mac.
|
|
|
|
sudo_prefix = ""
|
|
|
|
else:
|
|
|
|
raise Exception(
|
|
|
|
"""Currently auto installation of pdftotext is not supported on {0} platform """.format(platform)
|
|
|
|
)
|
2021-02-12 13:38:54 +01:00
|
|
|
commands = """ wget --no-check-certificate https://dl.xpdfreader.com/xpdf-tools-{0}-4.03.tar.gz &&
|
|
|
|
tar -xvf xpdf-tools-{0}-4.03.tar.gz &&
|
|
|
|
{1} cp xpdf-tools-{0}-4.03/bin64/pdftotext /usr/local/bin""".format(platform_id, sudo_prefix)
|
2020-06-08 11:07:19 +02:00
|
|
|
run([commands], shell=True)
|
|
|
|
|
|
|
|
verify_installation = run(["pdftotext -v"], shell=True)
|
|
|
|
if verify_installation.returncode == 127:
|
|
|
|
raise Exception(
|
|
|
|
"""pdftotext is not installed. It is part of xpdf or poppler-utils software suite.
|
|
|
|
You can download for your OS from here: https://www.xpdfreader.com/download.html."""
|
|
|
|
)
|
2020-07-10 10:54:56 +02:00
|
|
|
|
2020-08-17 11:21:09 +02:00
|
|
|
|
2020-12-03 10:27:06 +01:00
|
|
|
@pytest.fixture(scope="module")
|
2020-10-30 18:06:02 +01:00
|
|
|
def rag_generator():
|
|
|
|
return RAGenerator(
|
|
|
|
model_name_or_path="facebook/rag-token-nq",
|
|
|
|
generator_type=RAGeneratorType.TOKEN
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2021-07-26 17:20:43 +02:00
|
|
|
@pytest.fixture(scope="module")
|
|
|
|
def question_generator():
|
|
|
|
return QuestionGenerator(model_name_or_path="valhalla/t5-small-e2e-qg")
|
|
|
|
|
|
|
|
|
2021-06-14 17:53:43 +02:00
|
|
|
@pytest.fixture(scope="module")
|
|
|
|
def eli5_generator():
|
|
|
|
return Seq2SeqGenerator(model_name_or_path="yjernite/bart_eli5")
|
|
|
|
|
|
|
|
|
2021-01-08 14:29:46 +01:00
|
|
|
@pytest.fixture(scope="module")
|
|
|
|
def summarizer():
|
|
|
|
return TransformersSummarizer(
|
|
|
|
model_name_or_path="google/pegasus-xsum",
|
|
|
|
use_gpu=-1
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2021-02-12 15:58:26 +01:00
|
|
|
@pytest.fixture(scope="module")
|
|
|
|
def en_to_de_translator():
|
|
|
|
return TransformersTranslator(
|
|
|
|
model_name_or_path="Helsinki-NLP/opus-mt-en-de",
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
|
|
def de_to_en_translator():
|
|
|
|
return TransformersTranslator(
|
|
|
|
model_name_or_path="Helsinki-NLP/opus-mt-de-en",
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2020-12-03 10:27:06 +01:00
|
|
|
@pytest.fixture(scope="module")
|
2020-07-10 10:54:56 +02:00
|
|
|
def test_docs_xs():
|
|
|
|
return [
|
2020-07-31 11:34:06 +02:00
|
|
|
# current "dict" format for a document
|
2020-07-14 09:53:31 +02:00
|
|
|
{"text": "My name is Carla and I live in Berlin", "meta": {"meta_field": "test1", "name": "filename1"}},
|
2020-07-31 11:34:06 +02:00
|
|
|
# meta_field at the top level for backward compatibility
|
|
|
|
{"text": "My name is Paul and I live in New York", "meta_field": "test2", "name": "filename2"},
|
|
|
|
# Document object for a doc
|
|
|
|
Document(text="My name is Christelle and I live in Paris", meta={"meta_field": "test3", "name": "filename3"})
|
2020-07-10 10:54:56 +02:00
|
|
|
]
|
|
|
|
|
|
|
|
|
2021-08-17 10:27:11 +02:00
|
|
|
@pytest.fixture(scope="module")
|
|
|
|
def reader_without_normalized_scores():
|
|
|
|
return FARMReader(
|
|
|
|
model_name_or_path="distilbert-base-uncased-distilled-squad",
|
|
|
|
use_gpu=False,
|
|
|
|
top_k_per_sample=5,
|
|
|
|
num_processes=0,
|
|
|
|
use_confidence_scores=False
|
|
|
|
)
|
|
|
|
|
2020-12-03 10:27:06 +01:00
|
|
|
@pytest.fixture(params=["farm", "transformers"], scope="module")
|
|
|
|
def reader(request):
|
2020-07-10 10:54:56 +02:00
|
|
|
if request.param == "farm":
|
2020-12-03 10:27:06 +01:00
|
|
|
return FARMReader(
|
|
|
|
model_name_or_path="distilbert-base-uncased-distilled-squad",
|
|
|
|
use_gpu=False,
|
|
|
|
top_k_per_sample=5,
|
|
|
|
num_processes=0
|
|
|
|
)
|
2020-07-10 10:54:56 +02:00
|
|
|
if request.param == "transformers":
|
2020-12-03 10:27:06 +01:00
|
|
|
return TransformersReader(
|
|
|
|
model_name_or_path="distilbert-base-uncased-distilled-squad",
|
|
|
|
tokenizer="distilbert-base-uncased",
|
|
|
|
use_gpu=-1
|
|
|
|
)
|
2020-07-10 10:54:56 +02:00
|
|
|
|
2021-07-07 17:31:45 +02:00
|
|
|
@pytest.fixture(params=["farm", "sentencetransformers"], scope="module")
|
|
|
|
def ranker(request):
|
|
|
|
if request.param == "farm":
|
|
|
|
return FARMRanker(
|
|
|
|
model_name_or_path="deepset/gbert-base-germandpr-reranking"
|
|
|
|
)
|
|
|
|
if request.param == "sentencetransformers":
|
|
|
|
return SentenceTransformersRanker(
|
|
|
|
model_name_or_path="cross-encoder/ms-marco-MiniLM-L-12-v2",
|
|
|
|
)
|
|
|
|
|
2020-07-10 10:54:56 +02:00
|
|
|
|
2021-07-13 21:44:26 +02:00
|
|
|
@pytest.fixture(params=["farm"], scope="module")
|
|
|
|
def classifier(request):
|
|
|
|
if request.param == "farm":
|
|
|
|
return FARMClassifier(
|
|
|
|
model_name_or_path="deepset/bert-base-german-cased-sentiment-Germeval17"
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2020-07-14 18:53:15 +02:00
|
|
|
# TODO Fix bug in test_no_answer_output when using
|
|
|
|
# @pytest.fixture(params=["farm", "transformers"])
|
2020-12-03 10:27:06 +01:00
|
|
|
@pytest.fixture(params=["farm"], scope="module")
|
|
|
|
def no_answer_reader(request):
|
2020-07-14 18:53:15 +02:00
|
|
|
if request.param == "farm":
|
2020-12-03 10:27:06 +01:00
|
|
|
return FARMReader(
|
|
|
|
model_name_or_path="deepset/roberta-base-squad2",
|
|
|
|
use_gpu=False,
|
|
|
|
top_k_per_sample=5,
|
|
|
|
no_ans_boost=0,
|
2020-12-07 14:07:20 +01:00
|
|
|
return_no_answer=True,
|
2020-12-03 10:27:06 +01:00
|
|
|
num_processes=0
|
|
|
|
)
|
2020-07-14 18:53:15 +02:00
|
|
|
if request.param == "transformers":
|
2020-12-03 10:27:06 +01:00
|
|
|
return TransformersReader(
|
|
|
|
model_name_or_path="deepset/roberta-base-squad2",
|
|
|
|
tokenizer="deepset/roberta-base-squad2",
|
|
|
|
use_gpu=-1,
|
|
|
|
top_k_per_candidate=5
|
|
|
|
)
|
2020-07-14 18:53:15 +02:00
|
|
|
|
|
|
|
|
2020-12-03 10:27:06 +01:00
|
|
|
@pytest.fixture(scope="module")
|
2020-07-14 18:53:15 +02:00
|
|
|
def prediction(reader, test_docs_xs):
|
2020-07-31 11:34:06 +02:00
|
|
|
docs = [Document.from_dict(d) if isinstance(d, dict) else d for d in test_docs_xs]
|
2020-11-30 17:50:04 +01:00
|
|
|
prediction = reader.predict(query="Who lives in Berlin?", documents=docs, top_k=5)
|
2020-07-14 18:53:15 +02:00
|
|
|
return prediction
|
|
|
|
|
|
|
|
|
2020-12-03 10:27:06 +01:00
|
|
|
@pytest.fixture(scope="module")
|
2020-07-14 18:53:15 +02:00
|
|
|
def no_answer_prediction(no_answer_reader, test_docs_xs):
|
2020-07-31 11:34:06 +02:00
|
|
|
docs = [Document.from_dict(d) if isinstance(d, dict) else d for d in test_docs_xs]
|
2020-11-30 17:50:04 +01:00
|
|
|
prediction = no_answer_reader.predict(query="What is the meaning of life?", documents=docs, top_k=5)
|
2020-07-14 18:53:15 +02:00
|
|
|
return prediction
|
|
|
|
|
|
|
|
|
2020-10-23 17:50:49 +02:00
|
|
|
@pytest.fixture(params=["es_filter_only", "elasticsearch", "dpr", "embedding", "tfidf"])
|
2020-10-14 16:15:04 +02:00
|
|
|
def retriever(request, document_store):
|
|
|
|
return get_retriever(request.param, document_store)
|
|
|
|
|
|
|
|
|
2020-10-23 17:50:49 +02:00
|
|
|
@pytest.fixture(params=["es_filter_only", "elasticsearch", "dpr", "embedding", "tfidf"])
|
2020-10-14 16:15:04 +02:00
|
|
|
def retriever_with_docs(request, document_store_with_docs):
|
|
|
|
return get_retriever(request.param, document_store_with_docs)
|
|
|
|
|
|
|
|
|
|
|
|
def get_retriever(retriever_type, document_store):
|
|
|
|
|
|
|
|
if retriever_type == "dpr":
|
2020-10-30 23:52:06 +05:30
|
|
|
retriever = DensePassageRetriever(document_store=document_store,
|
|
|
|
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
|
|
|
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
|
|
|
use_gpu=False, embed_title=True)
|
2020-10-14 16:15:04 +02:00
|
|
|
elif retriever_type == "tfidf":
|
2021-02-12 14:57:06 +01:00
|
|
|
retriever = TfidfRetriever(document_store=document_store)
|
|
|
|
retriever.fit()
|
2020-10-14 16:15:04 +02:00
|
|
|
elif retriever_type == "embedding":
|
2020-10-30 18:06:02 +01:00
|
|
|
retriever = EmbeddingRetriever(
|
|
|
|
document_store=document_store,
|
|
|
|
embedding_model="deepset/sentence_bert",
|
|
|
|
use_gpu=False
|
|
|
|
)
|
2021-06-14 17:53:43 +02:00
|
|
|
elif retriever_type == "retribert":
|
|
|
|
retriever = EmbeddingRetriever(document_store=document_store,
|
|
|
|
embedding_model="yjernite/retribert-base-uncased",
|
|
|
|
model_format="retribert",
|
|
|
|
use_gpu=False)
|
2020-10-23 17:50:49 +02:00
|
|
|
elif retriever_type == "elasticsearch":
|
2020-10-14 16:15:04 +02:00
|
|
|
retriever = ElasticsearchRetriever(document_store=document_store)
|
|
|
|
elif retriever_type == "es_filter_only":
|
|
|
|
retriever = ElasticsearchFilterOnlyRetriever(document_store=document_store)
|
|
|
|
else:
|
|
|
|
raise Exception(f"No retriever fixture for '{retriever_type}'")
|
|
|
|
|
|
|
|
return retriever
|
2020-12-14 18:15:44 +01:00
|
|
|
|
|
|
|
|
2021-01-29 13:29:12 +01:00
|
|
|
@pytest.fixture(params=["elasticsearch", "faiss", "memory", "sql", "milvus"])
|
2020-12-14 18:15:44 +01:00
|
|
|
def document_store_with_docs(request, test_docs_xs):
|
|
|
|
document_store = get_document_store(request.param)
|
|
|
|
document_store.write_documents(test_docs_xs)
|
|
|
|
yield document_store
|
2021-08-30 18:48:28 +05:30
|
|
|
document_store.delete_documents()
|
2020-12-14 18:15:44 +01:00
|
|
|
|
2021-06-22 16:08:23 +02:00
|
|
|
|
|
|
|
@pytest.fixture
|
2020-12-14 18:15:44 +01:00
|
|
|
def document_store(request, test_docs_xs):
|
2021-06-14 17:53:43 +02:00
|
|
|
vector_dim = request.node.get_closest_marker("vector_dim", pytest.mark.vector_dim(768))
|
|
|
|
document_store = get_document_store(request.param, vector_dim.args[0])
|
2020-12-14 18:15:44 +01:00
|
|
|
yield document_store
|
2021-08-30 18:48:28 +05:30
|
|
|
document_store.delete_documents()
|
2020-12-14 18:15:44 +01:00
|
|
|
|
|
|
|
|
2021-06-14 17:53:43 +02:00
|
|
|
def get_document_store(document_store_type, embedding_dim=768, embedding_field="embedding"):
|
2020-12-14 18:15:44 +01:00
|
|
|
if document_store_type == "sql":
|
2021-01-22 14:39:24 +01:00
|
|
|
document_store = SQLDocumentStore(url="sqlite://", index="haystack_test")
|
2020-12-14 18:15:44 +01:00
|
|
|
elif document_store_type == "memory":
|
2021-01-22 14:39:24 +01:00
|
|
|
document_store = InMemoryDocumentStore(
|
2021-06-14 17:53:43 +02:00
|
|
|
return_embedding=True, embedding_dim=embedding_dim, embedding_field=embedding_field, index="haystack_test"
|
2021-01-22 14:39:24 +01:00
|
|
|
)
|
2020-12-14 18:15:44 +01:00
|
|
|
elif document_store_type == "elasticsearch":
|
|
|
|
# make sure we start from a fresh index
|
|
|
|
client = Elasticsearch()
|
|
|
|
client.indices.delete(index='haystack_test*', ignore=[404])
|
2020-12-17 09:18:57 +01:00
|
|
|
document_store = ElasticsearchDocumentStore(
|
2021-06-14 17:53:43 +02:00
|
|
|
index="haystack_test", return_embedding=True, embedding_dim=embedding_dim, embedding_field=embedding_field
|
2020-12-17 09:18:57 +01:00
|
|
|
)
|
2020-12-14 18:15:44 +01:00
|
|
|
elif document_store_type == "faiss":
|
|
|
|
document_store = FAISSDocumentStore(
|
2021-06-14 17:53:43 +02:00
|
|
|
vector_dim=embedding_dim,
|
2021-01-21 16:00:08 +01:00
|
|
|
sql_url="sqlite://",
|
|
|
|
return_embedding=True,
|
|
|
|
embedding_field=embedding_field,
|
2021-01-22 14:39:24 +01:00
|
|
|
index="haystack_test",
|
2020-12-14 18:15:44 +01:00
|
|
|
)
|
|
|
|
return document_store
|
2021-01-29 13:29:12 +01:00
|
|
|
elif document_store_type == "milvus":
|
|
|
|
document_store = MilvusDocumentStore(
|
2021-06-14 17:53:43 +02:00
|
|
|
vector_dim=embedding_dim,
|
2021-01-29 13:29:12 +01:00
|
|
|
sql_url="sqlite://",
|
|
|
|
return_embedding=True,
|
|
|
|
embedding_field=embedding_field,
|
|
|
|
index="haystack_test",
|
|
|
|
)
|
2021-04-21 09:56:35 +02:00
|
|
|
_, collections = document_store.milvus_server.list_collections()
|
|
|
|
for collection in collections:
|
|
|
|
if collection.startswith("haystack_test"):
|
|
|
|
document_store.milvus_server.drop_collection(collection)
|
2021-01-29 13:29:12 +01:00
|
|
|
return document_store
|
2021-06-10 13:13:53 +05:30
|
|
|
elif document_store_type == "weaviate":
|
|
|
|
document_store = WeaviateDocumentStore(
|
|
|
|
weaviate_url="http://localhost:8080",
|
|
|
|
index="Haystacktest"
|
|
|
|
)
|
|
|
|
document_store.weaviate_client.schema.delete_all()
|
|
|
|
document_store._create_schema_and_index_if_not_exist()
|
|
|
|
return document_store
|
2020-12-14 18:15:44 +01:00
|
|
|
else:
|
|
|
|
raise Exception(f"No document store fixture for '{document_store_type}'")
|
|
|
|
|
|
|
|
return document_store
|