mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-10 14:16:59 +00:00
fix: Despite return_embedding=False SearchEngineDocumentStore.query retrieves embedding_field (#3662)
* fix: Despite return_embedding=False SearchEngineDocumentStore.query retrieves embedding_field * fix pylint * add tests * fix mypy * fix merge * format * fix pylint * move tests to SearchEngineDocumentStoreTestAbstract * move missed constants * add mocked_document_store fixture to TestElasticsearchDocumentStore * fix mocked_document_store * fix get_all_documents tests for elasticsearch>=7.16 * fix tests * fix tests try 2
This commit is contained in:
parent
5b0b338175
commit
6ca88bfd23
@ -1,6 +1,5 @@
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Type, Union
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -377,9 +376,10 @@ class ElasticsearchDocumentStore(SearchEngineDocumentStore):
|
||||
if not self.embedding_field:
|
||||
raise RuntimeError("Please specify arg `embedding_field` in ElasticsearchDocumentStore()")
|
||||
|
||||
body = self._construct_dense_query_body(query_emb, filters, top_k, return_embedding)
|
||||
body = self._construct_dense_query_body(
|
||||
query_emb=query_emb, filters=filters, top_k=top_k, return_embedding=return_embedding
|
||||
)
|
||||
|
||||
logger.debug("Retriever query: %s", body)
|
||||
try:
|
||||
result = self.client.search(index=index, body=body, request_timeout=300, headers=headers)["hits"]["hits"]
|
||||
if len(result) == 0:
|
||||
@ -399,19 +399,13 @@ class ElasticsearchDocumentStore(SearchEngineDocumentStore):
|
||||
raise e
|
||||
|
||||
documents = [
|
||||
self._convert_es_hit_to_document(
|
||||
hit, adapt_score_for_embedding=True, return_embedding=return_embedding, scale_score=scale_score
|
||||
)
|
||||
self._convert_es_hit_to_document(hit, adapt_score_for_embedding=True, scale_score=scale_score)
|
||||
for hit in result
|
||||
]
|
||||
return documents
|
||||
|
||||
def _construct_dense_query_body(
|
||||
self,
|
||||
query_emb: np.ndarray,
|
||||
filters: Optional[FilterType] = None,
|
||||
top_k: int = 10,
|
||||
return_embedding: Optional[bool] = None,
|
||||
self, query_emb: np.ndarray, return_embedding: bool, filters: Optional[FilterType] = None, top_k: int = 10
|
||||
):
|
||||
body = {"size": top_k, "query": self._get_vector_similarity_query(query_emb, top_k)}
|
||||
if filters:
|
||||
@ -421,20 +415,10 @@ class ElasticsearchDocumentStore(SearchEngineDocumentStore):
|
||||
else:
|
||||
body["query"]["script_score"]["query"]["bool"]["filter"]["bool"]["must"].append(filter_)
|
||||
|
||||
excluded_meta_data: Optional[list] = None
|
||||
excluded_fields = self._get_excluded_fields(return_embedding=return_embedding)
|
||||
if excluded_fields:
|
||||
body["_source"] = {"excludes": excluded_fields}
|
||||
|
||||
if self.excluded_meta_data:
|
||||
excluded_meta_data = deepcopy(self.excluded_meta_data)
|
||||
|
||||
if return_embedding is True and self.embedding_field in excluded_meta_data:
|
||||
excluded_meta_data.remove(self.embedding_field)
|
||||
elif return_embedding is False and self.embedding_field not in excluded_meta_data:
|
||||
excluded_meta_data.append(self.embedding_field)
|
||||
elif return_embedding is False:
|
||||
excluded_meta_data = [self.embedding_field]
|
||||
|
||||
if excluded_meta_data:
|
||||
body["_source"] = {"excludes": excluded_meta_data}
|
||||
return body
|
||||
|
||||
def _create_document_index(self, index_name: str, headers: Optional[Dict[str, str]] = None):
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from typing import List, Optional, Union, Dict, Any
|
||||
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
@ -445,19 +444,13 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
|
||||
result = self.client.search(index=index, body=body, request_timeout=300, headers=headers)["hits"]["hits"]
|
||||
|
||||
documents = [
|
||||
self._convert_es_hit_to_document(
|
||||
hit, adapt_score_for_embedding=True, return_embedding=return_embedding, scale_score=scale_score
|
||||
)
|
||||
self._convert_es_hit_to_document(hit, adapt_score_for_embedding=True, scale_score=scale_score)
|
||||
for hit in result
|
||||
]
|
||||
return documents
|
||||
|
||||
def _construct_dense_query_body(
|
||||
self,
|
||||
query_emb: np.ndarray,
|
||||
filters: Optional[FilterType] = None,
|
||||
top_k: int = 10,
|
||||
return_embedding: Optional[bool] = None,
|
||||
self, query_emb: np.ndarray, return_embedding: bool, filters: Optional[FilterType] = None, top_k: int = 10
|
||||
):
|
||||
body: Dict[str, Any] = {"size": top_k, "query": self._get_vector_similarity_query(query_emb, top_k)}
|
||||
if filters:
|
||||
@ -468,19 +461,10 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
|
||||
else:
|
||||
body["query"]["bool"]["filter"] = filter_
|
||||
|
||||
excluded_meta_data: Optional[list] = None
|
||||
if self.excluded_meta_data:
|
||||
excluded_meta_data = deepcopy(self.excluded_meta_data)
|
||||
excluded_fields = self._get_excluded_fields(return_embedding=return_embedding)
|
||||
if excluded_fields:
|
||||
body["_source"] = {"excludes": excluded_fields}
|
||||
|
||||
if return_embedding is True and self.embedding_field in excluded_meta_data:
|
||||
excluded_meta_data.remove(self.embedding_field)
|
||||
elif return_embedding is False and self.embedding_field not in excluded_meta_data:
|
||||
excluded_meta_data.append(self.embedding_field)
|
||||
elif return_embedding is False:
|
||||
excluded_meta_data = [self.embedding_field]
|
||||
|
||||
if excluded_meta_data:
|
||||
body["_source"] = {"excludes": excluded_meta_data}
|
||||
return body
|
||||
|
||||
def _create_document_index(self, index_name: str, headers: Optional[Dict[str, str]] = None):
|
||||
@ -842,9 +826,7 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
|
||||
opensearch_logger.setLevel(logging.CRITICAL)
|
||||
with tqdm(total=document_count, position=0, unit=" Docs", desc="Cloning embeddings") as progress_bar:
|
||||
for result_batch in get_batches_from_generator(result, batch_size):
|
||||
document_batch = [
|
||||
self._convert_es_hit_to_document(hit, return_embedding=True) for hit in result_batch
|
||||
]
|
||||
document_batch = [self._convert_es_hit_to_document(hit) for hit in result_batch]
|
||||
doc_updates = []
|
||||
for doc in document_batch:
|
||||
if doc.embedding is not None:
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
# pylint: disable=too-many-public-methods
|
||||
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import List, Optional, Union, Dict, Any, Generator
|
||||
from abc import abstractmethod
|
||||
import json
|
||||
@ -279,10 +281,10 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
||||
for i in range(0, len(ids), batch_size):
|
||||
ids_for_batch = ids[i : i + batch_size]
|
||||
query = {"size": len(ids_for_batch), "query": {"ids": {"values": ids_for_batch}}}
|
||||
if not self.return_embedding and self.embedding_field:
|
||||
query["_source"] = {"excludes": [self.embedding_field]}
|
||||
result = self.client.search(index=index, body=query, headers=headers)["hits"]["hits"]
|
||||
documents.extend(
|
||||
[self._convert_es_hit_to_document(hit, return_embedding=self.return_embedding) for hit in result]
|
||||
)
|
||||
documents.extend([self._convert_es_hit_to_document(hit) for hit in result])
|
||||
return documents
|
||||
|
||||
def get_metadata_values_by_key(
|
||||
@ -675,9 +677,15 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
||||
if return_embedding is None:
|
||||
return_embedding = self.return_embedding
|
||||
|
||||
result = self._get_all_documents_in_index(index=index, filters=filters, batch_size=batch_size, headers=headers)
|
||||
excludes = None
|
||||
if not return_embedding and self.embedding_field:
|
||||
excludes = [self.embedding_field]
|
||||
|
||||
result = self._get_all_documents_in_index(
|
||||
index=index, filters=filters, batch_size=batch_size, headers=headers, excludes=excludes
|
||||
)
|
||||
for hit in result:
|
||||
document = self._convert_es_hit_to_document(hit, return_embedding=return_embedding)
|
||||
document = self._convert_es_hit_to_document(hit)
|
||||
yield document
|
||||
|
||||
def get_all_labels(
|
||||
@ -709,6 +717,7 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
||||
batch_size: int = 10_000,
|
||||
only_documents_without_embedding: bool = False,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
excludes: Optional[List[str]] = None,
|
||||
) -> Generator[dict, None, None]:
|
||||
"""
|
||||
Return all documents in a specific index in the document store
|
||||
@ -721,6 +730,9 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
||||
if only_documents_without_embedding:
|
||||
body["query"]["bool"]["must_not"] = [{"exists": {"field": self.embedding_field}}]
|
||||
|
||||
if excludes:
|
||||
body["_source"] = {"excludes": excludes}
|
||||
|
||||
result = self._do_scan(
|
||||
self.client, query=body, index=index, size=batch_size, scroll=self.scroll, headers=headers
|
||||
)
|
||||
@ -899,10 +911,7 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
||||
|
||||
result = self.client.search(index=index, body=body, headers=headers)["hits"]["hits"]
|
||||
|
||||
documents = [
|
||||
self._convert_es_hit_to_document(hit, return_embedding=self.return_embedding, scale_score=scale_score)
|
||||
for hit in result
|
||||
]
|
||||
documents = [self._convert_es_hit_to_document(hit, scale_score=scale_score) for hit in result]
|
||||
return documents
|
||||
|
||||
def query_batch(
|
||||
@ -1036,10 +1045,7 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
||||
cur_documents = []
|
||||
for response in responses["responses"]:
|
||||
cur_result = response["hits"]["hits"]
|
||||
cur_documents = [
|
||||
self._convert_es_hit_to_document(hit, return_embedding=self.return_embedding, scale_score=scale_score)
|
||||
for hit in cur_result
|
||||
]
|
||||
cur_documents = [self._convert_es_hit_to_document(hit, scale_score=scale_score) for hit in cur_result]
|
||||
all_documents.append(cur_documents)
|
||||
|
||||
return all_documents
|
||||
@ -1105,13 +1111,28 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
||||
if filters:
|
||||
body["query"]["bool"]["filter"] = LogicalFilterClause.parse(filters).convert_to_elasticsearch()
|
||||
|
||||
if self.excluded_meta_data:
|
||||
body["_source"] = {"excludes": self.excluded_meta_data}
|
||||
excluded_fields = self._get_excluded_fields(return_embedding=self.return_embedding)
|
||||
if excluded_fields:
|
||||
body["_source"] = {"excludes": excluded_fields}
|
||||
|
||||
return body
|
||||
|
||||
def _get_excluded_fields(self, return_embedding: bool) -> Optional[List[str]]:
|
||||
excluded_meta_data: Optional[list] = None
|
||||
|
||||
if self.excluded_meta_data:
|
||||
excluded_meta_data = deepcopy(self.excluded_meta_data)
|
||||
|
||||
if return_embedding is True and self.embedding_field in excluded_meta_data:
|
||||
excluded_meta_data.remove(self.embedding_field)
|
||||
elif return_embedding is False and self.embedding_field not in excluded_meta_data:
|
||||
excluded_meta_data.append(self.embedding_field)
|
||||
elif return_embedding is False:
|
||||
excluded_meta_data = [self.embedding_field]
|
||||
return excluded_meta_data
|
||||
|
||||
def _convert_es_hit_to_document(
|
||||
self, hit: dict, return_embedding: bool, adapt_score_for_embedding: bool = False, scale_score: bool = True
|
||||
self, hit: dict, adapt_score_for_embedding: bool = False, scale_score: bool = True
|
||||
) -> Document:
|
||||
# We put all additional data of the doc into meta_data and return it in the API
|
||||
try:
|
||||
@ -1139,10 +1160,9 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
||||
score = float(expit(np.asarray(score / 8))) # scaling probability from TFIDF/BM25
|
||||
|
||||
embedding = None
|
||||
if return_embedding:
|
||||
embedding_list = hit["_source"].get(self.embedding_field)
|
||||
if embedding_list:
|
||||
embedding = np.asarray(embedding_list, dtype=np.float32)
|
||||
embedding_list = hit["_source"].get(self.embedding_field)
|
||||
if embedding_list:
|
||||
embedding = np.asarray(embedding_list, dtype=np.float32)
|
||||
|
||||
doc_dict = {
|
||||
"id": hit["_id"],
|
||||
@ -1284,9 +1304,7 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
||||
for response in responses["responses"]:
|
||||
cur_result = response["hits"]["hits"]
|
||||
cur_documents = [
|
||||
self._convert_es_hit_to_document(
|
||||
hit, adapt_score_for_embedding=True, return_embedding=self.return_embedding, scale_score=scale_score
|
||||
)
|
||||
self._convert_es_hit_to_document(hit, adapt_score_for_embedding=True, scale_score=scale_score)
|
||||
for hit in cur_result
|
||||
]
|
||||
all_documents.append(cur_documents)
|
||||
@ -1295,11 +1313,7 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
||||
|
||||
@abstractmethod
|
||||
def _construct_dense_query_body(
|
||||
self,
|
||||
query_emb: np.ndarray,
|
||||
filters: Optional[FilterType] = None,
|
||||
top_k: int = 10,
|
||||
return_embedding: Optional[bool] = None,
|
||||
self, query_emb: np.ndarray, return_embedding: bool, filters: Optional[FilterType] = None, top_k: int = 10
|
||||
):
|
||||
pass
|
||||
|
||||
@ -1381,13 +1395,14 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
||||
batch_size=batch_size,
|
||||
only_documents_without_embedding=not update_existing_embeddings,
|
||||
headers=headers,
|
||||
excludes=[self.embedding_field],
|
||||
)
|
||||
|
||||
logging.getLogger(__name__).setLevel(logging.CRITICAL)
|
||||
|
||||
with tqdm(total=document_count, position=0, unit=" Docs", desc="Updating embeddings") as progress_bar:
|
||||
for result_batch in get_batches_from_generator(result, batch_size):
|
||||
document_batch = [self._convert_es_hit_to_document(hit, return_embedding=False) for hit in result_batch]
|
||||
document_batch = [self._convert_es_hit_to_document(hit) for hit in result_batch]
|
||||
embeddings = self._embed_documents(document_batch, retriever)
|
||||
|
||||
doc_updates = []
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import os
|
||||
from unittest.mock import MagicMock
|
||||
import pytest
|
||||
|
||||
import numpy as np
|
||||
@ -31,6 +32,21 @@ class TestElasticsearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngine
|
||||
ds.delete_index(self.index_name)
|
||||
ds.delete_index(labels_index_name)
|
||||
|
||||
@pytest.fixture
|
||||
def mocked_document_store(self):
|
||||
"""
|
||||
The fixture provides an instance of a slightly customized
|
||||
ElasticsearchDocumentStore equipped with a mocked client
|
||||
"""
|
||||
|
||||
class DSMock(ElasticsearchDocumentStore):
|
||||
# We mock a subclass to avoid messing up the actual class object
|
||||
pass
|
||||
|
||||
DSMock._init_elastic_client = MagicMock()
|
||||
DSMock.client = MagicMock()
|
||||
return DSMock()
|
||||
|
||||
@pytest.mark.integration
|
||||
def test___init__(self):
|
||||
# defaults
|
||||
|
||||
@ -26,7 +26,6 @@ from .test_search_engine import SearchEngineDocumentStoreTestAbstract
|
||||
class TestOpenSearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngineDocumentStoreTestAbstract):
|
||||
|
||||
# Constants
|
||||
|
||||
query_emb = np.random.random_sample(size=(2, 2))
|
||||
index_name = __name__
|
||||
|
||||
|
||||
@ -14,6 +14,9 @@ class SearchEngineDocumentStoreTestAbstract:
|
||||
because we want to run its methods only in subclasses.
|
||||
"""
|
||||
|
||||
# Constants
|
||||
query = "test"
|
||||
|
||||
@pytest.mark.integration
|
||||
def test___do_bulk(self):
|
||||
pass
|
||||
@ -46,6 +49,101 @@ class SearchEngineDocumentStoreTestAbstract:
|
||||
result = ds.get_metadata_values_by_key(key="year", query="Bar")
|
||||
assert result == [{"count": 3, "value": "2021"}]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_query_return_embedding_true(self, mocked_document_store):
|
||||
mocked_document_store.return_embedding = True
|
||||
mocked_document_store.query(self.query)
|
||||
# assert the resulting body is consistent with the `excluded_meta_data` value
|
||||
_, kwargs = mocked_document_store.client.search.call_args
|
||||
assert "_source" not in kwargs["body"]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_query_return_embedding_false(self, mocked_document_store):
|
||||
mocked_document_store.return_embedding = False
|
||||
mocked_document_store.query(self.query)
|
||||
# assert the resulting body is consistent with the `excluded_meta_data` value
|
||||
_, kwargs = mocked_document_store.client.search.call_args
|
||||
assert kwargs["body"]["_source"] == {"excludes": ["embedding"]}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_query_excluded_meta_data_return_embedding_true(self, mocked_document_store):
|
||||
mocked_document_store.return_embedding = True
|
||||
mocked_document_store.excluded_meta_data = ["foo", "embedding"]
|
||||
mocked_document_store.query(self.query)
|
||||
_, kwargs = mocked_document_store.client.search.call_args
|
||||
# we expect "embedding" was removed from the final query
|
||||
assert kwargs["body"]["_source"] == {"excludes": ["foo"]}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_query_excluded_meta_data_return_embedding_false(self, mocked_document_store):
|
||||
mocked_document_store.return_embedding = False
|
||||
mocked_document_store.excluded_meta_data = ["foo"]
|
||||
mocked_document_store.query(self.query)
|
||||
# assert the resulting body is consistent with the `excluded_meta_data` value
|
||||
_, kwargs = mocked_document_store.client.search.call_args
|
||||
assert kwargs["body"]["_source"] == {"excludes": ["foo", "embedding"]}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_all_documents_return_embedding_true(self, mocked_document_store):
|
||||
mocked_document_store.return_embedding = False
|
||||
mocked_document_store.client.search.return_value = {}
|
||||
mocked_document_store.get_all_documents(return_embedding=True)
|
||||
# assert the resulting body is consistent with the `excluded_meta_data` value
|
||||
_, kwargs = mocked_document_store.client.search.call_args
|
||||
# starting with elasticsearch client 7.16, scan() uses the query parameter instead of body,
|
||||
# see https://github.com/elastic/elasticsearch-py/commit/889edc9ad6d728b79fadf790238b79f36449d2e2
|
||||
body = kwargs.get("body", kwargs)
|
||||
assert "_source" not in body
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_all_documents_return_embedding_false(self, mocked_document_store):
|
||||
mocked_document_store.return_embedding = True
|
||||
mocked_document_store.client.search.return_value = {}
|
||||
mocked_document_store.get_all_documents(return_embedding=False)
|
||||
# assert the resulting body is consistent with the `excluded_meta_data` value
|
||||
_, kwargs = mocked_document_store.client.search.call_args
|
||||
# starting with elasticsearch client 7.16, scan() uses the query parameter instead of body,
|
||||
# see https://github.com/elastic/elasticsearch-py/commit/889edc9ad6d728b79fadf790238b79f36449d2e2
|
||||
body = kwargs.get("body", kwargs)
|
||||
assert body["_source"] == {"excludes": ["embedding"]}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_all_documents_excluded_meta_data_has_no_influence(self, mocked_document_store):
|
||||
mocked_document_store.excluded_meta_data = ["foo"]
|
||||
mocked_document_store.client.search.return_value = {}
|
||||
mocked_document_store.get_all_documents(return_embedding=False)
|
||||
# assert the resulting body is not affected by the `excluded_meta_data` value
|
||||
_, kwargs = mocked_document_store.client.search.call_args
|
||||
# starting with elasticsearch client 7.16, scan() uses the query parameter instead of body,
|
||||
# see https://github.com/elastic/elasticsearch-py/commit/889edc9ad6d728b79fadf790238b79f36449d2e2
|
||||
body = kwargs.get("body", kwargs)
|
||||
assert body["_source"] == {"excludes": ["embedding"]}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_document_by_id_return_embedding_true(self, mocked_document_store):
|
||||
mocked_document_store.return_embedding = True
|
||||
mocked_document_store.get_document_by_id("123")
|
||||
# assert the resulting body is consistent with the `excluded_meta_data` value
|
||||
_, kwargs = mocked_document_store.client.search.call_args
|
||||
assert "_source" not in kwargs["body"]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_document_by_id_return_embedding_false(self, mocked_document_store):
|
||||
mocked_document_store.return_embedding = False
|
||||
mocked_document_store.get_document_by_id("123")
|
||||
# assert the resulting body is consistent with the `excluded_meta_data` value
|
||||
_, kwargs = mocked_document_store.client.search.call_args
|
||||
assert kwargs["body"]["_source"] == {"excludes": ["embedding"]}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_document_by_id_excluded_meta_data_has_no_influence(self, mocked_document_store):
|
||||
mocked_document_store.excluded_meta_data = ["foo"]
|
||||
mocked_document_store.return_embedding = False
|
||||
mocked_document_store.get_document_by_id("123")
|
||||
# assert the resulting body is not affected by the `excluded_meta_data` value
|
||||
_, kwargs = mocked_document_store.client.search.call_args
|
||||
assert kwargs["body"]["_source"] == {"excludes": ["embedding"]}
|
||||
|
||||
|
||||
@pytest.mark.document_store
|
||||
class TestSearchEngineDocumentStore:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user