fix: Despite return_embedding=False SearchEngineDocumentStore.query retrieves embedding_field (#3662)

* fix: Despite return_embedding=False SearchEngineDocumentStore.query retrieves embedding_field

* fix pylint

* add tests

* fix mypy

* fix merge

* format

* fix pylint

* move tests to SearchEngineDocumentStoreTestAbstract

* move missed constants

* add mocked_document_store fixture to TestElasticsearchDocumentStore

* fix mocked_document_store

* fix get_all_documents tests for elasticsearch>=7.16

* fix tests

* fix tests try 2
This commit is contained in:
tstadel 2023-01-09 11:58:23 +01:00 committed by GitHub
parent 5b0b338175
commit 6ca88bfd23
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 172 additions and 78 deletions

View File

@ -1,6 +1,5 @@
import logging
from typing import Dict, List, Optional, Type, Union
from copy import deepcopy
import numpy as np
@ -377,9 +376,10 @@ class ElasticsearchDocumentStore(SearchEngineDocumentStore):
if not self.embedding_field:
raise RuntimeError("Please specify arg `embedding_field` in ElasticsearchDocumentStore()")
body = self._construct_dense_query_body(query_emb, filters, top_k, return_embedding)
body = self._construct_dense_query_body(
query_emb=query_emb, filters=filters, top_k=top_k, return_embedding=return_embedding
)
logger.debug("Retriever query: %s", body)
try:
result = self.client.search(index=index, body=body, request_timeout=300, headers=headers)["hits"]["hits"]
if len(result) == 0:
@ -399,19 +399,13 @@ class ElasticsearchDocumentStore(SearchEngineDocumentStore):
raise e
documents = [
self._convert_es_hit_to_document(
hit, adapt_score_for_embedding=True, return_embedding=return_embedding, scale_score=scale_score
)
self._convert_es_hit_to_document(hit, adapt_score_for_embedding=True, scale_score=scale_score)
for hit in result
]
return documents
def _construct_dense_query_body(
self,
query_emb: np.ndarray,
filters: Optional[FilterType] = None,
top_k: int = 10,
return_embedding: Optional[bool] = None,
self, query_emb: np.ndarray, return_embedding: bool, filters: Optional[FilterType] = None, top_k: int = 10
):
body = {"size": top_k, "query": self._get_vector_similarity_query(query_emb, top_k)}
if filters:
@ -421,20 +415,10 @@ class ElasticsearchDocumentStore(SearchEngineDocumentStore):
else:
body["query"]["script_score"]["query"]["bool"]["filter"]["bool"]["must"].append(filter_)
excluded_meta_data: Optional[list] = None
excluded_fields = self._get_excluded_fields(return_embedding=return_embedding)
if excluded_fields:
body["_source"] = {"excludes": excluded_fields}
if self.excluded_meta_data:
excluded_meta_data = deepcopy(self.excluded_meta_data)
if return_embedding is True and self.embedding_field in excluded_meta_data:
excluded_meta_data.remove(self.embedding_field)
elif return_embedding is False and self.embedding_field not in excluded_meta_data:
excluded_meta_data.append(self.embedding_field)
elif return_embedding is False:
excluded_meta_data = [self.embedding_field]
if excluded_meta_data:
body["_source"] = {"excludes": excluded_meta_data}
return body
def _create_document_index(self, index_name: str, headers: Optional[Dict[str, str]] = None):

View File

@ -1,7 +1,6 @@
from typing import List, Optional, Union, Dict, Any
import logging
from copy import deepcopy
import numpy as np
from tqdm.auto import tqdm
@ -445,19 +444,13 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
result = self.client.search(index=index, body=body, request_timeout=300, headers=headers)["hits"]["hits"]
documents = [
self._convert_es_hit_to_document(
hit, adapt_score_for_embedding=True, return_embedding=return_embedding, scale_score=scale_score
)
self._convert_es_hit_to_document(hit, adapt_score_for_embedding=True, scale_score=scale_score)
for hit in result
]
return documents
def _construct_dense_query_body(
self,
query_emb: np.ndarray,
filters: Optional[FilterType] = None,
top_k: int = 10,
return_embedding: Optional[bool] = None,
self, query_emb: np.ndarray, return_embedding: bool, filters: Optional[FilterType] = None, top_k: int = 10
):
body: Dict[str, Any] = {"size": top_k, "query": self._get_vector_similarity_query(query_emb, top_k)}
if filters:
@ -468,19 +461,10 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
else:
body["query"]["bool"]["filter"] = filter_
excluded_meta_data: Optional[list] = None
if self.excluded_meta_data:
excluded_meta_data = deepcopy(self.excluded_meta_data)
excluded_fields = self._get_excluded_fields(return_embedding=return_embedding)
if excluded_fields:
body["_source"] = {"excludes": excluded_fields}
if return_embedding is True and self.embedding_field in excluded_meta_data:
excluded_meta_data.remove(self.embedding_field)
elif return_embedding is False and self.embedding_field not in excluded_meta_data:
excluded_meta_data.append(self.embedding_field)
elif return_embedding is False:
excluded_meta_data = [self.embedding_field]
if excluded_meta_data:
body["_source"] = {"excludes": excluded_meta_data}
return body
def _create_document_index(self, index_name: str, headers: Optional[Dict[str, str]] = None):
@ -842,9 +826,7 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
opensearch_logger.setLevel(logging.CRITICAL)
with tqdm(total=document_count, position=0, unit=" Docs", desc="Cloning embeddings") as progress_bar:
for result_batch in get_batches_from_generator(result, batch_size):
document_batch = [
self._convert_es_hit_to_document(hit, return_embedding=True) for hit in result_batch
]
document_batch = [self._convert_es_hit_to_document(hit) for hit in result_batch]
doc_updates = []
for doc in document_batch:
if doc.embedding is not None:

View File

@ -1,5 +1,7 @@
# pylint: disable=too-many-public-methods
from copy import deepcopy
from typing import List, Optional, Union, Dict, Any, Generator
from abc import abstractmethod
import json
@ -279,10 +281,10 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
for i in range(0, len(ids), batch_size):
ids_for_batch = ids[i : i + batch_size]
query = {"size": len(ids_for_batch), "query": {"ids": {"values": ids_for_batch}}}
if not self.return_embedding and self.embedding_field:
query["_source"] = {"excludes": [self.embedding_field]}
result = self.client.search(index=index, body=query, headers=headers)["hits"]["hits"]
documents.extend(
[self._convert_es_hit_to_document(hit, return_embedding=self.return_embedding) for hit in result]
)
documents.extend([self._convert_es_hit_to_document(hit) for hit in result])
return documents
def get_metadata_values_by_key(
@ -675,9 +677,15 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
if return_embedding is None:
return_embedding = self.return_embedding
result = self._get_all_documents_in_index(index=index, filters=filters, batch_size=batch_size, headers=headers)
excludes = None
if not return_embedding and self.embedding_field:
excludes = [self.embedding_field]
result = self._get_all_documents_in_index(
index=index, filters=filters, batch_size=batch_size, headers=headers, excludes=excludes
)
for hit in result:
document = self._convert_es_hit_to_document(hit, return_embedding=return_embedding)
document = self._convert_es_hit_to_document(hit)
yield document
def get_all_labels(
@ -709,6 +717,7 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
batch_size: int = 10_000,
only_documents_without_embedding: bool = False,
headers: Optional[Dict[str, str]] = None,
excludes: Optional[List[str]] = None,
) -> Generator[dict, None, None]:
"""
Return all documents in a specific index in the document store
@ -721,6 +730,9 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
if only_documents_without_embedding:
body["query"]["bool"]["must_not"] = [{"exists": {"field": self.embedding_field}}]
if excludes:
body["_source"] = {"excludes": excludes}
result = self._do_scan(
self.client, query=body, index=index, size=batch_size, scroll=self.scroll, headers=headers
)
@ -899,10 +911,7 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
result = self.client.search(index=index, body=body, headers=headers)["hits"]["hits"]
documents = [
self._convert_es_hit_to_document(hit, return_embedding=self.return_embedding, scale_score=scale_score)
for hit in result
]
documents = [self._convert_es_hit_to_document(hit, scale_score=scale_score) for hit in result]
return documents
def query_batch(
@ -1036,10 +1045,7 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
cur_documents = []
for response in responses["responses"]:
cur_result = response["hits"]["hits"]
cur_documents = [
self._convert_es_hit_to_document(hit, return_embedding=self.return_embedding, scale_score=scale_score)
for hit in cur_result
]
cur_documents = [self._convert_es_hit_to_document(hit, scale_score=scale_score) for hit in cur_result]
all_documents.append(cur_documents)
return all_documents
@ -1105,13 +1111,28 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
if filters:
body["query"]["bool"]["filter"] = LogicalFilterClause.parse(filters).convert_to_elasticsearch()
if self.excluded_meta_data:
body["_source"] = {"excludes": self.excluded_meta_data}
excluded_fields = self._get_excluded_fields(return_embedding=self.return_embedding)
if excluded_fields:
body["_source"] = {"excludes": excluded_fields}
return body
def _get_excluded_fields(self, return_embedding: bool) -> Optional[List[str]]:
excluded_meta_data: Optional[list] = None
if self.excluded_meta_data:
excluded_meta_data = deepcopy(self.excluded_meta_data)
if return_embedding is True and self.embedding_field in excluded_meta_data:
excluded_meta_data.remove(self.embedding_field)
elif return_embedding is False and self.embedding_field not in excluded_meta_data:
excluded_meta_data.append(self.embedding_field)
elif return_embedding is False:
excluded_meta_data = [self.embedding_field]
return excluded_meta_data
def _convert_es_hit_to_document(
self, hit: dict, return_embedding: bool, adapt_score_for_embedding: bool = False, scale_score: bool = True
self, hit: dict, adapt_score_for_embedding: bool = False, scale_score: bool = True
) -> Document:
# We put all additional data of the doc into meta_data and return it in the API
try:
@ -1139,10 +1160,9 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
score = float(expit(np.asarray(score / 8))) # scaling probability from TFIDF/BM25
embedding = None
if return_embedding:
embedding_list = hit["_source"].get(self.embedding_field)
if embedding_list:
embedding = np.asarray(embedding_list, dtype=np.float32)
embedding_list = hit["_source"].get(self.embedding_field)
if embedding_list:
embedding = np.asarray(embedding_list, dtype=np.float32)
doc_dict = {
"id": hit["_id"],
@ -1284,9 +1304,7 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
for response in responses["responses"]:
cur_result = response["hits"]["hits"]
cur_documents = [
self._convert_es_hit_to_document(
hit, adapt_score_for_embedding=True, return_embedding=self.return_embedding, scale_score=scale_score
)
self._convert_es_hit_to_document(hit, adapt_score_for_embedding=True, scale_score=scale_score)
for hit in cur_result
]
all_documents.append(cur_documents)
@ -1295,11 +1313,7 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
@abstractmethod
def _construct_dense_query_body(
self,
query_emb: np.ndarray,
filters: Optional[FilterType] = None,
top_k: int = 10,
return_embedding: Optional[bool] = None,
self, query_emb: np.ndarray, return_embedding: bool, filters: Optional[FilterType] = None, top_k: int = 10
):
pass
@ -1381,13 +1395,14 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
batch_size=batch_size,
only_documents_without_embedding=not update_existing_embeddings,
headers=headers,
excludes=[self.embedding_field],
)
logging.getLogger(__name__).setLevel(logging.CRITICAL)
with tqdm(total=document_count, position=0, unit=" Docs", desc="Updating embeddings") as progress_bar:
for result_batch in get_batches_from_generator(result, batch_size):
document_batch = [self._convert_es_hit_to_document(hit, return_embedding=False) for hit in result_batch]
document_batch = [self._convert_es_hit_to_document(hit) for hit in result_batch]
embeddings = self._embed_documents(document_batch, retriever)
doc_updates = []

View File

@ -1,4 +1,5 @@
import os
from unittest.mock import MagicMock
import pytest
import numpy as np
@ -31,6 +32,21 @@ class TestElasticsearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngine
ds.delete_index(self.index_name)
ds.delete_index(labels_index_name)
@pytest.fixture
def mocked_document_store(self):
"""
The fixture provides an instance of a slightly customized
ElasticsearchDocumentStore equipped with a mocked client
"""
class DSMock(ElasticsearchDocumentStore):
# We mock a subclass to avoid messing up the actual class object
pass
DSMock._init_elastic_client = MagicMock()
DSMock.client = MagicMock()
return DSMock()
@pytest.mark.integration
def test___init__(self):
# defaults

View File

@ -26,7 +26,6 @@ from .test_search_engine import SearchEngineDocumentStoreTestAbstract
class TestOpenSearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngineDocumentStoreTestAbstract):
# Constants
query_emb = np.random.random_sample(size=(2, 2))
index_name = __name__

View File

@ -14,6 +14,9 @@ class SearchEngineDocumentStoreTestAbstract:
because we want to run its methods only in subclasses.
"""
# Constants
query = "test"
@pytest.mark.integration
def test___do_bulk(self):
pass
@ -46,6 +49,101 @@ class SearchEngineDocumentStoreTestAbstract:
result = ds.get_metadata_values_by_key(key="year", query="Bar")
assert result == [{"count": 3, "value": "2021"}]
@pytest.mark.unit
def test_query_return_embedding_true(self, mocked_document_store):
mocked_document_store.return_embedding = True
mocked_document_store.query(self.query)
# assert the resulting body is consistent with the `excluded_meta_data` value
_, kwargs = mocked_document_store.client.search.call_args
assert "_source" not in kwargs["body"]
@pytest.mark.unit
def test_query_return_embedding_false(self, mocked_document_store):
mocked_document_store.return_embedding = False
mocked_document_store.query(self.query)
# assert the resulting body is consistent with the `excluded_meta_data` value
_, kwargs = mocked_document_store.client.search.call_args
assert kwargs["body"]["_source"] == {"excludes": ["embedding"]}
@pytest.mark.unit
def test_query_excluded_meta_data_return_embedding_true(self, mocked_document_store):
mocked_document_store.return_embedding = True
mocked_document_store.excluded_meta_data = ["foo", "embedding"]
mocked_document_store.query(self.query)
_, kwargs = mocked_document_store.client.search.call_args
# we expect "embedding" was removed from the final query
assert kwargs["body"]["_source"] == {"excludes": ["foo"]}
@pytest.mark.unit
def test_query_excluded_meta_data_return_embedding_false(self, mocked_document_store):
mocked_document_store.return_embedding = False
mocked_document_store.excluded_meta_data = ["foo"]
mocked_document_store.query(self.query)
# assert the resulting body is consistent with the `excluded_meta_data` value
_, kwargs = mocked_document_store.client.search.call_args
assert kwargs["body"]["_source"] == {"excludes": ["foo", "embedding"]}
@pytest.mark.unit
def test_get_all_documents_return_embedding_true(self, mocked_document_store):
mocked_document_store.return_embedding = False
mocked_document_store.client.search.return_value = {}
mocked_document_store.get_all_documents(return_embedding=True)
# assert the resulting body is consistent with the `excluded_meta_data` value
_, kwargs = mocked_document_store.client.search.call_args
# starting with elasticsearch client 7.16, scan() uses the query parameter instead of body,
# see https://github.com/elastic/elasticsearch-py/commit/889edc9ad6d728b79fadf790238b79f36449d2e2
body = kwargs.get("body", kwargs)
assert "_source" not in body
@pytest.mark.unit
def test_get_all_documents_return_embedding_false(self, mocked_document_store):
mocked_document_store.return_embedding = True
mocked_document_store.client.search.return_value = {}
mocked_document_store.get_all_documents(return_embedding=False)
# assert the resulting body is consistent with the `excluded_meta_data` value
_, kwargs = mocked_document_store.client.search.call_args
# starting with elasticsearch client 7.16, scan() uses the query parameter instead of body,
# see https://github.com/elastic/elasticsearch-py/commit/889edc9ad6d728b79fadf790238b79f36449d2e2
body = kwargs.get("body", kwargs)
assert body["_source"] == {"excludes": ["embedding"]}
@pytest.mark.unit
def test_get_all_documents_excluded_meta_data_has_no_influence(self, mocked_document_store):
mocked_document_store.excluded_meta_data = ["foo"]
mocked_document_store.client.search.return_value = {}
mocked_document_store.get_all_documents(return_embedding=False)
# assert the resulting body is not affected by the `excluded_meta_data` value
_, kwargs = mocked_document_store.client.search.call_args
# starting with elasticsearch client 7.16, scan() uses the query parameter instead of body,
# see https://github.com/elastic/elasticsearch-py/commit/889edc9ad6d728b79fadf790238b79f36449d2e2
body = kwargs.get("body", kwargs)
assert body["_source"] == {"excludes": ["embedding"]}
@pytest.mark.unit
def test_get_document_by_id_return_embedding_true(self, mocked_document_store):
mocked_document_store.return_embedding = True
mocked_document_store.get_document_by_id("123")
# assert the resulting body is consistent with the `excluded_meta_data` value
_, kwargs = mocked_document_store.client.search.call_args
assert "_source" not in kwargs["body"]
@pytest.mark.unit
def test_get_document_by_id_return_embedding_false(self, mocked_document_store):
mocked_document_store.return_embedding = False
mocked_document_store.get_document_by_id("123")
# assert the resulting body is consistent with the `excluded_meta_data` value
_, kwargs = mocked_document_store.client.search.call_args
assert kwargs["body"]["_source"] == {"excludes": ["embedding"]}
@pytest.mark.unit
def test_get_document_by_id_excluded_meta_data_has_no_influence(self, mocked_document_store):
mocked_document_store.excluded_meta_data = ["foo"]
mocked_document_store.return_embedding = False
mocked_document_store.get_document_by_id("123")
# assert the resulting body is not affected by the `excluded_meta_data` value
_, kwargs = mocked_document_store.client.search.call_args
assert kwargs["body"]["_source"] == {"excludes": ["embedding"]}
@pytest.mark.document_store
class TestSearchEngineDocumentStore: