feat: Check version of Elasticsearch server and add support for Elasticsearch <= 7.5 (#5320)

* Check ES server version + add support for ES <= 7.5

* Adapt comment

* PR feedback
This commit is contained in:
bogdankostic 2023-07-13 14:50:43 +02:00 committed by GitHub
parent 63fd63ff23
commit 237d67dbfd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 59 additions and 2 deletions

View File

@ -282,6 +282,18 @@ class _ElasticsearchDocumentStore(SearchEngineDocumentStore):
mapping["properties"][self.embedding_field] = {"type": "dense_vector", "dims": self.embedding_dim}
self._index_put_mapping(index=index_id, body=mapping, headers=headers)
def _validate_server_version(self, expected_version: int):
"""
Validate that the Elasticsearch server version is compatible with the used ElasticsearchDocumentStore.
"""
if self.server_version[0] != expected_version:
logger.warning(
"This ElasticsearchDocumentStore has been built for Elasticsearch %s, but the detected version of the "
"Elasticsearch server is %s. Unexpected behaviors or errors may occur due to version incompatibility.",
expected_version,
".".join(map(str, self.server_version)),
)
def _get_vector_similarity_query(self, query_emb: np.ndarray, top_k: int):
"""
Generate Elasticsearch query for vector similarity.
@ -302,12 +314,19 @@ class _ElasticsearchDocumentStore(SearchEngineDocumentStore):
if self.skip_missing_embeddings:
script_score_query = {"bool": {"filter": {"bool": {"must": [{"exists": {"field": self.embedding_field}}]}}}}
# Elasticsearch 7.6 introduced a breaking change regarding the vector function signatures:
# https://www.elastic.co/guide/en/elasticsearch/reference/7.6/breaking-changes-7.6.html#_update_to_vector_function_signatures
if self.server_version[0] == 7 and self.server_version[1] < 6:
similarity_script_source = f"{similarity_fn_name}(params.query_vector,doc['{self.embedding_field}']) + 1000"
else:
similarity_script_source = f"{similarity_fn_name}(params.query_vector,'{self.embedding_field}') + 1000"
query = {
"script_score": {
"query": script_score_query,
"script": {
# offset score to ensure a positive range as required by Elasticsearch
"source": f"{similarity_fn_name}(params.query_vector,'{self.embedding_field}') + 1000",
"source": similarity_script_source,
"params": {"query_vector": query_emb.tolist()},
},
}

View File

@ -178,6 +178,8 @@ class ElasticsearchDocumentStore(_ElasticsearchDocumentStore):
batch_size=batch_size,
)
self._validate_server_version(expected_version=7)
def _do_bulk(self, *args, **kwargs):
"""Override the base class method to use the Elasticsearch client"""
return bulk(*args, **kwargs)

View File

@ -186,6 +186,8 @@ class ElasticsearchDocumentStore(_ElasticsearchDocumentStore):
batch_size=batch_size,
)
self._validate_server_version(expected_version=8)
def _do_bulk(self, *args, **kwargs):
"""Override the base class method to use the Elasticsearch client"""
return bulk(*args, **kwargs)

View File

@ -106,6 +106,8 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
raise DocumentStoreError(
f"Invalid value {similarity} for similarity, choose between 'cosine', 'l2' and 'dot_product'"
)
client_info = self.client.info()
self.server_version = tuple(int(num) for num in client_info["version"]["number"].split("."))
self._init_indices(
index=index, label_index=label_index, create_index=create_index, recreate_index=recreate_index

View File

@ -56,7 +56,13 @@ class TestElasticsearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngine
ElasticsearchDocumentStore equipped with a mocked client
"""
with patch(f"{ElasticsearchDocumentStore.__module__}.ElasticsearchDocumentStore._init_elastic_client"):
with patch(
f"{ElasticsearchDocumentStore.__module__}.ElasticsearchDocumentStore._init_elastic_client"
) as mocked_init_client:
if VERSION[0] == 7:
mocked_init_client().info.return_value = {"version": {"number": "7.17.6"}}
else:
mocked_init_client().info.return_value = {"version": {"number": "8.8.0"}}
class DSMock(ElasticsearchDocumentStore):
# We mock a subclass to avoid messing up the actual class object
@ -376,6 +382,31 @@ class TestElasticsearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngine
mocked_document_store.write_documents(documents)
assert mocked_bulk.call_count == 5
@pytest.mark.unit
def test_get_vector_similarity_query(self, mocked_document_store):
"""
Test that the source field of the vector similarity query is correctly formatted for ES 7.6 and above.
We test this to make sure we use the correct syntax for newer ES versions.
"""
vec_sim_query = mocked_document_store._get_vector_similarity_query(np.random.rand(3).astype(np.float32), 10)
assert vec_sim_query["script_score"]["script"]["source"] == "dotProduct(params.query_vector,'embedding') + 1000"
@pytest.mark.unit
def test_get_vector_similarity_query_es_7_5_and_below(self, mocked_document_store):
"""
Test that the source field of the vector similarity query is correctly formatter for ES 7.5 and below.
We test this to make sure we use the correct syntax for ES versions older than 7.6, as the syntax changed
in 7.6.
"""
# Patch server version to be 7.5.0
mocked_document_store.server_version = (7, 5, 0)
vec_sim_query = mocked_document_store._get_vector_similarity_query(np.random.rand(3).astype(np.float32), 10)
assert (
vec_sim_query["script_score"]["script"]["source"]
== "dotProduct(params.query_vector,doc['embedding']) + 1000"
)
# The following tests are overridden only to be able to skip them depending on ES version
@pytest.mark.skipif(VERSION[0] == 8, reason="Elasticsearch 8 uses a different client call")

View File

@ -60,6 +60,7 @@ class TestOpenSearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngineDoc
opensearch_mock = MagicMock()
opensearch_mock.indices.exists.return_value = True
opensearch_mock.indices.get.return_value = {self.index_name: existing_index}
opensearch_mock.info.return_value = {"version": {"number": "1.3.5"}}
DSMock._init_client = MagicMock()
DSMock._init_client.configure_mock(return_value=opensearch_mock)
dsMock = DSMock()