mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-29 11:50:34 +00:00

* Python performance improvements with ruff C4 and PERF * pre-commit fixes * Revert changes to examples/basic_qa_pipeline.py * Revert changes to haystack/preview/testing/document_store.py * revert releasenotes * Upgrade to ruff v0.0.290
530 lines
25 KiB
Python
530 lines
25 KiB
Python
import logging
|
|
import os
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import numpy as np
|
|
import pytest
|
|
from elasticsearch import Elasticsearch
|
|
|
|
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore, VERSION
|
|
from haystack.document_stores.es_converter import elasticsearch_index_to_document_store
|
|
from haystack.document_stores.memory import InMemoryDocumentStore
|
|
from haystack.nodes import PreProcessor
|
|
from haystack.testing import DocumentStoreBaseTestAbstract
|
|
|
|
from .test_search_engine import SearchEngineDocumentStoreTestAbstract
|
|
|
|
|
|
class TestElasticsearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngineDocumentStoreTestAbstract):
|
|
# Constants
|
|
|
|
index_name = __name__
|
|
|
|
@pytest.fixture
|
|
def ds(self):
|
|
"""
|
|
This fixture provides a working document store and takes care of keeping clean
|
|
the ES cluster used in the tests.
|
|
"""
|
|
labels_index_name = f"{self.index_name}_labels"
|
|
ds = ElasticsearchDocumentStore(
|
|
index=self.index_name,
|
|
label_index=labels_index_name,
|
|
host=os.environ.get("ELASTICSEARCH_HOST", "localhost"),
|
|
create_index=True,
|
|
recreate_index=True,
|
|
)
|
|
|
|
yield ds
|
|
|
|
@pytest.fixture
|
|
def mocked_elastic_search_init(self, monkeypatch):
|
|
mocked_init = MagicMock(return_value=None)
|
|
monkeypatch.setattr(Elasticsearch, "__init__", mocked_init)
|
|
return mocked_init
|
|
|
|
@pytest.fixture
|
|
def mocked_elastic_search_ping(self, monkeypatch):
|
|
mocked_ping = MagicMock(return_value=True)
|
|
monkeypatch.setattr(Elasticsearch, "ping", mocked_ping)
|
|
return mocked_ping
|
|
|
|
@pytest.fixture
|
|
def mocked_document_store(self):
|
|
"""
|
|
The fixture provides an instance of a slightly customized
|
|
ElasticsearchDocumentStore equipped with a mocked client
|
|
"""
|
|
|
|
with patch(
|
|
f"{ElasticsearchDocumentStore.__module__}.ElasticsearchDocumentStore._init_elastic_client"
|
|
) as mocked_init_client:
|
|
if VERSION[0] == 7:
|
|
mocked_init_client().info.return_value = {"version": {"number": "7.17.6"}}
|
|
else:
|
|
mocked_init_client().info.return_value = {"version": {"number": "8.8.0"}}
|
|
|
|
class DSMock(ElasticsearchDocumentStore):
|
|
# We mock a subclass to avoid messing up the actual class object
|
|
pass
|
|
|
|
DSMock.client = MagicMock()
|
|
yield DSMock()
|
|
|
|
@pytest.mark.integration
|
|
def test___init__(self):
|
|
# defaults
|
|
_ = ElasticsearchDocumentStore()
|
|
|
|
# list of hosts + single port
|
|
_ = ElasticsearchDocumentStore(host=["localhost", "127.0.0.1"], port=9200)
|
|
|
|
# list of hosts + list of ports (wrong)
|
|
with pytest.raises(Exception):
|
|
_ = ElasticsearchDocumentStore(host=["localhost", "127.0.0.1"], port=[9200])
|
|
|
|
# list of hosts + list
|
|
_ = ElasticsearchDocumentStore(host=["localhost", "127.0.0.1"], port=[9200, 9200])
|
|
|
|
# only api_key
|
|
with pytest.raises(Exception):
|
|
_ = ElasticsearchDocumentStore(host=["localhost"], port=[9200], api_key="test")
|
|
|
|
# api_key + id
|
|
_ = ElasticsearchDocumentStore(host=["localhost"], port=[9200], api_key="test", api_key_id="test")
|
|
|
|
@pytest.mark.integration
|
|
def test_recreate_index(self, ds, documents, labels):
|
|
ds.write_documents(documents)
|
|
ds.write_labels(labels)
|
|
|
|
# Create another document store on top of the previous one
|
|
ds = ElasticsearchDocumentStore(index=ds.index, label_index=ds.label_index, recreate_index=True)
|
|
assert len(ds.get_all_documents(index=ds.index)) == 0
|
|
assert len(ds.get_all_labels(index=ds.label_index)) == 0
|
|
|
|
@pytest.mark.integration
|
|
def test_eq_filter(self, ds, documents):
|
|
ds.write_documents(documents)
|
|
|
|
filter = {"name": {"$eq": ["name_0"]}}
|
|
filtered_docs = ds.get_all_documents(filters=filter)
|
|
assert len(filtered_docs) == 3
|
|
for doc in filtered_docs:
|
|
assert doc.meta["name"] == "name_0"
|
|
|
|
filter = {"numbers": {"$eq": [2, 4]}}
|
|
filtered_docs = ds.query(query=None, filters=filter)
|
|
assert len(filtered_docs) == 3
|
|
for doc in filtered_docs:
|
|
assert doc.meta["month"] == "01"
|
|
assert doc.meta["numbers"] == [2, 4]
|
|
|
|
@pytest.mark.integration
|
|
def test_custom_fields(self, ds):
|
|
index = "haystack_test_custom"
|
|
document_store = ElasticsearchDocumentStore(
|
|
index=index,
|
|
content_field="custom_text_field",
|
|
embedding_field="custom_embedding_field",
|
|
recreate_index=True,
|
|
)
|
|
doc_to_write = {"custom_text_field": "test", "custom_embedding_field": np.random.rand(768).astype(np.float32)}
|
|
document_store.write_documents([doc_to_write])
|
|
documents = document_store.get_all_documents(return_embedding=True)
|
|
assert len(documents) == 1
|
|
assert documents[0].content == "test"
|
|
np.testing.assert_array_equal(doc_to_write["custom_embedding_field"], documents[0].embedding)
|
|
document_store.delete_index(index)
|
|
|
|
@pytest.mark.integration
|
|
def test_query_with_filters_and_missing_embeddings(self, ds, documents):
|
|
ds.write_documents(documents)
|
|
filters = {"month": {"$in": ["01", "03"]}}
|
|
ds.skip_missing_embeddings = False
|
|
with pytest.raises(ds._RequestError):
|
|
ds.query_by_embedding(np.random.rand(768), filters=filters)
|
|
|
|
ds.skip_missing_embeddings = True
|
|
documents = ds.query_by_embedding(np.random.rand(768), filters=filters)
|
|
assert len(documents) == 3
|
|
|
|
@pytest.mark.integration
|
|
def test_synonyms(self, ds):
|
|
synonyms = ["i-pod, i pod, ipod", "sea biscuit, sea biscit, seabiscuit", "foo, foo bar, baz"]
|
|
synonym_type = "synonym_graph"
|
|
|
|
client = ds.client
|
|
index = "haystack_synonym_arg"
|
|
client.indices.delete(index=index, ignore=[404])
|
|
ElasticsearchDocumentStore(index=index, synonyms=synonyms, synonym_type=synonym_type)
|
|
indexed_settings = client.indices.get_settings(index=index)
|
|
|
|
assert synonym_type == indexed_settings[index]["settings"]["index"]["analysis"]["filter"]["synonym"]["type"]
|
|
assert synonyms == indexed_settings[index]["settings"]["index"]["analysis"]["filter"]["synonym"]["synonyms"]
|
|
|
|
@pytest.mark.integration
|
|
def test_search_field_mapping(self):
|
|
index = "haystack_search_field_mapping"
|
|
document_store = ElasticsearchDocumentStore(
|
|
index=index, search_fields=["content", "sub_content"], content_field="title"
|
|
)
|
|
|
|
document_store.write_documents(
|
|
[
|
|
{
|
|
"title": "Green tea components",
|
|
"meta": {
|
|
"content": "The green tea plant contains a range of healthy compounds that make it into the final drink",
|
|
"sub_content": "Drink tip",
|
|
},
|
|
"id": "1",
|
|
},
|
|
{
|
|
"title": "Green tea catechin",
|
|
"meta": {
|
|
"content": "Green tea contains a catechin called epigallocatechin-3-gallate (EGCG).",
|
|
"sub_content": "Ingredients tip",
|
|
},
|
|
"id": "2",
|
|
},
|
|
{
|
|
"title": "Minerals in Green tea",
|
|
"meta": {
|
|
"content": "Green tea also has small amounts of minerals that can benefit your health.",
|
|
"sub_content": "Minerals tip",
|
|
},
|
|
"id": "3",
|
|
},
|
|
{
|
|
"title": "Green tea Benefits",
|
|
"meta": {
|
|
"content": "Green tea does more than just keep you alert, it may also help boost brain function.",
|
|
"sub_content": "Health tip",
|
|
},
|
|
"id": "4",
|
|
},
|
|
]
|
|
)
|
|
|
|
indexed_settings = document_store.client.indices.get_mapping(index=index)
|
|
|
|
assert indexed_settings[index]["mappings"]["properties"]["content"]["type"] == "text"
|
|
assert indexed_settings[index]["mappings"]["properties"]["sub_content"]["type"] == "text"
|
|
document_store.delete_index(index)
|
|
|
|
@pytest.mark.integration
|
|
def test_existing_alias(self, ds):
|
|
client = ds.client
|
|
client.indices.delete(index="haystack_existing_alias_1", ignore=[404])
|
|
client.indices.delete(index="haystack_existing_alias_2", ignore=[404])
|
|
client.indices.delete_alias(index="_all", name="haystack_existing_alias", ignore=[404])
|
|
|
|
settings = {"mappings": {"properties": {"content": {"type": "text"}}}}
|
|
|
|
client.indices.create(index="haystack_existing_alias_1", **settings)
|
|
client.indices.create(index="haystack_existing_alias_2", **settings)
|
|
|
|
client.indices.put_alias(
|
|
index="haystack_existing_alias_1,haystack_existing_alias_2", name="haystack_existing_alias"
|
|
)
|
|
|
|
# To be valid, all indices related to the alias must have content field of type text
|
|
ElasticsearchDocumentStore(index="haystack_existing_alias", search_fields=["content"])
|
|
|
|
@pytest.mark.integration
|
|
def test_existing_alias_missing_fields(self, ds):
|
|
client = ds.client
|
|
client.indices.delete(index="haystack_existing_alias_1", ignore=[404])
|
|
client.indices.delete(index="haystack_existing_alias_2", ignore=[404])
|
|
client.indices.delete_alias(index="_all", name="haystack_existing_alias", ignore=[404])
|
|
|
|
right_settings = {"mappings": {"properties": {"content": {"type": "text"}}}}
|
|
wrong_settings = {"mappings": {"properties": {"content": {"type": "histogram"}}}}
|
|
|
|
client.indices.create(index="haystack_existing_alias_1", **right_settings)
|
|
client.indices.create(index="haystack_existing_alias_2", **wrong_settings)
|
|
client.indices.put_alias(
|
|
index="haystack_existing_alias_1,haystack_existing_alias_2", name="haystack_existing_alias"
|
|
)
|
|
|
|
with pytest.raises(Exception):
|
|
# wrong field type for "content" in index "haystack_existing_alias_2"
|
|
ElasticsearchDocumentStore(
|
|
index="haystack_existing_alias", search_fields=["content"], content_field="title"
|
|
)
|
|
|
|
@pytest.mark.integration
|
|
def test_get_document_count_only_documents_without_embedding_arg(self, ds, documents):
|
|
ds.write_documents(documents)
|
|
|
|
assert ds.get_document_count() == 9
|
|
assert ds.get_document_count(only_documents_without_embedding=True) == 3
|
|
assert ds.get_document_count(only_documents_without_embedding=True, filters={"month": ["01"]}) == 0
|
|
assert ds.get_document_count(only_documents_without_embedding=True, filters={"month": ["03"]}) == 3
|
|
|
|
@pytest.mark.integration
|
|
def test_elasticsearch_brownfield_support(self, ds, documents):
|
|
ds.write_documents(documents)
|
|
|
|
new_document_store = elasticsearch_index_to_document_store(
|
|
document_store=InMemoryDocumentStore(),
|
|
original_index_name=ds.index,
|
|
original_content_field="content",
|
|
original_name_field="name",
|
|
included_metadata_fields=["date_field"],
|
|
index="test_brownfield_support",
|
|
id_hash_keys=["content", "meta"],
|
|
)
|
|
|
|
original_documents = ds.get_all_documents()
|
|
transferred_documents = new_document_store.get_all_documents(index="test_brownfield_support")
|
|
assert len(original_documents) == len(transferred_documents)
|
|
assert all("name" in doc.meta for doc in transferred_documents)
|
|
assert all(doc.id == doc._get_id(["content", "meta"]) for doc in transferred_documents)
|
|
|
|
original_content = {doc.content for doc in original_documents}
|
|
transferred_content = {doc.content for doc in transferred_documents}
|
|
assert original_content == transferred_content
|
|
|
|
# Test transferring docs with PreProcessor
|
|
new_document_store = elasticsearch_index_to_document_store(
|
|
document_store=InMemoryDocumentStore(),
|
|
original_index_name=ds.index,
|
|
original_content_field="content",
|
|
excluded_metadata_fields=["date_field"],
|
|
index="test_brownfield_support_2",
|
|
preprocessor=PreProcessor(split_length=1, split_respect_sentence_boundary=False),
|
|
)
|
|
transferred_documents = new_document_store.get_all_documents(index="test_brownfield_support_2")
|
|
assert all("name" in doc.meta for doc in transferred_documents)
|
|
# Check if number of transferred_documents is equal to number of unique words.
|
|
assert len(transferred_documents) == len(set(" ".join(original_content).split()))
|
|
|
|
@pytest.mark.skipif(VERSION[0] == 8, reason="Elasticsearch 8 is not supported")
|
|
@pytest.mark.unit
|
|
def test__init_elastic_client_aws4auth_and_username_raises_warning(
|
|
self, caplog, mocked_elastic_search_init, mocked_elastic_search_ping
|
|
):
|
|
_init_client_remaining_kwargs = {
|
|
"host": "host",
|
|
"port": 443,
|
|
"password": "pass",
|
|
"api_key_id": None,
|
|
"api_key": None,
|
|
"scheme": "https",
|
|
"ca_certs": None,
|
|
"verify_certs": True,
|
|
"timeout": 10,
|
|
"use_system_proxy": False,
|
|
}
|
|
|
|
with caplog.at_level(logging.WARN, logger="haystack.document_stores.elasticsearch"):
|
|
ElasticsearchDocumentStore._init_elastic_client(
|
|
username="admin", aws4auth="foo", **_init_client_remaining_kwargs
|
|
)
|
|
assert len(caplog.records) == 1
|
|
for r in caplog.records:
|
|
assert r.levelname == "WARNING"
|
|
|
|
caplog.clear()
|
|
with caplog.at_level(logging.WARN, logger="haystack.document_stores.elasticsearch"):
|
|
ElasticsearchDocumentStore._init_elastic_client(
|
|
username=None, aws4auth="foo", **_init_client_remaining_kwargs
|
|
)
|
|
ElasticsearchDocumentStore._init_elastic_client(
|
|
username="", aws4auth="foo", **_init_client_remaining_kwargs
|
|
)
|
|
assert len(caplog.records) == 0
|
|
|
|
@pytest.mark.skipif(VERSION[0] == 8, reason="Elasticsearch 8 uses a different client call")
|
|
@pytest.mark.unit
|
|
def test_get_document_by_id_return_embedding_false_es7(self, mocked_document_store):
|
|
mocked_document_store.return_embedding = False
|
|
mocked_document_store.get_document_by_id("123")
|
|
# assert the resulting body is consistent with the `excluded_meta_data` value
|
|
_, kwargs = mocked_document_store.client.search.call_args
|
|
assert kwargs["_source"] == {"excludes": ["embedding"]}
|
|
|
|
@pytest.mark.skipif(VERSION[0] == 7, reason="Elasticsearch 7 uses a different client call")
|
|
@pytest.mark.unit
|
|
def test_get_document_by_id_return_embedding_false_es8(self, mocked_document_store):
|
|
mocked_document_store.return_embedding = False
|
|
mocked_document_store.get_document_by_id("123")
|
|
# assert the resulting body is consistent with the `excluded_meta_data` value
|
|
_, kwargs = mocked_document_store.client.options().search.call_args
|
|
assert kwargs["_source"] == {"excludes": ["embedding"]}
|
|
|
|
@pytest.mark.skipif(VERSION[0] == 8, reason="Elasticsearch 8 uses a different client call")
|
|
@pytest.mark.unit
|
|
def test_get_document_by_id_excluded_meta_data_has_no_influence_es7(self, mocked_document_store):
|
|
mocked_document_store.excluded_meta_data = ["foo"]
|
|
mocked_document_store.return_embedding = False
|
|
mocked_document_store.get_document_by_id("123")
|
|
# assert the resulting body is not affected by the `excluded_meta_data` value
|
|
_, kwargs = mocked_document_store.client.search.call_args
|
|
assert kwargs["_source"] == {"excludes": ["embedding"]}
|
|
|
|
@pytest.mark.skipif(VERSION[0] == 7, reason="Elasticsearch 7 uses a different client call")
|
|
@pytest.mark.unit
|
|
def test_get_document_by_id_excluded_meta_data_has_no_influence_es8(self, mocked_document_store):
|
|
mocked_document_store.excluded_meta_data = ["foo"]
|
|
mocked_document_store.return_embedding = False
|
|
mocked_document_store.get_document_by_id("123")
|
|
# assert the resulting body is not affected by the `excluded_meta_data` value
|
|
_, kwargs = mocked_document_store.client.options().search.call_args
|
|
assert kwargs["_source"] == {"excludes": ["embedding"]}
|
|
|
|
@pytest.mark.unit
|
|
def test_write_documents_req_for_each_batch(self, mocked_document_store, documents):
|
|
mocked_document_store.batch_size = 2
|
|
with patch(f"{ElasticsearchDocumentStore.__module__}.bulk") as mocked_bulk:
|
|
mocked_document_store.write_documents(documents)
|
|
assert mocked_bulk.call_count == 5
|
|
|
|
@pytest.mark.unit
|
|
def test_get_vector_similarity_query(self, mocked_document_store):
|
|
"""
|
|
Test that the source field of the vector similarity query is correctly formatted for ES 7.6 and above.
|
|
We test this to make sure we use the correct syntax for newer ES versions.
|
|
"""
|
|
vec_sim_query = mocked_document_store._get_vector_similarity_query(np.random.rand(3).astype(np.float32), 10)
|
|
assert vec_sim_query["script_score"]["script"]["source"] == "dotProduct(params.query_vector,'embedding') + 1000"
|
|
|
|
@pytest.mark.unit
|
|
def test_get_vector_similarity_query_es_7_5_and_below(self, mocked_document_store):
|
|
"""
|
|
Test that the source field of the vector similarity query is correctly formatter for ES 7.5 and below.
|
|
We test this to make sure we use the correct syntax for ES versions older than 7.6, as the syntax changed
|
|
in 7.6.
|
|
"""
|
|
# Patch server version to be 7.5.0
|
|
mocked_document_store.server_version = (7, 5, 0)
|
|
|
|
vec_sim_query = mocked_document_store._get_vector_similarity_query(np.random.rand(3).astype(np.float32), 10)
|
|
assert (
|
|
vec_sim_query["script_score"]["script"]["source"]
|
|
== "dotProduct(params.query_vector,doc['embedding']) + 1000"
|
|
)
|
|
|
|
# The following tests are overridden only to be able to skip them depending on ES version
|
|
|
|
@pytest.mark.skipif(VERSION[0] == 8, reason="Elasticsearch 8 uses a different client call")
|
|
@pytest.mark.unit
|
|
def test_get_all_documents_return_embedding_true(self, mocked_document_store):
|
|
super().test_get_all_documents_return_embedding_true(mocked_document_store)
|
|
|
|
@pytest.mark.skipif(VERSION[0] == 7, reason="Elasticsearch 7 uses a different client call")
|
|
@pytest.mark.unit
|
|
def test_get_all_documents_return_embedding_true_es8(self, mocked_document_store):
|
|
mocked_document_store.return_embedding = False
|
|
mocked_document_store.client.options().search.return_value = {}
|
|
mocked_document_store.get_all_documents(return_embedding=True)
|
|
# assert the resulting body is consistent with the `excluded_meta_data` value
|
|
_, kwargs = mocked_document_store.client.options().search.call_args
|
|
assert "_source" not in kwargs
|
|
|
|
@pytest.mark.skipif(VERSION[0] == 8, reason="Elasticsearch 8 uses a different client call")
|
|
@pytest.mark.unit
|
|
def test_get_all_documents_return_embedding_false(self, mocked_document_store):
|
|
super().test_get_all_documents_return_embedding_false(mocked_document_store)
|
|
|
|
@pytest.mark.skipif(VERSION[0] == 7, reason="Elasticsearch 7 uses a different client call")
|
|
@pytest.mark.unit
|
|
def test_get_all_documents_return_embedding_false_es8(self, mocked_document_store):
|
|
mocked_document_store.return_embedding = True
|
|
mocked_document_store.client.options().search.return_value = {}
|
|
mocked_document_store.get_all_documents(return_embedding=False)
|
|
# assert the resulting body is consistent with the `excluded_meta_data` value
|
|
_, kwargs = mocked_document_store.client.options().search.call_args
|
|
body = kwargs.get("body", kwargs)
|
|
assert body["_source"] == {"excludes": ["embedding"]}
|
|
|
|
@pytest.mark.skipif(VERSION[0] == 8, reason="Elasticsearch 8 uses a different client call")
|
|
@pytest.mark.unit
|
|
def test_get_all_documents_excluded_meta_data_has_no_influence(self, mocked_document_store):
|
|
super().test_get_all_documents_excluded_meta_data_has_no_influence(mocked_document_store)
|
|
|
|
@pytest.mark.skipif(VERSION[0] == 7, reason="Elasticsearch 7 uses a different client call")
|
|
@pytest.mark.unit
|
|
def test_get_all_documents_excluded_meta_data_has_no_influence_es8(self, mocked_document_store):
|
|
mocked_document_store.excluded_meta_data = ["foo"]
|
|
mocked_document_store.client.options().search.return_value = {}
|
|
mocked_document_store.get_all_documents(return_embedding=False)
|
|
# assert the resulting body is not affected by the `excluded_meta_data` value
|
|
_, kwargs = mocked_document_store.client.options().search.call_args
|
|
body = kwargs.get("body", kwargs)
|
|
assert body["_source"] == {"excludes": ["embedding"]}
|
|
|
|
@pytest.mark.skipif(VERSION[0] == 8, reason="Elasticsearch 8 uses a different client call")
|
|
@pytest.mark.unit
|
|
def test_query_return_embedding_true(self, mocked_document_store):
|
|
super().test_query_return_embedding_true(mocked_document_store)
|
|
|
|
@pytest.mark.skipif(VERSION[0] == 7, reason="Elasticsearch 7 uses a different client call")
|
|
@pytest.mark.unit
|
|
def test_query_return_embedding_true_es8(self, mocked_document_store):
|
|
mocked_document_store.return_embedding = True
|
|
mocked_document_store.query(self.query)
|
|
# assert the resulting body is consistent with the `excluded_meta_data` value
|
|
_, kwargs = mocked_document_store.client.options().search.call_args
|
|
assert "_source" not in kwargs
|
|
|
|
@pytest.mark.skipif(VERSION[0] == 8, reason="Elasticsearch 8 uses a different client call")
|
|
@pytest.mark.unit
|
|
def test_query_return_embedding_false(self, mocked_document_store):
|
|
super().test_query_return_embedding_false(mocked_document_store)
|
|
|
|
@pytest.mark.skipif(VERSION[0] == 7, reason="Elasticsearch 7 uses a different client call")
|
|
@pytest.mark.unit
|
|
def test_query_return_embedding_false_es8(self, mocked_document_store):
|
|
mocked_document_store.return_embedding = False
|
|
mocked_document_store.query(self.query)
|
|
# assert the resulting body is consistent with the `excluded_meta_data` value
|
|
_, kwargs = mocked_document_store.client.options().search.call_args
|
|
assert kwargs["_source"] == {"excludes": ["embedding"]}
|
|
|
|
@pytest.mark.skipif(VERSION[0] == 8, reason="Elasticsearch 8 uses a different client call")
|
|
@pytest.mark.unit
|
|
def test_query_excluded_meta_data_return_embedding_true(self, mocked_document_store):
|
|
super().test_query_excluded_meta_data_return_embedding_true(mocked_document_store)
|
|
|
|
@pytest.mark.skipif(VERSION[0] == 7, reason="Elasticsearch 7 uses a different client call")
|
|
@pytest.mark.unit
|
|
def test_query_excluded_meta_data_return_embedding_true_es8(self, mocked_document_store):
|
|
mocked_document_store.return_embedding = True
|
|
mocked_document_store.excluded_meta_data = ["foo", "embedding"]
|
|
mocked_document_store.query(self.query)
|
|
_, kwargs = mocked_document_store.client.options().search.call_args
|
|
# we expect "embedding" was removed from the final query
|
|
assert kwargs["_source"] == {"excludes": ["foo"]}
|
|
|
|
@pytest.mark.skipif(VERSION[0] == 8, reason="Elasticsearch 8 uses a different client call")
|
|
@pytest.mark.unit
|
|
def test_query_excluded_meta_data_return_embedding_false(self, mocked_document_store):
|
|
super().test_query_excluded_meta_data_return_embedding_false(mocked_document_store)
|
|
|
|
@pytest.mark.skipif(VERSION[0] == 7, reason="Elasticsearch 7 uses a different client call")
|
|
@pytest.mark.unit
|
|
def test_query_excluded_meta_data_return_embedding_false_es8(self, mocked_document_store):
|
|
mocked_document_store.return_embedding = False
|
|
mocked_document_store.excluded_meta_data = ["foo"]
|
|
mocked_document_store.query(self.query)
|
|
# assert the resulting body is consistent with the `excluded_meta_data` value
|
|
_, kwargs = mocked_document_store.client.options().search.call_args
|
|
assert kwargs["_source"] == {"excludes": ["foo", "embedding"]}
|
|
|
|
@pytest.mark.skipif(VERSION[0] == 8, reason="Elasticsearch 8 uses a different client call")
|
|
@pytest.mark.unit
|
|
def test_get_document_by_id_return_embedding_true(self, mocked_document_store):
|
|
super().test_get_document_by_id_return_embedding_true(mocked_document_store)
|
|
|
|
@pytest.mark.skipif(VERSION[0] == 7, reason="Elasticsearch 7 uses a different client call")
|
|
@pytest.mark.unit
|
|
def test_get_document_by_id_return_embedding_true_es8(self, mocked_document_store):
|
|
mocked_document_store.return_embedding = True
|
|
mocked_document_store.get_document_by_id("123")
|
|
# assert the resulting body is consistent with the `excluded_meta_data` value
|
|
_, kwargs = mocked_document_store.client.options().search.call_args
|
|
assert "_source" not in kwargs
|