haystack/test/document_stores/test_elasticsearch.py
Christian Clauss 91ab90a256
perf: Python performance improvements with ruff C4 and PERF fixes (#5803)
* Python performance improvements with ruff C4 and PERF

* pre-commit fixes

* Revert changes to examples/basic_qa_pipeline.py

* Revert changes to haystack/preview/testing/document_store.py

* revert releasenotes

* Upgrade to ruff v0.0.290
2023-09-16 16:26:07 +02:00

530 lines
25 KiB
Python

import logging
import os
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from elasticsearch import Elasticsearch
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore, VERSION
from haystack.document_stores.es_converter import elasticsearch_index_to_document_store
from haystack.document_stores.memory import InMemoryDocumentStore
from haystack.nodes import PreProcessor
from haystack.testing import DocumentStoreBaseTestAbstract
from .test_search_engine import SearchEngineDocumentStoreTestAbstract
class TestElasticsearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngineDocumentStoreTestAbstract):
# Constants
index_name = __name__
@pytest.fixture
def ds(self):
"""
This fixture provides a working document store and takes care of keeping clean
the ES cluster used in the tests.
"""
labels_index_name = f"{self.index_name}_labels"
ds = ElasticsearchDocumentStore(
index=self.index_name,
label_index=labels_index_name,
host=os.environ.get("ELASTICSEARCH_HOST", "localhost"),
create_index=True,
recreate_index=True,
)
yield ds
@pytest.fixture
def mocked_elastic_search_init(self, monkeypatch):
mocked_init = MagicMock(return_value=None)
monkeypatch.setattr(Elasticsearch, "__init__", mocked_init)
return mocked_init
@pytest.fixture
def mocked_elastic_search_ping(self, monkeypatch):
mocked_ping = MagicMock(return_value=True)
monkeypatch.setattr(Elasticsearch, "ping", mocked_ping)
return mocked_ping
@pytest.fixture
def mocked_document_store(self):
"""
The fixture provides an instance of a slightly customized
ElasticsearchDocumentStore equipped with a mocked client
"""
with patch(
f"{ElasticsearchDocumentStore.__module__}.ElasticsearchDocumentStore._init_elastic_client"
) as mocked_init_client:
if VERSION[0] == 7:
mocked_init_client().info.return_value = {"version": {"number": "7.17.6"}}
else:
mocked_init_client().info.return_value = {"version": {"number": "8.8.0"}}
class DSMock(ElasticsearchDocumentStore):
# We mock a subclass to avoid messing up the actual class object
pass
DSMock.client = MagicMock()
yield DSMock()
@pytest.mark.integration
def test___init__(self):
# defaults
_ = ElasticsearchDocumentStore()
# list of hosts + single port
_ = ElasticsearchDocumentStore(host=["localhost", "127.0.0.1"], port=9200)
# list of hosts + list of ports (wrong)
with pytest.raises(Exception):
_ = ElasticsearchDocumentStore(host=["localhost", "127.0.0.1"], port=[9200])
# list of hosts + list
_ = ElasticsearchDocumentStore(host=["localhost", "127.0.0.1"], port=[9200, 9200])
# only api_key
with pytest.raises(Exception):
_ = ElasticsearchDocumentStore(host=["localhost"], port=[9200], api_key="test")
# api_key + id
_ = ElasticsearchDocumentStore(host=["localhost"], port=[9200], api_key="test", api_key_id="test")
@pytest.mark.integration
def test_recreate_index(self, ds, documents, labels):
ds.write_documents(documents)
ds.write_labels(labels)
# Create another document store on top of the previous one
ds = ElasticsearchDocumentStore(index=ds.index, label_index=ds.label_index, recreate_index=True)
assert len(ds.get_all_documents(index=ds.index)) == 0
assert len(ds.get_all_labels(index=ds.label_index)) == 0
@pytest.mark.integration
def test_eq_filter(self, ds, documents):
ds.write_documents(documents)
filter = {"name": {"$eq": ["name_0"]}}
filtered_docs = ds.get_all_documents(filters=filter)
assert len(filtered_docs) == 3
for doc in filtered_docs:
assert doc.meta["name"] == "name_0"
filter = {"numbers": {"$eq": [2, 4]}}
filtered_docs = ds.query(query=None, filters=filter)
assert len(filtered_docs) == 3
for doc in filtered_docs:
assert doc.meta["month"] == "01"
assert doc.meta["numbers"] == [2, 4]
@pytest.mark.integration
def test_custom_fields(self, ds):
index = "haystack_test_custom"
document_store = ElasticsearchDocumentStore(
index=index,
content_field="custom_text_field",
embedding_field="custom_embedding_field",
recreate_index=True,
)
doc_to_write = {"custom_text_field": "test", "custom_embedding_field": np.random.rand(768).astype(np.float32)}
document_store.write_documents([doc_to_write])
documents = document_store.get_all_documents(return_embedding=True)
assert len(documents) == 1
assert documents[0].content == "test"
np.testing.assert_array_equal(doc_to_write["custom_embedding_field"], documents[0].embedding)
document_store.delete_index(index)
@pytest.mark.integration
def test_query_with_filters_and_missing_embeddings(self, ds, documents):
ds.write_documents(documents)
filters = {"month": {"$in": ["01", "03"]}}
ds.skip_missing_embeddings = False
with pytest.raises(ds._RequestError):
ds.query_by_embedding(np.random.rand(768), filters=filters)
ds.skip_missing_embeddings = True
documents = ds.query_by_embedding(np.random.rand(768), filters=filters)
assert len(documents) == 3
@pytest.mark.integration
def test_synonyms(self, ds):
synonyms = ["i-pod, i pod, ipod", "sea biscuit, sea biscit, seabiscuit", "foo, foo bar, baz"]
synonym_type = "synonym_graph"
client = ds.client
index = "haystack_synonym_arg"
client.indices.delete(index=index, ignore=[404])
ElasticsearchDocumentStore(index=index, synonyms=synonyms, synonym_type=synonym_type)
indexed_settings = client.indices.get_settings(index=index)
assert synonym_type == indexed_settings[index]["settings"]["index"]["analysis"]["filter"]["synonym"]["type"]
assert synonyms == indexed_settings[index]["settings"]["index"]["analysis"]["filter"]["synonym"]["synonyms"]
@pytest.mark.integration
def test_search_field_mapping(self):
index = "haystack_search_field_mapping"
document_store = ElasticsearchDocumentStore(
index=index, search_fields=["content", "sub_content"], content_field="title"
)
document_store.write_documents(
[
{
"title": "Green tea components",
"meta": {
"content": "The green tea plant contains a range of healthy compounds that make it into the final drink",
"sub_content": "Drink tip",
},
"id": "1",
},
{
"title": "Green tea catechin",
"meta": {
"content": "Green tea contains a catechin called epigallocatechin-3-gallate (EGCG).",
"sub_content": "Ingredients tip",
},
"id": "2",
},
{
"title": "Minerals in Green tea",
"meta": {
"content": "Green tea also has small amounts of minerals that can benefit your health.",
"sub_content": "Minerals tip",
},
"id": "3",
},
{
"title": "Green tea Benefits",
"meta": {
"content": "Green tea does more than just keep you alert, it may also help boost brain function.",
"sub_content": "Health tip",
},
"id": "4",
},
]
)
indexed_settings = document_store.client.indices.get_mapping(index=index)
assert indexed_settings[index]["mappings"]["properties"]["content"]["type"] == "text"
assert indexed_settings[index]["mappings"]["properties"]["sub_content"]["type"] == "text"
document_store.delete_index(index)
@pytest.mark.integration
def test_existing_alias(self, ds):
client = ds.client
client.indices.delete(index="haystack_existing_alias_1", ignore=[404])
client.indices.delete(index="haystack_existing_alias_2", ignore=[404])
client.indices.delete_alias(index="_all", name="haystack_existing_alias", ignore=[404])
settings = {"mappings": {"properties": {"content": {"type": "text"}}}}
client.indices.create(index="haystack_existing_alias_1", **settings)
client.indices.create(index="haystack_existing_alias_2", **settings)
client.indices.put_alias(
index="haystack_existing_alias_1,haystack_existing_alias_2", name="haystack_existing_alias"
)
# To be valid, all indices related to the alias must have content field of type text
ElasticsearchDocumentStore(index="haystack_existing_alias", search_fields=["content"])
@pytest.mark.integration
def test_existing_alias_missing_fields(self, ds):
client = ds.client
client.indices.delete(index="haystack_existing_alias_1", ignore=[404])
client.indices.delete(index="haystack_existing_alias_2", ignore=[404])
client.indices.delete_alias(index="_all", name="haystack_existing_alias", ignore=[404])
right_settings = {"mappings": {"properties": {"content": {"type": "text"}}}}
wrong_settings = {"mappings": {"properties": {"content": {"type": "histogram"}}}}
client.indices.create(index="haystack_existing_alias_1", **right_settings)
client.indices.create(index="haystack_existing_alias_2", **wrong_settings)
client.indices.put_alias(
index="haystack_existing_alias_1,haystack_existing_alias_2", name="haystack_existing_alias"
)
with pytest.raises(Exception):
# wrong field type for "content" in index "haystack_existing_alias_2"
ElasticsearchDocumentStore(
index="haystack_existing_alias", search_fields=["content"], content_field="title"
)
@pytest.mark.integration
def test_get_document_count_only_documents_without_embedding_arg(self, ds, documents):
ds.write_documents(documents)
assert ds.get_document_count() == 9
assert ds.get_document_count(only_documents_without_embedding=True) == 3
assert ds.get_document_count(only_documents_without_embedding=True, filters={"month": ["01"]}) == 0
assert ds.get_document_count(only_documents_without_embedding=True, filters={"month": ["03"]}) == 3
@pytest.mark.integration
def test_elasticsearch_brownfield_support(self, ds, documents):
ds.write_documents(documents)
new_document_store = elasticsearch_index_to_document_store(
document_store=InMemoryDocumentStore(),
original_index_name=ds.index,
original_content_field="content",
original_name_field="name",
included_metadata_fields=["date_field"],
index="test_brownfield_support",
id_hash_keys=["content", "meta"],
)
original_documents = ds.get_all_documents()
transferred_documents = new_document_store.get_all_documents(index="test_brownfield_support")
assert len(original_documents) == len(transferred_documents)
assert all("name" in doc.meta for doc in transferred_documents)
assert all(doc.id == doc._get_id(["content", "meta"]) for doc in transferred_documents)
original_content = {doc.content for doc in original_documents}
transferred_content = {doc.content for doc in transferred_documents}
assert original_content == transferred_content
# Test transferring docs with PreProcessor
new_document_store = elasticsearch_index_to_document_store(
document_store=InMemoryDocumentStore(),
original_index_name=ds.index,
original_content_field="content",
excluded_metadata_fields=["date_field"],
index="test_brownfield_support_2",
preprocessor=PreProcessor(split_length=1, split_respect_sentence_boundary=False),
)
transferred_documents = new_document_store.get_all_documents(index="test_brownfield_support_2")
assert all("name" in doc.meta for doc in transferred_documents)
# Check if number of transferred_documents is equal to number of unique words.
assert len(transferred_documents) == len(set(" ".join(original_content).split()))
@pytest.mark.skipif(VERSION[0] == 8, reason="Elasticsearch 8 is not supported")
@pytest.mark.unit
def test__init_elastic_client_aws4auth_and_username_raises_warning(
self, caplog, mocked_elastic_search_init, mocked_elastic_search_ping
):
_init_client_remaining_kwargs = {
"host": "host",
"port": 443,
"password": "pass",
"api_key_id": None,
"api_key": None,
"scheme": "https",
"ca_certs": None,
"verify_certs": True,
"timeout": 10,
"use_system_proxy": False,
}
with caplog.at_level(logging.WARN, logger="haystack.document_stores.elasticsearch"):
ElasticsearchDocumentStore._init_elastic_client(
username="admin", aws4auth="foo", **_init_client_remaining_kwargs
)
assert len(caplog.records) == 1
for r in caplog.records:
assert r.levelname == "WARNING"
caplog.clear()
with caplog.at_level(logging.WARN, logger="haystack.document_stores.elasticsearch"):
ElasticsearchDocumentStore._init_elastic_client(
username=None, aws4auth="foo", **_init_client_remaining_kwargs
)
ElasticsearchDocumentStore._init_elastic_client(
username="", aws4auth="foo", **_init_client_remaining_kwargs
)
assert len(caplog.records) == 0
@pytest.mark.skipif(VERSION[0] == 8, reason="Elasticsearch 8 uses a different client call")
@pytest.mark.unit
def test_get_document_by_id_return_embedding_false_es7(self, mocked_document_store):
mocked_document_store.return_embedding = False
mocked_document_store.get_document_by_id("123")
# assert the resulting body is consistent with the `excluded_meta_data` value
_, kwargs = mocked_document_store.client.search.call_args
assert kwargs["_source"] == {"excludes": ["embedding"]}
@pytest.mark.skipif(VERSION[0] == 7, reason="Elasticsearch 7 uses a different client call")
@pytest.mark.unit
def test_get_document_by_id_return_embedding_false_es8(self, mocked_document_store):
mocked_document_store.return_embedding = False
mocked_document_store.get_document_by_id("123")
# assert the resulting body is consistent with the `excluded_meta_data` value
_, kwargs = mocked_document_store.client.options().search.call_args
assert kwargs["_source"] == {"excludes": ["embedding"]}
@pytest.mark.skipif(VERSION[0] == 8, reason="Elasticsearch 8 uses a different client call")
@pytest.mark.unit
def test_get_document_by_id_excluded_meta_data_has_no_influence_es7(self, mocked_document_store):
mocked_document_store.excluded_meta_data = ["foo"]
mocked_document_store.return_embedding = False
mocked_document_store.get_document_by_id("123")
# assert the resulting body is not affected by the `excluded_meta_data` value
_, kwargs = mocked_document_store.client.search.call_args
assert kwargs["_source"] == {"excludes": ["embedding"]}
@pytest.mark.skipif(VERSION[0] == 7, reason="Elasticsearch 7 uses a different client call")
@pytest.mark.unit
def test_get_document_by_id_excluded_meta_data_has_no_influence_es8(self, mocked_document_store):
mocked_document_store.excluded_meta_data = ["foo"]
mocked_document_store.return_embedding = False
mocked_document_store.get_document_by_id("123")
# assert the resulting body is not affected by the `excluded_meta_data` value
_, kwargs = mocked_document_store.client.options().search.call_args
assert kwargs["_source"] == {"excludes": ["embedding"]}
@pytest.mark.unit
def test_write_documents_req_for_each_batch(self, mocked_document_store, documents):
mocked_document_store.batch_size = 2
with patch(f"{ElasticsearchDocumentStore.__module__}.bulk") as mocked_bulk:
mocked_document_store.write_documents(documents)
assert mocked_bulk.call_count == 5
@pytest.mark.unit
def test_get_vector_similarity_query(self, mocked_document_store):
"""
Test that the source field of the vector similarity query is correctly formatted for ES 7.6 and above.
We test this to make sure we use the correct syntax for newer ES versions.
"""
vec_sim_query = mocked_document_store._get_vector_similarity_query(np.random.rand(3).astype(np.float32), 10)
assert vec_sim_query["script_score"]["script"]["source"] == "dotProduct(params.query_vector,'embedding') + 1000"
@pytest.mark.unit
def test_get_vector_similarity_query_es_7_5_and_below(self, mocked_document_store):
"""
Test that the source field of the vector similarity query is correctly formatter for ES 7.5 and below.
We test this to make sure we use the correct syntax for ES versions older than 7.6, as the syntax changed
in 7.6.
"""
# Patch server version to be 7.5.0
mocked_document_store.server_version = (7, 5, 0)
vec_sim_query = mocked_document_store._get_vector_similarity_query(np.random.rand(3).astype(np.float32), 10)
assert (
vec_sim_query["script_score"]["script"]["source"]
== "dotProduct(params.query_vector,doc['embedding']) + 1000"
)
# The following tests are overridden only to be able to skip them depending on ES version
@pytest.mark.skipif(VERSION[0] == 8, reason="Elasticsearch 8 uses a different client call")
@pytest.mark.unit
def test_get_all_documents_return_embedding_true(self, mocked_document_store):
super().test_get_all_documents_return_embedding_true(mocked_document_store)
@pytest.mark.skipif(VERSION[0] == 7, reason="Elasticsearch 7 uses a different client call")
@pytest.mark.unit
def test_get_all_documents_return_embedding_true_es8(self, mocked_document_store):
mocked_document_store.return_embedding = False
mocked_document_store.client.options().search.return_value = {}
mocked_document_store.get_all_documents(return_embedding=True)
# assert the resulting body is consistent with the `excluded_meta_data` value
_, kwargs = mocked_document_store.client.options().search.call_args
assert "_source" not in kwargs
@pytest.mark.skipif(VERSION[0] == 8, reason="Elasticsearch 8 uses a different client call")
@pytest.mark.unit
def test_get_all_documents_return_embedding_false(self, mocked_document_store):
super().test_get_all_documents_return_embedding_false(mocked_document_store)
@pytest.mark.skipif(VERSION[0] == 7, reason="Elasticsearch 7 uses a different client call")
@pytest.mark.unit
def test_get_all_documents_return_embedding_false_es8(self, mocked_document_store):
mocked_document_store.return_embedding = True
mocked_document_store.client.options().search.return_value = {}
mocked_document_store.get_all_documents(return_embedding=False)
# assert the resulting body is consistent with the `excluded_meta_data` value
_, kwargs = mocked_document_store.client.options().search.call_args
body = kwargs.get("body", kwargs)
assert body["_source"] == {"excludes": ["embedding"]}
@pytest.mark.skipif(VERSION[0] == 8, reason="Elasticsearch 8 uses a different client call")
@pytest.mark.unit
def test_get_all_documents_excluded_meta_data_has_no_influence(self, mocked_document_store):
super().test_get_all_documents_excluded_meta_data_has_no_influence(mocked_document_store)
@pytest.mark.skipif(VERSION[0] == 7, reason="Elasticsearch 7 uses a different client call")
@pytest.mark.unit
def test_get_all_documents_excluded_meta_data_has_no_influence_es8(self, mocked_document_store):
mocked_document_store.excluded_meta_data = ["foo"]
mocked_document_store.client.options().search.return_value = {}
mocked_document_store.get_all_documents(return_embedding=False)
# assert the resulting body is not affected by the `excluded_meta_data` value
_, kwargs = mocked_document_store.client.options().search.call_args
body = kwargs.get("body", kwargs)
assert body["_source"] == {"excludes": ["embedding"]}
@pytest.mark.skipif(VERSION[0] == 8, reason="Elasticsearch 8 uses a different client call")
@pytest.mark.unit
def test_query_return_embedding_true(self, mocked_document_store):
super().test_query_return_embedding_true(mocked_document_store)
@pytest.mark.skipif(VERSION[0] == 7, reason="Elasticsearch 7 uses a different client call")
@pytest.mark.unit
def test_query_return_embedding_true_es8(self, mocked_document_store):
mocked_document_store.return_embedding = True
mocked_document_store.query(self.query)
# assert the resulting body is consistent with the `excluded_meta_data` value
_, kwargs = mocked_document_store.client.options().search.call_args
assert "_source" not in kwargs
@pytest.mark.skipif(VERSION[0] == 8, reason="Elasticsearch 8 uses a different client call")
@pytest.mark.unit
def test_query_return_embedding_false(self, mocked_document_store):
super().test_query_return_embedding_false(mocked_document_store)
@pytest.mark.skipif(VERSION[0] == 7, reason="Elasticsearch 7 uses a different client call")
@pytest.mark.unit
def test_query_return_embedding_false_es8(self, mocked_document_store):
mocked_document_store.return_embedding = False
mocked_document_store.query(self.query)
# assert the resulting body is consistent with the `excluded_meta_data` value
_, kwargs = mocked_document_store.client.options().search.call_args
assert kwargs["_source"] == {"excludes": ["embedding"]}
@pytest.mark.skipif(VERSION[0] == 8, reason="Elasticsearch 8 uses a different client call")
@pytest.mark.unit
def test_query_excluded_meta_data_return_embedding_true(self, mocked_document_store):
super().test_query_excluded_meta_data_return_embedding_true(mocked_document_store)
@pytest.mark.skipif(VERSION[0] == 7, reason="Elasticsearch 7 uses a different client call")
@pytest.mark.unit
def test_query_excluded_meta_data_return_embedding_true_es8(self, mocked_document_store):
mocked_document_store.return_embedding = True
mocked_document_store.excluded_meta_data = ["foo", "embedding"]
mocked_document_store.query(self.query)
_, kwargs = mocked_document_store.client.options().search.call_args
# we expect "embedding" was removed from the final query
assert kwargs["_source"] == {"excludes": ["foo"]}
@pytest.mark.skipif(VERSION[0] == 8, reason="Elasticsearch 8 uses a different client call")
@pytest.mark.unit
def test_query_excluded_meta_data_return_embedding_false(self, mocked_document_store):
super().test_query_excluded_meta_data_return_embedding_false(mocked_document_store)
@pytest.mark.skipif(VERSION[0] == 7, reason="Elasticsearch 7 uses a different client call")
@pytest.mark.unit
def test_query_excluded_meta_data_return_embedding_false_es8(self, mocked_document_store):
mocked_document_store.return_embedding = False
mocked_document_store.excluded_meta_data = ["foo"]
mocked_document_store.query(self.query)
# assert the resulting body is consistent with the `excluded_meta_data` value
_, kwargs = mocked_document_store.client.options().search.call_args
assert kwargs["_source"] == {"excludes": ["foo", "embedding"]}
@pytest.mark.skipif(VERSION[0] == 8, reason="Elasticsearch 8 uses a different client call")
@pytest.mark.unit
def test_get_document_by_id_return_embedding_true(self, mocked_document_store):
super().test_get_document_by_id_return_embedding_true(mocked_document_store)
@pytest.mark.skipif(VERSION[0] == 7, reason="Elasticsearch 7 uses a different client call")
@pytest.mark.unit
def test_get_document_by_id_return_embedding_true_es8(self, mocked_document_store):
mocked_document_store.return_embedding = True
mocked_document_store.get_document_by_id("123")
# assert the resulting body is consistent with the `excluded_meta_data` value
_, kwargs = mocked_document_store.client.options().search.call_args
assert "_source" not in kwargs