mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-27 18:06:17 +00:00
feat: Check version of Elasticsearch server and add support for Elasticsearch <= 7.5 (#5320)
* Check ES server version + add support for ES <= 7.5 * Adapt comment * PR feedback
This commit is contained in:
parent
63fd63ff23
commit
237d67dbfd
@ -282,6 +282,18 @@ class _ElasticsearchDocumentStore(SearchEngineDocumentStore):
|
||||
mapping["properties"][self.embedding_field] = {"type": "dense_vector", "dims": self.embedding_dim}
|
||||
self._index_put_mapping(index=index_id, body=mapping, headers=headers)
|
||||
|
||||
def _validate_server_version(self, expected_version: int):
|
||||
"""
|
||||
Validate that the Elasticsearch server version is compatible with the used ElasticsearchDocumentStore.
|
||||
"""
|
||||
if self.server_version[0] != expected_version:
|
||||
logger.warning(
|
||||
"This ElasticsearchDocumentStore has been built for Elasticsearch %s, but the detected version of the "
|
||||
"Elasticsearch server is %s. Unexpected behaviors or errors may occur due to version incompatibility.",
|
||||
expected_version,
|
||||
".".join(map(str, self.server_version)),
|
||||
)
|
||||
|
||||
def _get_vector_similarity_query(self, query_emb: np.ndarray, top_k: int):
|
||||
"""
|
||||
Generate Elasticsearch query for vector similarity.
|
||||
@ -302,12 +314,19 @@ class _ElasticsearchDocumentStore(SearchEngineDocumentStore):
|
||||
if self.skip_missing_embeddings:
|
||||
script_score_query = {"bool": {"filter": {"bool": {"must": [{"exists": {"field": self.embedding_field}}]}}}}
|
||||
|
||||
# Elasticsearch 7.6 introduced a breaking change regarding the vector function signatures:
|
||||
# https://www.elastic.co/guide/en/elasticsearch/reference/7.6/breaking-changes-7.6.html#_update_to_vector_function_signatures
|
||||
if self.server_version[0] == 7 and self.server_version[1] < 6:
|
||||
similarity_script_source = f"{similarity_fn_name}(params.query_vector,doc['{self.embedding_field}']) + 1000"
|
||||
else:
|
||||
similarity_script_source = f"{similarity_fn_name}(params.query_vector,'{self.embedding_field}') + 1000"
|
||||
|
||||
query = {
|
||||
"script_score": {
|
||||
"query": script_score_query,
|
||||
"script": {
|
||||
# offset score to ensure a positive range as required by Elasticsearch
|
||||
"source": f"{similarity_fn_name}(params.query_vector,'{self.embedding_field}') + 1000",
|
||||
"source": similarity_script_source,
|
||||
"params": {"query_vector": query_emb.tolist()},
|
||||
},
|
||||
}
|
||||
|
@ -178,6 +178,8 @@ class ElasticsearchDocumentStore(_ElasticsearchDocumentStore):
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
self._validate_server_version(expected_version=7)
|
||||
|
||||
def _do_bulk(self, *args, **kwargs):
|
||||
"""Override the base class method to use the Elasticsearch client"""
|
||||
return bulk(*args, **kwargs)
|
||||
|
@ -186,6 +186,8 @@ class ElasticsearchDocumentStore(_ElasticsearchDocumentStore):
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
self._validate_server_version(expected_version=8)
|
||||
|
||||
def _do_bulk(self, *args, **kwargs):
|
||||
"""Override the base class method to use the Elasticsearch client"""
|
||||
return bulk(*args, **kwargs)
|
||||
|
@ -106,6 +106,8 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
||||
raise DocumentStoreError(
|
||||
f"Invalid value {similarity} for similarity, choose between 'cosine', 'l2' and 'dot_product'"
|
||||
)
|
||||
client_info = self.client.info()
|
||||
self.server_version = tuple(int(num) for num in client_info["version"]["number"].split("."))
|
||||
|
||||
self._init_indices(
|
||||
index=index, label_index=label_index, create_index=create_index, recreate_index=recreate_index
|
||||
|
@ -56,7 +56,13 @@ class TestElasticsearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngine
|
||||
ElasticsearchDocumentStore equipped with a mocked client
|
||||
"""
|
||||
|
||||
with patch(f"{ElasticsearchDocumentStore.__module__}.ElasticsearchDocumentStore._init_elastic_client"):
|
||||
with patch(
|
||||
f"{ElasticsearchDocumentStore.__module__}.ElasticsearchDocumentStore._init_elastic_client"
|
||||
) as mocked_init_client:
|
||||
if VERSION[0] == 7:
|
||||
mocked_init_client().info.return_value = {"version": {"number": "7.17.6"}}
|
||||
else:
|
||||
mocked_init_client().info.return_value = {"version": {"number": "8.8.0"}}
|
||||
|
||||
class DSMock(ElasticsearchDocumentStore):
|
||||
# We mock a subclass to avoid messing up the actual class object
|
||||
@ -376,6 +382,31 @@ class TestElasticsearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngine
|
||||
mocked_document_store.write_documents(documents)
|
||||
assert mocked_bulk.call_count == 5
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_vector_similarity_query(self, mocked_document_store):
|
||||
"""
|
||||
Test that the source field of the vector similarity query is correctly formatted for ES 7.6 and above.
|
||||
We test this to make sure we use the correct syntax for newer ES versions.
|
||||
"""
|
||||
vec_sim_query = mocked_document_store._get_vector_similarity_query(np.random.rand(3).astype(np.float32), 10)
|
||||
assert vec_sim_query["script_score"]["script"]["source"] == "dotProduct(params.query_vector,'embedding') + 1000"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_vector_similarity_query_es_7_5_and_below(self, mocked_document_store):
|
||||
"""
|
||||
Test that the source field of the vector similarity query is correctly formatter for ES 7.5 and below.
|
||||
We test this to make sure we use the correct syntax for ES versions older than 7.6, as the syntax changed
|
||||
in 7.6.
|
||||
"""
|
||||
# Patch server version to be 7.5.0
|
||||
mocked_document_store.server_version = (7, 5, 0)
|
||||
|
||||
vec_sim_query = mocked_document_store._get_vector_similarity_query(np.random.rand(3).astype(np.float32), 10)
|
||||
assert (
|
||||
vec_sim_query["script_score"]["script"]["source"]
|
||||
== "dotProduct(params.query_vector,doc['embedding']) + 1000"
|
||||
)
|
||||
|
||||
# The following tests are overridden only to be able to skip them depending on ES version
|
||||
|
||||
@pytest.mark.skipif(VERSION[0] == 8, reason="Elasticsearch 8 uses a different client call")
|
||||
|
@ -60,6 +60,7 @@ class TestOpenSearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngineDoc
|
||||
opensearch_mock = MagicMock()
|
||||
opensearch_mock.indices.exists.return_value = True
|
||||
opensearch_mock.indices.get.return_value = {self.index_name: existing_index}
|
||||
opensearch_mock.info.return_value = {"version": {"number": "1.3.5"}}
|
||||
DSMock._init_client = MagicMock()
|
||||
DSMock._init_client.configure_mock(return_value=opensearch_mock)
|
||||
dsMock = DSMock()
|
||||
|
Loading…
x
Reference in New Issue
Block a user