Improve document stores unit test parametrization (#1202)

This commit is contained in:
vblagoje 2021-06-22 16:08:23 +02:00 committed by GitHub
parent a8f3601e6a
commit 02fc4c7783
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 25 additions and 8 deletions

View File

@ -32,6 +32,29 @@ from haystack.summarizer.transformers import TransformersSummarizer
from haystack.translator import TransformersTranslator
def pytest_addoption(parser):
parser.addoption("--document_store_type", action="store", default="all")
def pytest_generate_tests(metafunc):
# parametrize document_store fixture if it's in the test function argument list
# but does not have an explicit parametrize annotation e.g
# @pytest.mark.parametrize("document_store", ["memory"], indirect=False)
found_mark_parametrize_document_store = False
for marker in metafunc.definition.iter_markers('parametrize'):
if 'document_store' in marker.args[0]:
found_mark_parametrize_document_store = True
break
if 'document_store' in metafunc.fixturenames and not found_mark_parametrize_document_store:
document_store_type = metafunc.config.option.document_store_type
if "all" in document_store_type:
document_store_type = "elasticsearch, faiss, memory, milvus"
document_store_types = [item.strip() for item in document_store_type.split(",")]
metafunc.parametrize("document_store", document_store_types, indirect=True)
def _sql_session_rollback(self, attr):
"""
Inject SQLDocumentStore at runtime to do a session rollback each time it is called. This allows to catch
@ -348,7 +371,8 @@ def document_store_with_docs(request, test_docs_xs):
yield document_store
document_store.delete_all_documents()
@pytest.fixture(params=["elasticsearch", "faiss", "memory", "sql", "milvus"])
@pytest.fixture
def document_store(request, test_docs_xs):
vector_dim = request.node.get_closest_marker("vector_dim", pytest.mark.vector_dim(768))
document_store = get_document_store(request.param, vector_dim.args[0])

View File

@ -32,7 +32,6 @@ def test_init_elastic_client():
@pytest.mark.elasticsearch
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "sql", "milvus"], indirect=True)
def test_write_with_duplicate_doc_ids(document_store):
documents = [
Document(
@ -158,7 +157,6 @@ def test_get_all_documents_generator(document_store):
@pytest.mark.elasticsearch
@pytest.mark.parametrize("document_store", ["elasticsearch", "sql", "faiss", "milvus"], indirect=True)
@pytest.mark.parametrize("update_existing_documents", [True, False])
def test_update_existing_documents(document_store, update_existing_documents):
original_docs = [
@ -221,7 +219,6 @@ def test_write_document_index(document_store):
@pytest.mark.elasticsearch
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus"], indirect=True)
def test_document_with_embeddings(document_store):
documents = [
{"text": "text1", "id": "1", "embedding": np.random.rand(768).astype(np.float32)},
@ -240,7 +237,6 @@ def test_document_with_embeddings(document_store):
@pytest.mark.parametrize("retriever", ["dpr", "embedding"], indirect=True)
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus"], indirect=True)
def test_update_embeddings(document_store, retriever):
documents = []
for i in range(6):

View File

@ -435,7 +435,6 @@ def test_generator_pipeline(document_store, retriever, rag_generator):
@pytest.mark.slow
@pytest.mark.generator
@pytest.mark.elasticsearch
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus"], indirect=True)
@pytest.mark.parametrize("retriever", ["retribert"], indirect=True)
@pytest.mark.vector_dim(128)
def test_lfqa_pipeline(document_store, retriever, eli5_generator):

View File

@ -140,7 +140,6 @@ def test_elasticsearch_custom_query(elasticsearch_fixture):
@pytest.mark.slow
@pytest.mark.elasticsearch
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus"], indirect=True)
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
def test_dpr_embedding(document_store, retriever):
@ -164,7 +163,6 @@ def test_dpr_embedding(document_store, retriever):
@pytest.mark.slow
@pytest.mark.elasticsearch
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus"], indirect=True)
@pytest.mark.parametrize("retriever", ["retribert"], indirect=True)
@pytest.mark.vector_dim(128)
def test_retribert_embedding(document_store, retriever):