From 3e4dbbb32cfa63a9e0be0f87b0a6ad59cefa0c8f Mon Sep 17 00:00:00 2001 From: MichelBartels Date: Wed, 12 Jan 2022 19:28:20 +0100 Subject: [PATCH] Align similarity scores across document stores (#1967) * align document store similarity functions * remove unnecessary imports * undone accidental change * stopped weaviate from pretending to support dot product similarity * stopped weaviate from pretending to support dot product similarity * Add latest docstring and tutorial changes * fix fixture params for document stores * use cosine similarity for most tests * fix cosine similarity test * fix faiss test * fix weaviate test * fix accidental deletion * fix document_store fixture * test fix; shouldn't be merged * fix test_normalize_embeddings_diff_shapes * probably a better fix * fix for parameter combinations * revert new pytest_generate_tests functionality * simplify pytest_generate_tests * normalize embeddings for test_dpr_embedding * add to faiss doc that embeddings are normalized * Add latest docstring and tutorial changes * remove unnecessary parameters and add comments * simplify two lines of memory.py into one * test similarity scores with smaller language model * fix test_similarity_score Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/_src/api/api/document_store.md | 11 +++++---- haystack/document_stores/faiss.py | 6 ++--- haystack/document_stores/memory.py | 6 ++--- haystack/document_stores/weaviate.py | 7 ++++-- test/conftest.py | 34 +++++++++++++++++++------- test/test_document_store.py | 24 +++++++++++++++++-- test/test_faiss_and_milvus.py | 36 +++++++++++++++------------- test/test_retriever.py | 28 +++++++++++++--------- test/test_standard_pipelines.py | 10 ++++---- 9 files changed, 104 insertions(+), 58 deletions(-) diff --git a/docs/_src/api/api/document_store.md b/docs/_src/api/api/document_store.md index e8daaf6d3..7ad468855 100644 --- a/docs/_src/api/api/document_store.md +++ b/docs/_src/api/api/document_store.md @@ -1228,7 +1228,7 @@ the vector embeddings are indexed in a FAISS Index. Benchmarks: XXX - `faiss_index`: Pass an existing FAISS Index, i.e. an empty one that you configured manually or one with docs that you used in Haystack before and want to load again. -- `return_embedding`: To return document embedding +- `return_embedding`: To return document embedding. Unlike other document stores, FAISS will return normalized embeddings - `index`: Name of index in document store to use. - `similarity`: The similarity function used to compare document vectors. 'dot_product' is the default since it is more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence-Transformer model. @@ -1322,7 +1322,7 @@ a large number of documents without having to load all documents in memory. DocumentStore's default index (self.index) will be used. - `filters`: Optional filters to narrow down the documents to return. Example: {"name": ["some", "more"], "category": ["only_one"]} -- `return_embedding`: Whether to return the document embeddings. +- `return_embedding`: Whether to return the document embeddings. Unlike other document stores, FAISS will return normalized embeddings - `batch_size`: When working with large number of documents, batching can help reduce memory footprint. @@ -1404,7 +1404,7 @@ Find the document that is most similar to the provided `query_emb` by using a ve Example: {"name": ["some", "more"], "category": ["only_one"]} - `top_k`: How many documents to return - `index`: Index name to query the document from. -- `return_embedding`: To return document embedding +- `return_embedding`: To return document embedding. Unlike other document stores, FAISS will return normalized embeddings **Returns**: @@ -1761,6 +1761,7 @@ Some of the key differences in contrast to FAISS & Milvus: 2. Allows combination of vector search and scalar filtering, i.e. you can filter for a certain tag and do dense retrieval on that subset 3. Has less variety of ANN algorithms, as of now only HNSW. 4. Requires document ids to be in uuid-format. If wrongly formatted ids are provided at indexing time they will be replaced with uuids automatically. +5. Only support cosine similarity. Weaviate python client is used to connect to the server, more details are here https://weaviate-python-client.readthedocs.io/en/docs/weaviate.html @@ -1776,7 +1777,7 @@ The current implementation is not supporting the storage of labels, so you canno #### \_\_init\_\_ ```python - | __init__(host: Union[str, List[str]] = "http://localhost", port: Union[int, List[int]] = 8080, timeout_config: tuple = (5, 15), username: str = None, password: str = None, index: str = "Document", embedding_dim: int = 768, content_field: str = "content", name_field: str = "name", similarity: str = "dot_product", index_type: str = "hnsw", custom_schema: Optional[dict] = None, return_embedding: bool = False, embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', **kwargs, ,) + | __init__(host: Union[str, List[str]] = "http://localhost", port: Union[int, List[int]] = 8080, timeout_config: tuple = (5, 15), username: str = None, password: str = None, index: str = "Document", embedding_dim: int = 768, content_field: str = "content", name_field: str = "name", similarity: str = "cosine", index_type: str = "hnsw", custom_schema: Optional[dict] = None, return_embedding: bool = False, embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', **kwargs, ,) ``` **Arguments**: @@ -1792,7 +1793,7 @@ The current implementation is not supporting the storage of labels, so you canno - `content_field`: Name of field that might contain the answer and will therefore be passed to the Reader Model (e.g. "full_text"). If no Reader is used (e.g. in FAQ-Style QA) the plain content of this field will just be returned. - `name_field`: Name of field that contains the title of the the doc -- `similarity`: The similarity function used to compare document vectors. 'dot_product' is the default. +- `similarity`: The similarity function used to compare document vectors. 'cosine' is the only currently supported option and default. 'cosine' is recommended for Sentence Transformers. - `index_type`: Index type of any vector object defined in weaviate schema. The vector index type is pluggable. Currently, HSNW is only supported. diff --git a/haystack/document_stores/faiss.py b/haystack/document_stores/faiss.py index a462e4adc..b4fa1e80d 100644 --- a/haystack/document_stores/faiss.py +++ b/haystack/document_stores/faiss.py @@ -74,7 +74,7 @@ class FAISSDocumentStore(SQLDocumentStore): Benchmarks: XXX :param faiss_index: Pass an existing FAISS Index, i.e. an empty one that you configured manually or one with docs that you used in Haystack before and want to load again. - :param return_embedding: To return document embedding + :param return_embedding: To return document embedding. Unlike other document stores, FAISS will return normalized embeddings :param index: Name of index in document store to use. :param similarity: The similarity function used to compare document vectors. 'dot_product' is the default since it is more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence-Transformer model. @@ -379,7 +379,7 @@ class FAISSDocumentStore(SQLDocumentStore): DocumentStore's default index (self.index) will be used. :param filters: Optional filters to narrow down the documents to return. Example: {"name": ["some", "more"], "category": ["only_one"]} - :param return_embedding: Whether to return the document embeddings. + :param return_embedding: Whether to return the document embeddings. Unlike other document stores, FAISS will return normalized embeddings :param batch_size: When working with large number of documents, batching can help reduce memory footprint. """ if headers: @@ -510,7 +510,7 @@ class FAISSDocumentStore(SQLDocumentStore): Example: {"name": ["some", "more"], "category": ["only_one"]} :param top_k: How many documents to return :param index: Index name to query the document from. - :param return_embedding: To return document embedding + :param return_embedding: To return document embedding. Unlike other document stores, FAISS will return normalized embeddings :return: """ if headers: diff --git a/haystack/document_stores/memory.py b/haystack/document_stores/memory.py index e92f29064..bb9a4bf1d 100644 --- a/haystack/document_stores/memory.py +++ b/haystack/document_stores/memory.py @@ -210,13 +210,11 @@ class InMemoryDocumentStore(BaseDocumentStore): new_document.embedding = doc.embedding if return_embedding is True else None if self.similarity == "dot_product": - score = np.dot(query_emb, doc.embedding) / ( - np.linalg.norm(query_emb) * np.linalg.norm(doc.embedding) - ) + score = np.dot(query_emb, doc.embedding) elif self.similarity == "cosine": # cosine similarity score = 1 - cosine distance score = 1 - cosine(query_emb, doc.embedding) - new_document.score = (score + 1) / 2 + new_document.score = self.finalize_raw_score(score, self.similarity) candidate_docs.append(new_document) return sorted(candidate_docs, key=lambda x: x.score if x.score is not None else 0.0, reverse=True)[0:top_k] diff --git a/haystack/document_stores/weaviate.py b/haystack/document_stores/weaviate.py index e17ff98fb..b163d5c25 100644 --- a/haystack/document_stores/weaviate.py +++ b/haystack/document_stores/weaviate.py @@ -31,6 +31,7 @@ class WeaviateDocumentStore(BaseDocumentStore): 2. Allows combination of vector search and scalar filtering, i.e. you can filter for a certain tag and do dense retrieval on that subset 3. Has less variety of ANN algorithms, as of now only HNSW. 4. Requires document ids to be in uuid-format. If wrongly formatted ids are provided at indexing time they will be replaced with uuids automatically. + 5. Only support cosine similarity. Weaviate python client is used to connect to the server, more details are here https://weaviate-python-client.readthedocs.io/en/docs/weaviate.html @@ -53,7 +54,7 @@ class WeaviateDocumentStore(BaseDocumentStore): embedding_dim: int = 768, content_field: str = "content", name_field: str = "name", - similarity: str = "dot_product", + similarity: str = "cosine", index_type: str = "hnsw", custom_schema: Optional[dict] = None, return_embedding: bool = False, @@ -74,7 +75,7 @@ class WeaviateDocumentStore(BaseDocumentStore): :param content_field: Name of field that might contain the answer and will therefore be passed to the Reader Model (e.g. "full_text"). If no Reader is used (e.g. in FAQ-Style QA) the plain content of this field will just be returned. :param name_field: Name of field that contains the title of the the doc - :param similarity: The similarity function used to compare document vectors. 'dot_product' is the default. + :param similarity: The similarity function used to compare document vectors. 'cosine' is the only currently supported option and default. 'cosine' is recommended for Sentence Transformers. :param index_type: Index type of any vector object defined in weaviate schema. The vector index type is pluggable. Currently, HSNW is only supported. @@ -93,6 +94,8 @@ class WeaviateDocumentStore(BaseDocumentStore): overwrite: Update any existing documents with the same ID when adding documents. fail: an error is raised if the document ID of the document being added already exists. """ + if similarity != "cosine": + raise ValueError(f"Weaviate only supports cosine similarity, but you provided {similarity}") # save init parameters to enable export of component config as YAML self.set_config( host=host, port=port, timeout_config=timeout_config, username=username, password=password, diff --git a/test/conftest.py b/test/conftest.py index b95c5cae7..67e9c2b35 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -54,7 +54,7 @@ def pytest_generate_tests(metafunc): # @pytest.mark.parametrize("document_store", ["memory"], indirect=False) found_mark_parametrize_document_store = False for marker in metafunc.definition.iter_markers('parametrize'): - if 'document_store' in marker.args[0] or 'document_store_with_docs' in marker.args[0] or 'document_store_type' in marker.args[0]: + if 'document_store' in marker.args[0]: found_mark_parametrize_document_store = True break # for all others that don't have explicit parametrization, we add the ones from the CLI arg @@ -479,34 +479,50 @@ def get_retriever(retriever_type, document_store): @pytest.fixture(params=["elasticsearch", "faiss", "memory", "milvus", "weaviate"]) def document_store_with_docs(request, test_docs_xs): - document_store = get_document_store(request.param) + embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768)) + document_store = get_document_store(request.param, embedding_dim.args[0]) document_store.write_documents(test_docs_xs) yield document_store document_store.delete_documents() @pytest.fixture -def document_store(request, test_docs_xs): +def document_store(request): embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768)) document_store = get_document_store(request.param, embedding_dim.args[0]) yield document_store document_store.delete_documents() -@pytest.fixture(params=["faiss", "milvus", "weaviate"]) -def document_store_cosine(request, test_docs_xs): +@pytest.fixture(params=["memory", "faiss", "milvus", "elasticsearch"]) +def document_store_dot_product(request): embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768)) - document_store = get_document_store(request.param, embedding_dim.args[0], similarity="cosine") + document_store = get_document_store(request.param, embedding_dim.args[0], similarity="dot_product") + yield document_store + document_store.delete_documents() + +@pytest.fixture(params=["memory", "faiss", "milvus", "elasticsearch"]) +def document_store_dot_product_with_docs(request, test_docs_xs): + embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768)) + document_store = get_document_store(request.param, embedding_dim.args[0], similarity="dot_product") + document_store.write_documents(test_docs_xs) + yield document_store + document_store.delete_documents() + +@pytest.fixture(params=["elasticsearch", "faiss", "memory", "milvus"]) +def document_store_dot_product_small(request): + embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(3)) + document_store = get_document_store(request.param, embedding_dim.args[0], similarity="dot_product") yield document_store document_store.delete_documents() @pytest.fixture(params=["elasticsearch", "faiss", "memory", "milvus", "weaviate"]) -def document_store_cosine_small(request, test_docs_xs): +def document_store_small(request): embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(3)) document_store = get_document_store(request.param, embedding_dim.args[0], similarity="cosine") yield document_store - document_store.delete_documents() + document_store.delete_documents() -def get_document_store(document_store_type, embedding_dim=768, embedding_field="embedding", index="haystack_test", similarity:str="dot_product"): +def get_document_store(document_store_type, embedding_dim=768, embedding_field="embedding", index="haystack_test", similarity:str="cosine"): # cosine is default similarity as dot product is not supported by Weaviate if document_store_type == "sql": document_store = SQLDocumentStore(url="sqlite://", index=index) elif document_store_type == "memory": diff --git a/test/test_document_store.py b/test/test_document_store.py index 39bc5caf5..a39859e10 100644 --- a/test/test_document_store.py +++ b/test/test_document_store.py @@ -14,7 +14,8 @@ from haystack.errors import DuplicateDocumentError from haystack.schema import Document, Label, Answer, Span from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore from haystack.document_stores.faiss import FAISSDocumentStore - +from haystack.nodes import EmbeddingRetriever +from haystack.pipelines import DocumentSearchPipeline @pytest.mark.elasticsearch def test_init_elastic_client(): @@ -954,8 +955,27 @@ def test_elasticsearch_synonyms(): indexed_settings = client.indices.get_settings(index="haystack_synonym_arg") assert synonym_type == indexed_settings['haystack_synonym_arg']['settings']['index']['analysis']['filter']['synonym']['type'] - assert synonyms == indexed_settings['haystack_synonym_arg']['settings']['index']['analysis']['filter']['synonym']['synonyms'] + assert synonyms == indexed_settings['haystack_synonym_arg']['settings']['index']['analysis']['filter']['synonym']['synonyms'] +@pytest.mark.parametrize("document_store_with_docs", ["memory", "faiss", "milvus", "weaviate", "elasticsearch"], indirect=True) +@pytest.mark.embedding_dim(384) +def test_similarity_score(document_store_with_docs): + retriever = EmbeddingRetriever(document_store=document_store_with_docs, embedding_model="sentence-transformers/paraphrase-MiniLM-L3-v2") + document_store_with_docs.update_embeddings(retriever) + pipeline = DocumentSearchPipeline(retriever) + prediction = pipeline.run("Paul lives in New York") + scores = [document.score for document in prediction["documents"]] + assert scores == pytest.approx([0.9102500000000191, 0.6491700000000264, 0.6321699999999737], abs=1e-3) + +@pytest.mark.parametrize("document_store_dot_product_with_docs", ["memory", "faiss", "milvus", "elasticsearch"], indirect=True) +@pytest.mark.embedding_dim(384) +def test_similarity_score_dot_product(document_store_dot_product_with_docs): + retriever = EmbeddingRetriever(document_store=document_store_dot_product_with_docs, embedding_model="sentence-transformers/paraphrase-MiniLM-L3-v2") + document_store_dot_product_with_docs.update_embeddings(retriever) + pipeline = DocumentSearchPipeline(retriever) + prediction = pipeline.run("Paul lives in New York") + scores = [document.score for document in prediction["documents"]] + assert scores == pytest.approx([0.5526493562767626, 0.5189836204008691, 0.5179697571274173], abs=1e-3) def test_custom_headers(document_store_with_docs: BaseDocumentStore): mock_client = None diff --git a/test/test_faiss_and_milvus.py b/test/test_faiss_and_milvus.py index b2ca1c7c2..8e19dcfde 100644 --- a/test/test_faiss_and_milvus.py +++ b/test/test_faiss_and_milvus.py @@ -157,7 +157,8 @@ def test_faiss_write_docs(document_store, index_buffer_size, batch_size): original_doc = [d for d in DOCUMENTS if d["content"] == doc.content][0] stored_emb = document_store.faiss_indexes[document_store.index].reconstruct(int(doc.meta["vector_id"])) # compare original input vec with stored one (ignore extra dim added by hnsw) - assert np.allclose(original_doc["embedding"], stored_emb, rtol=0.01) + # original input vec is normalized as faiss only stores normalized vectors + assert np.allclose(original_doc["embedding"] / np.linalg.norm(original_doc["embedding"]), stored_emb, rtol=0.01) @pytest.mark.slow @@ -178,7 +179,8 @@ def test_update_docs(document_store, retriever, batch_size): updated_embedding = retriever.embed_documents([Document.from_dict(original_doc)]) stored_doc = document_store.get_all_documents(filters={"name": [doc.meta["name"]]})[0] # compare original input vec with stored one (ignore extra dim added by hnsw) - assert np.allclose(updated_embedding, stored_doc.embedding, rtol=0.01) + # original input vec is normalized as faiss only stores normalized vectors + assert np.allclose(updated_embedding / np.linalg.norm(updated_embedding), stored_doc.embedding, rtol=0.01) @pytest.mark.slow @@ -441,16 +443,16 @@ def ensure_ids_are_correct_uuids(docs:list,document_store:object)->None: for d in docs: d["id"] = str(uuid.uuid4()) -def test_cosine_similarity(document_store_cosine): +def test_cosine_similarity(document_store): # below we will write documents to the store and then query it to see if vectors were normalized - ensure_ids_are_correct_uuids(docs=DOCUMENTS,document_store=document_store_cosine) - document_store_cosine.write_documents(documents=DOCUMENTS) + ensure_ids_are_correct_uuids(docs=DOCUMENTS,document_store=document_store) + document_store.write_documents(documents=DOCUMENTS) # note that the same query will be used later when querying after updating the embeddings query = np.random.rand(768).astype(np.float32) - query_results = document_store_cosine.query_by_embedding(query_emb=query, top_k=len(DOCUMENTS), return_embedding=True) + query_results = document_store.query_by_embedding(query_emb=query, top_k=len(DOCUMENTS), return_embedding=True) # check if search with cosine similarity returns the correct number of results assert len(query_results) == len(DOCUMENTS) @@ -461,7 +463,7 @@ def test_cosine_similarity(document_store_cosine): for doc in query_results: result_emb = doc.embedding original_emb = np.array([indexed_docs[doc.content]], dtype="float32") - document_store_cosine.normalize_embedding(original_emb[0]) + document_store.normalize_embedding(original_emb[0]) # check if the stored embedding was normalized assert np.allclose(original_emb[0], result_emb, rtol=0.01) @@ -475,27 +477,27 @@ def test_cosine_similarity(document_store_cosine): return [np.random.rand(768).astype(np.float32) for doc in docs] retriever = MockRetriever() - document_store_cosine.update_embeddings(retriever=retriever) - query_results = document_store_cosine.query_by_embedding(query_emb=query, top_k=len(DOCUMENTS), return_embedding=True) + document_store.update_embeddings(retriever=retriever) + query_results = document_store.query_by_embedding(query_emb=query, top_k=len(DOCUMENTS), return_embedding=True) for doc in query_results: original_emb = np.array([indexed_docs[doc.content]], dtype="float32") - document_store_cosine.normalize_embedding(original_emb[0]) + document_store.normalize_embedding(original_emb[0]) # check if the original embedding has changed after updating the embeddings assert not np.allclose(original_emb[0], doc.embedding, rtol=0.01) -def test_normalize_embeddings_diff_shapes(document_store_cosine_small): +def test_normalize_embeddings_diff_shapes(document_store_dot_product_small): VEC_1 = np.array([.1, .2, .3], dtype="float32") - document_store_cosine_small.normalize_embedding(VEC_1) + document_store_dot_product_small.normalize_embedding(VEC_1) assert np.linalg.norm(VEC_1) - 1 < 0.01 VEC_1 = np.array([.1, .2, .3], dtype="float32").reshape(1, -1) - document_store_cosine_small.normalize_embedding(VEC_1) + document_store_dot_product_small.normalize_embedding(VEC_1) assert np.linalg.norm(VEC_1) - 1 < 0.01 -def test_cosine_sanity_check(document_store_cosine_small): +def test_cosine_sanity_check(document_store_small): VEC_1 = np.array([.1, .2, .3], dtype="float32") VEC_2 = np.array([.4, .5, .6], dtype="float32") @@ -504,10 +506,10 @@ def test_cosine_sanity_check(document_store_cosine_small): KNOWN_COSINE = (0.9746317 + 1) / 2 docs = [{"name": "vec_1", "text": "vec_1", "content": "vec_1", "embedding": VEC_1}] - ensure_ids_are_correct_uuids(docs=docs,document_store=document_store_cosine_small) - document_store_cosine_small.write_documents(documents=docs) + ensure_ids_are_correct_uuids(docs=docs,document_store=document_store_small) + document_store_small.write_documents(documents=docs) - query_results = document_store_cosine_small.query_by_embedding(query_emb=VEC_2, top_k=1, return_embedding=True) + query_results = document_store_small.query_by_embedding(query_emb=VEC_2, top_k=1, return_embedding=True) # check if faiss returns the same cosine similarity. Manual testing with faiss yielded 0.9746318 assert math.isclose(query_results[0].score, KNOWN_COSINE, abs_tol=0.00002) diff --git a/test/test_retriever.py b/test/test_retriever.py index 24eba04b2..8b479e65d 100644 --- a/test/test_retriever.py +++ b/test/test_retriever.py @@ -154,17 +154,23 @@ def test_dpr_embedding(document_store, retriever, docs): document_store.update_embeddings(retriever=retriever) time.sleep(1) - doc_1 = document_store.get_document_by_id("1") - assert len(doc_1.embedding) == 768 - assert abs(doc_1.embedding[0] - (-0.3063)) < 0.001 - doc_2 = document_store.get_document_by_id("2") - assert abs(doc_2.embedding[0] - (-0.3914)) < 0.001 - doc_3 = document_store.get_document_by_id("3") - assert abs(doc_3.embedding[0] - (-0.2470)) < 0.001 - doc_4 = document_store.get_document_by_id("4") - assert abs(doc_4.embedding[0] - (-0.0802)) < 0.001 - doc_5 = document_store.get_document_by_id("5") - assert abs(doc_5.embedding[0] - (-0.0551)) < 0.001 + # always normalize vector as faiss returns normalized vectors and other document stores do not + doc_1 = document_store.get_document_by_id("1").embedding + doc_1 /= np.linalg.norm(doc_1) + assert len(doc_1) == 768 + assert abs(doc_1[0] - (-0.0250)) < 0.001 + doc_2 = document_store.get_document_by_id("2").embedding + doc_2 /= np.linalg.norm(doc_2) + assert abs(doc_2[0] - (-0.0314)) < 0.001 + doc_3 = document_store.get_document_by_id("3").embedding + doc_3 /= np.linalg.norm(doc_3) + assert abs(doc_3[0] - (-0.0200)) < 0.001 + doc_4 = document_store.get_document_by_id("4").embedding + doc_4 /= np.linalg.norm(doc_4) + assert abs(doc_4[0] - (-0.0070)) < 0.001 + doc_5 = document_store.get_document_by_id("5").embedding + doc_5 /= np.linalg.norm(doc_5) + assert abs(doc_5[0] - (-0.0049)) < 0.001 @pytest.mark.slow diff --git a/test/test_standard_pipelines.py b/test/test_standard_pipelines.py index 75f91cf52..3d05059a6 100644 --- a/test/test_standard_pipelines.py +++ b/test/test_standard_pipelines.py @@ -125,17 +125,17 @@ def test_most_similar_documents_pipeline(retriever, document_store): @pytest.mark.elasticsearch -@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True) +@pytest.mark.parametrize("document_store_dot_product_with_docs", ["elasticsearch"], indirect=True) @pytest.mark.parametrize("reader", ["farm"], indirect=True) -def test_join_document_pipeline(document_store_with_docs, reader): - es = ElasticsearchRetriever(document_store=document_store_with_docs) +def test_join_document_pipeline(document_store_dot_product_with_docs, reader): + es = ElasticsearchRetriever(document_store=document_store_dot_product_with_docs) dpr = DensePassageRetriever( - document_store=document_store_with_docs, + document_store=document_store_dot_product_with_docs, query_embedding_model="facebook/dpr-question_encoder-single-nq-base", passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base", use_gpu=False, ) - document_store_with_docs.update_embeddings(dpr) + document_store_dot_product_with_docs.update_embeddings(dpr) query = "Where does Carla live?"