From 6ca88bfd23b2114fa2748843e6d4e27e7d8e4db0 Mon Sep 17 00:00:00 2001 From: tstadel <60758086+tstadel@users.noreply.github.com> Date: Mon, 9 Jan 2023 11:58:23 +0100 Subject: [PATCH] fix: Despite return_embedding=False SearchEngineDocumentStore.query retrieves embedding_field (#3662) * fix: Despite return_embedding=False SearchEngineDocumentStore.query retrieves embedding_field * fix pylint * add tests * fix mypy * fix merge * format * fix pylint * move tests to SearchEngineDocumentStoreTestAbstract * move missed constants * add mocked_document_store fixture to TestElasticsearchDocumentStore * fix mocked_document_store * fix get_all_documents tests for elasticsearch>=7.16 * fix tests * fix tests try 2 --- haystack/document_stores/elasticsearch.py | 32 ++----- haystack/document_stores/opensearch.py | 30 ++----- haystack/document_stores/search_engine.py | 73 +++++++++------- test/document_stores/test_elasticsearch.py | 16 ++++ test/document_stores/test_opensearch.py | 1 - test/document_stores/test_search_engine.py | 98 ++++++++++++++++++++++ 6 files changed, 172 insertions(+), 78 deletions(-) diff --git a/haystack/document_stores/elasticsearch.py b/haystack/document_stores/elasticsearch.py index c716cb186..fd76c443e 100644 --- a/haystack/document_stores/elasticsearch.py +++ b/haystack/document_stores/elasticsearch.py @@ -1,6 +1,5 @@ import logging from typing import Dict, List, Optional, Type, Union -from copy import deepcopy import numpy as np @@ -377,9 +376,10 @@ class ElasticsearchDocumentStore(SearchEngineDocumentStore): if not self.embedding_field: raise RuntimeError("Please specify arg `embedding_field` in ElasticsearchDocumentStore()") - body = self._construct_dense_query_body(query_emb, filters, top_k, return_embedding) + body = self._construct_dense_query_body( + query_emb=query_emb, filters=filters, top_k=top_k, return_embedding=return_embedding + ) - logger.debug("Retriever query: %s", body) try: result = self.client.search(index=index, body=body, request_timeout=300, headers=headers)["hits"]["hits"] if len(result) == 0: @@ -399,19 +399,13 @@ class ElasticsearchDocumentStore(SearchEngineDocumentStore): raise e documents = [ - self._convert_es_hit_to_document( - hit, adapt_score_for_embedding=True, return_embedding=return_embedding, scale_score=scale_score - ) + self._convert_es_hit_to_document(hit, adapt_score_for_embedding=True, scale_score=scale_score) for hit in result ] return documents def _construct_dense_query_body( - self, - query_emb: np.ndarray, - filters: Optional[FilterType] = None, - top_k: int = 10, - return_embedding: Optional[bool] = None, + self, query_emb: np.ndarray, return_embedding: bool, filters: Optional[FilterType] = None, top_k: int = 10 ): body = {"size": top_k, "query": self._get_vector_similarity_query(query_emb, top_k)} if filters: @@ -421,20 +415,10 @@ class ElasticsearchDocumentStore(SearchEngineDocumentStore): else: body["query"]["script_score"]["query"]["bool"]["filter"]["bool"]["must"].append(filter_) - excluded_meta_data: Optional[list] = None + excluded_fields = self._get_excluded_fields(return_embedding=return_embedding) + if excluded_fields: + body["_source"] = {"excludes": excluded_fields} - if self.excluded_meta_data: - excluded_meta_data = deepcopy(self.excluded_meta_data) - - if return_embedding is True and self.embedding_field in excluded_meta_data: - excluded_meta_data.remove(self.embedding_field) - elif return_embedding is False and self.embedding_field not in excluded_meta_data: - excluded_meta_data.append(self.embedding_field) - elif return_embedding is False: - excluded_meta_data = [self.embedding_field] - - if excluded_meta_data: - body["_source"] = {"excludes": excluded_meta_data} return body def _create_document_index(self, index_name: str, headers: Optional[Dict[str, str]] = None): diff --git a/haystack/document_stores/opensearch.py b/haystack/document_stores/opensearch.py index 555c1533f..01e9e0a5f 100644 --- a/haystack/document_stores/opensearch.py +++ b/haystack/document_stores/opensearch.py @@ -1,7 +1,6 @@ from typing import List, Optional, Union, Dict, Any import logging -from copy import deepcopy import numpy as np from tqdm.auto import tqdm @@ -445,19 +444,13 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore): result = self.client.search(index=index, body=body, request_timeout=300, headers=headers)["hits"]["hits"] documents = [ - self._convert_es_hit_to_document( - hit, adapt_score_for_embedding=True, return_embedding=return_embedding, scale_score=scale_score - ) + self._convert_es_hit_to_document(hit, adapt_score_for_embedding=True, scale_score=scale_score) for hit in result ] return documents def _construct_dense_query_body( - self, - query_emb: np.ndarray, - filters: Optional[FilterType] = None, - top_k: int = 10, - return_embedding: Optional[bool] = None, + self, query_emb: np.ndarray, return_embedding: bool, filters: Optional[FilterType] = None, top_k: int = 10 ): body: Dict[str, Any] = {"size": top_k, "query": self._get_vector_similarity_query(query_emb, top_k)} if filters: @@ -468,19 +461,10 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore): else: body["query"]["bool"]["filter"] = filter_ - excluded_meta_data: Optional[list] = None - if self.excluded_meta_data: - excluded_meta_data = deepcopy(self.excluded_meta_data) + excluded_fields = self._get_excluded_fields(return_embedding=return_embedding) + if excluded_fields: + body["_source"] = {"excludes": excluded_fields} - if return_embedding is True and self.embedding_field in excluded_meta_data: - excluded_meta_data.remove(self.embedding_field) - elif return_embedding is False and self.embedding_field not in excluded_meta_data: - excluded_meta_data.append(self.embedding_field) - elif return_embedding is False: - excluded_meta_data = [self.embedding_field] - - if excluded_meta_data: - body["_source"] = {"excludes": excluded_meta_data} return body def _create_document_index(self, index_name: str, headers: Optional[Dict[str, str]] = None): @@ -842,9 +826,7 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore): opensearch_logger.setLevel(logging.CRITICAL) with tqdm(total=document_count, position=0, unit=" Docs", desc="Cloning embeddings") as progress_bar: for result_batch in get_batches_from_generator(result, batch_size): - document_batch = [ - self._convert_es_hit_to_document(hit, return_embedding=True) for hit in result_batch - ] + document_batch = [self._convert_es_hit_to_document(hit) for hit in result_batch] doc_updates = [] for doc in document_batch: if doc.embedding is not None: diff --git a/haystack/document_stores/search_engine.py b/haystack/document_stores/search_engine.py index 10bd1638b..e8866275e 100644 --- a/haystack/document_stores/search_engine.py +++ b/haystack/document_stores/search_engine.py @@ -1,5 +1,7 @@ # pylint: disable=too-many-public-methods + +from copy import deepcopy from typing import List, Optional, Union, Dict, Any, Generator from abc import abstractmethod import json @@ -279,10 +281,10 @@ class SearchEngineDocumentStore(KeywordDocumentStore): for i in range(0, len(ids), batch_size): ids_for_batch = ids[i : i + batch_size] query = {"size": len(ids_for_batch), "query": {"ids": {"values": ids_for_batch}}} + if not self.return_embedding and self.embedding_field: + query["_source"] = {"excludes": [self.embedding_field]} result = self.client.search(index=index, body=query, headers=headers)["hits"]["hits"] - documents.extend( - [self._convert_es_hit_to_document(hit, return_embedding=self.return_embedding) for hit in result] - ) + documents.extend([self._convert_es_hit_to_document(hit) for hit in result]) return documents def get_metadata_values_by_key( @@ -675,9 +677,15 @@ class SearchEngineDocumentStore(KeywordDocumentStore): if return_embedding is None: return_embedding = self.return_embedding - result = self._get_all_documents_in_index(index=index, filters=filters, batch_size=batch_size, headers=headers) + excludes = None + if not return_embedding and self.embedding_field: + excludes = [self.embedding_field] + + result = self._get_all_documents_in_index( + index=index, filters=filters, batch_size=batch_size, headers=headers, excludes=excludes + ) for hit in result: - document = self._convert_es_hit_to_document(hit, return_embedding=return_embedding) + document = self._convert_es_hit_to_document(hit) yield document def get_all_labels( @@ -709,6 +717,7 @@ class SearchEngineDocumentStore(KeywordDocumentStore): batch_size: int = 10_000, only_documents_without_embedding: bool = False, headers: Optional[Dict[str, str]] = None, + excludes: Optional[List[str]] = None, ) -> Generator[dict, None, None]: """ Return all documents in a specific index in the document store @@ -721,6 +730,9 @@ class SearchEngineDocumentStore(KeywordDocumentStore): if only_documents_without_embedding: body["query"]["bool"]["must_not"] = [{"exists": {"field": self.embedding_field}}] + if excludes: + body["_source"] = {"excludes": excludes} + result = self._do_scan( self.client, query=body, index=index, size=batch_size, scroll=self.scroll, headers=headers ) @@ -899,10 +911,7 @@ class SearchEngineDocumentStore(KeywordDocumentStore): result = self.client.search(index=index, body=body, headers=headers)["hits"]["hits"] - documents = [ - self._convert_es_hit_to_document(hit, return_embedding=self.return_embedding, scale_score=scale_score) - for hit in result - ] + documents = [self._convert_es_hit_to_document(hit, scale_score=scale_score) for hit in result] return documents def query_batch( @@ -1036,10 +1045,7 @@ class SearchEngineDocumentStore(KeywordDocumentStore): cur_documents = [] for response in responses["responses"]: cur_result = response["hits"]["hits"] - cur_documents = [ - self._convert_es_hit_to_document(hit, return_embedding=self.return_embedding, scale_score=scale_score) - for hit in cur_result - ] + cur_documents = [self._convert_es_hit_to_document(hit, scale_score=scale_score) for hit in cur_result] all_documents.append(cur_documents) return all_documents @@ -1105,13 +1111,28 @@ class SearchEngineDocumentStore(KeywordDocumentStore): if filters: body["query"]["bool"]["filter"] = LogicalFilterClause.parse(filters).convert_to_elasticsearch() - if self.excluded_meta_data: - body["_source"] = {"excludes": self.excluded_meta_data} + excluded_fields = self._get_excluded_fields(return_embedding=self.return_embedding) + if excluded_fields: + body["_source"] = {"excludes": excluded_fields} return body + def _get_excluded_fields(self, return_embedding: bool) -> Optional[List[str]]: + excluded_meta_data: Optional[list] = None + + if self.excluded_meta_data: + excluded_meta_data = deepcopy(self.excluded_meta_data) + + if return_embedding is True and self.embedding_field in excluded_meta_data: + excluded_meta_data.remove(self.embedding_field) + elif return_embedding is False and self.embedding_field not in excluded_meta_data: + excluded_meta_data.append(self.embedding_field) + elif return_embedding is False: + excluded_meta_data = [self.embedding_field] + return excluded_meta_data + def _convert_es_hit_to_document( - self, hit: dict, return_embedding: bool, adapt_score_for_embedding: bool = False, scale_score: bool = True + self, hit: dict, adapt_score_for_embedding: bool = False, scale_score: bool = True ) -> Document: # We put all additional data of the doc into meta_data and return it in the API try: @@ -1139,10 +1160,9 @@ class SearchEngineDocumentStore(KeywordDocumentStore): score = float(expit(np.asarray(score / 8))) # scaling probability from TFIDF/BM25 embedding = None - if return_embedding: - embedding_list = hit["_source"].get(self.embedding_field) - if embedding_list: - embedding = np.asarray(embedding_list, dtype=np.float32) + embedding_list = hit["_source"].get(self.embedding_field) + if embedding_list: + embedding = np.asarray(embedding_list, dtype=np.float32) doc_dict = { "id": hit["_id"], @@ -1284,9 +1304,7 @@ class SearchEngineDocumentStore(KeywordDocumentStore): for response in responses["responses"]: cur_result = response["hits"]["hits"] cur_documents = [ - self._convert_es_hit_to_document( - hit, adapt_score_for_embedding=True, return_embedding=self.return_embedding, scale_score=scale_score - ) + self._convert_es_hit_to_document(hit, adapt_score_for_embedding=True, scale_score=scale_score) for hit in cur_result ] all_documents.append(cur_documents) @@ -1295,11 +1313,7 @@ class SearchEngineDocumentStore(KeywordDocumentStore): @abstractmethod def _construct_dense_query_body( - self, - query_emb: np.ndarray, - filters: Optional[FilterType] = None, - top_k: int = 10, - return_embedding: Optional[bool] = None, + self, query_emb: np.ndarray, return_embedding: bool, filters: Optional[FilterType] = None, top_k: int = 10 ): pass @@ -1381,13 +1395,14 @@ class SearchEngineDocumentStore(KeywordDocumentStore): batch_size=batch_size, only_documents_without_embedding=not update_existing_embeddings, headers=headers, + excludes=[self.embedding_field], ) logging.getLogger(__name__).setLevel(logging.CRITICAL) with tqdm(total=document_count, position=0, unit=" Docs", desc="Updating embeddings") as progress_bar: for result_batch in get_batches_from_generator(result, batch_size): - document_batch = [self._convert_es_hit_to_document(hit, return_embedding=False) for hit in result_batch] + document_batch = [self._convert_es_hit_to_document(hit) for hit in result_batch] embeddings = self._embed_documents(document_batch, retriever) doc_updates = [] diff --git a/test/document_stores/test_elasticsearch.py b/test/document_stores/test_elasticsearch.py index 3fa7d29b0..cca979eff 100644 --- a/test/document_stores/test_elasticsearch.py +++ b/test/document_stores/test_elasticsearch.py @@ -1,4 +1,5 @@ import os +from unittest.mock import MagicMock import pytest import numpy as np @@ -31,6 +32,21 @@ class TestElasticsearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngine ds.delete_index(self.index_name) ds.delete_index(labels_index_name) + @pytest.fixture + def mocked_document_store(self): + """ + The fixture provides an instance of a slightly customized + ElasticsearchDocumentStore equipped with a mocked client + """ + + class DSMock(ElasticsearchDocumentStore): + # We mock a subclass to avoid messing up the actual class object + pass + + DSMock._init_elastic_client = MagicMock() + DSMock.client = MagicMock() + return DSMock() + @pytest.mark.integration def test___init__(self): # defaults diff --git a/test/document_stores/test_opensearch.py b/test/document_stores/test_opensearch.py index a389fb052..53a1d97a5 100644 --- a/test/document_stores/test_opensearch.py +++ b/test/document_stores/test_opensearch.py @@ -26,7 +26,6 @@ from .test_search_engine import SearchEngineDocumentStoreTestAbstract class TestOpenSearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngineDocumentStoreTestAbstract): # Constants - query_emb = np.random.random_sample(size=(2, 2)) index_name = __name__ diff --git a/test/document_stores/test_search_engine.py b/test/document_stores/test_search_engine.py index e8a89e684..a7a498e85 100644 --- a/test/document_stores/test_search_engine.py +++ b/test/document_stores/test_search_engine.py @@ -14,6 +14,9 @@ class SearchEngineDocumentStoreTestAbstract: because we want to run its methods only in subclasses. """ + # Constants + query = "test" + @pytest.mark.integration def test___do_bulk(self): pass @@ -46,6 +49,101 @@ class SearchEngineDocumentStoreTestAbstract: result = ds.get_metadata_values_by_key(key="year", query="Bar") assert result == [{"count": 3, "value": "2021"}] + @pytest.mark.unit + def test_query_return_embedding_true(self, mocked_document_store): + mocked_document_store.return_embedding = True + mocked_document_store.query(self.query) + # assert the resulting body is consistent with the `excluded_meta_data` value + _, kwargs = mocked_document_store.client.search.call_args + assert "_source" not in kwargs["body"] + + @pytest.mark.unit + def test_query_return_embedding_false(self, mocked_document_store): + mocked_document_store.return_embedding = False + mocked_document_store.query(self.query) + # assert the resulting body is consistent with the `excluded_meta_data` value + _, kwargs = mocked_document_store.client.search.call_args + assert kwargs["body"]["_source"] == {"excludes": ["embedding"]} + + @pytest.mark.unit + def test_query_excluded_meta_data_return_embedding_true(self, mocked_document_store): + mocked_document_store.return_embedding = True + mocked_document_store.excluded_meta_data = ["foo", "embedding"] + mocked_document_store.query(self.query) + _, kwargs = mocked_document_store.client.search.call_args + # we expect "embedding" was removed from the final query + assert kwargs["body"]["_source"] == {"excludes": ["foo"]} + + @pytest.mark.unit + def test_query_excluded_meta_data_return_embedding_false(self, mocked_document_store): + mocked_document_store.return_embedding = False + mocked_document_store.excluded_meta_data = ["foo"] + mocked_document_store.query(self.query) + # assert the resulting body is consistent with the `excluded_meta_data` value + _, kwargs = mocked_document_store.client.search.call_args + assert kwargs["body"]["_source"] == {"excludes": ["foo", "embedding"]} + + @pytest.mark.unit + def test_get_all_documents_return_embedding_true(self, mocked_document_store): + mocked_document_store.return_embedding = False + mocked_document_store.client.search.return_value = {} + mocked_document_store.get_all_documents(return_embedding=True) + # assert the resulting body is consistent with the `excluded_meta_data` value + _, kwargs = mocked_document_store.client.search.call_args + # starting with elasticsearch client 7.16, scan() uses the query parameter instead of body, + # see https://github.com/elastic/elasticsearch-py/commit/889edc9ad6d728b79fadf790238b79f36449d2e2 + body = kwargs.get("body", kwargs) + assert "_source" not in body + + @pytest.mark.unit + def test_get_all_documents_return_embedding_false(self, mocked_document_store): + mocked_document_store.return_embedding = True + mocked_document_store.client.search.return_value = {} + mocked_document_store.get_all_documents(return_embedding=False) + # assert the resulting body is consistent with the `excluded_meta_data` value + _, kwargs = mocked_document_store.client.search.call_args + # starting with elasticsearch client 7.16, scan() uses the query parameter instead of body, + # see https://github.com/elastic/elasticsearch-py/commit/889edc9ad6d728b79fadf790238b79f36449d2e2 + body = kwargs.get("body", kwargs) + assert body["_source"] == {"excludes": ["embedding"]} + + @pytest.mark.unit + def test_get_all_documents_excluded_meta_data_has_no_influence(self, mocked_document_store): + mocked_document_store.excluded_meta_data = ["foo"] + mocked_document_store.client.search.return_value = {} + mocked_document_store.get_all_documents(return_embedding=False) + # assert the resulting body is not affected by the `excluded_meta_data` value + _, kwargs = mocked_document_store.client.search.call_args + # starting with elasticsearch client 7.16, scan() uses the query parameter instead of body, + # see https://github.com/elastic/elasticsearch-py/commit/889edc9ad6d728b79fadf790238b79f36449d2e2 + body = kwargs.get("body", kwargs) + assert body["_source"] == {"excludes": ["embedding"]} + + @pytest.mark.unit + def test_get_document_by_id_return_embedding_true(self, mocked_document_store): + mocked_document_store.return_embedding = True + mocked_document_store.get_document_by_id("123") + # assert the resulting body is consistent with the `excluded_meta_data` value + _, kwargs = mocked_document_store.client.search.call_args + assert "_source" not in kwargs["body"] + + @pytest.mark.unit + def test_get_document_by_id_return_embedding_false(self, mocked_document_store): + mocked_document_store.return_embedding = False + mocked_document_store.get_document_by_id("123") + # assert the resulting body is consistent with the `excluded_meta_data` value + _, kwargs = mocked_document_store.client.search.call_args + assert kwargs["body"]["_source"] == {"excludes": ["embedding"]} + + @pytest.mark.unit + def test_get_document_by_id_excluded_meta_data_has_no_influence(self, mocked_document_store): + mocked_document_store.excluded_meta_data = ["foo"] + mocked_document_store.return_embedding = False + mocked_document_store.get_document_by_id("123") + # assert the resulting body is not affected by the `excluded_meta_data` value + _, kwargs = mocked_document_store.client.search.call_args + assert kwargs["body"]["_source"] == {"excludes": ["embedding"]} + @pytest.mark.document_store class TestSearchEngineDocumentStore: