From 02fc4c7783369f2db76d0ce5fa9fae69a13989eb Mon Sep 17 00:00:00 2001 From: vblagoje Date: Tue, 22 Jun 2021 16:08:23 +0200 Subject: [PATCH] Improve document stores unit test parametrization (#1202) --- test/conftest.py | 26 +++++++++++++++++++++++++- test/test_document_store.py | 4 ---- test/test_generator.py | 1 - test/test_retriever.py | 2 -- 4 files changed, 25 insertions(+), 8 deletions(-) diff --git a/test/conftest.py b/test/conftest.py index e4b25d3cb..3d7fa558e 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -32,6 +32,29 @@ from haystack.summarizer.transformers import TransformersSummarizer from haystack.translator import TransformersTranslator +def pytest_addoption(parser): + parser.addoption("--document_store_type", action="store", default="all") + + +def pytest_generate_tests(metafunc): + # parametrize document_store fixture if it's in the test function argument list + # but does not have an explicit parametrize annotation e.g + # @pytest.mark.parametrize("document_store", ["memory"], indirect=False) + found_mark_parametrize_document_store = False + for marker in metafunc.definition.iter_markers('parametrize'): + if 'document_store' in marker.args[0]: + found_mark_parametrize_document_store = True + break + + if 'document_store' in metafunc.fixturenames and not found_mark_parametrize_document_store: + document_store_type = metafunc.config.option.document_store_type + if "all" in document_store_type: + document_store_type = "elasticsearch, faiss, memory, milvus" + + document_store_types = [item.strip() for item in document_store_type.split(",")] + metafunc.parametrize("document_store", document_store_types, indirect=True) + + def _sql_session_rollback(self, attr): """ Inject SQLDocumentStore at runtime to do a session rollback each time it is called. This allows to catch @@ -348,7 +371,8 @@ def document_store_with_docs(request, test_docs_xs): yield document_store document_store.delete_all_documents() -@pytest.fixture(params=["elasticsearch", "faiss", "memory", "sql", "milvus"]) + +@pytest.fixture def document_store(request, test_docs_xs): vector_dim = request.node.get_closest_marker("vector_dim", pytest.mark.vector_dim(768)) document_store = get_document_store(request.param, vector_dim.args[0]) diff --git a/test/test_document_store.py b/test/test_document_store.py index b30e3bbc3..dbada17cb 100644 --- a/test/test_document_store.py +++ b/test/test_document_store.py @@ -32,7 +32,6 @@ def test_init_elastic_client(): @pytest.mark.elasticsearch -@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "sql", "milvus"], indirect=True) def test_write_with_duplicate_doc_ids(document_store): documents = [ Document( @@ -158,7 +157,6 @@ def test_get_all_documents_generator(document_store): @pytest.mark.elasticsearch -@pytest.mark.parametrize("document_store", ["elasticsearch", "sql", "faiss", "milvus"], indirect=True) @pytest.mark.parametrize("update_existing_documents", [True, False]) def test_update_existing_documents(document_store, update_existing_documents): original_docs = [ @@ -221,7 +219,6 @@ def test_write_document_index(document_store): @pytest.mark.elasticsearch -@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus"], indirect=True) def test_document_with_embeddings(document_store): documents = [ {"text": "text1", "id": "1", "embedding": np.random.rand(768).astype(np.float32)}, @@ -240,7 +237,6 @@ def test_document_with_embeddings(document_store): @pytest.mark.parametrize("retriever", ["dpr", "embedding"], indirect=True) -@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus"], indirect=True) def test_update_embeddings(document_store, retriever): documents = [] for i in range(6): diff --git a/test/test_generator.py b/test/test_generator.py index 085e52703..a643f0a02 100644 --- a/test/test_generator.py +++ b/test/test_generator.py @@ -435,7 +435,6 @@ def test_generator_pipeline(document_store, retriever, rag_generator): @pytest.mark.slow @pytest.mark.generator @pytest.mark.elasticsearch -@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus"], indirect=True) @pytest.mark.parametrize("retriever", ["retribert"], indirect=True) @pytest.mark.vector_dim(128) def test_lfqa_pipeline(document_store, retriever, eli5_generator): diff --git a/test/test_retriever.py b/test/test_retriever.py index 539dad2c7..3c11997bc 100644 --- a/test/test_retriever.py +++ b/test/test_retriever.py @@ -140,7 +140,6 @@ def test_elasticsearch_custom_query(elasticsearch_fixture): @pytest.mark.slow @pytest.mark.elasticsearch -@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus"], indirect=True) @pytest.mark.parametrize("retriever", ["dpr"], indirect=True) def test_dpr_embedding(document_store, retriever): @@ -164,7 +163,6 @@ def test_dpr_embedding(document_store, retriever): @pytest.mark.slow @pytest.mark.elasticsearch -@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus"], indirect=True) @pytest.mark.parametrize("retriever", ["retribert"], indirect=True) @pytest.mark.vector_dim(128) def test_retribert_embedding(document_store, retriever):