From 237d67dbfd89e4253ece7e3ca2e6c65a293e68c7 Mon Sep 17 00:00:00 2001 From: bogdankostic Date: Thu, 13 Jul 2023 14:50:43 +0200 Subject: [PATCH] feat: Check version of Elasticsearch server and add support for Elasticsearch <= 7.5 (#5320) * Check ES server version + add support for ES <= 7.5 * Adapt comment * PR feedback --- .../document_stores/elasticsearch/base.py | 21 +++++++++++- haystack/document_stores/elasticsearch/es7.py | 2 ++ haystack/document_stores/elasticsearch/es8.py | 2 ++ haystack/document_stores/search_engine.py | 2 ++ test/document_stores/test_elasticsearch.py | 33 ++++++++++++++++++- test/document_stores/test_opensearch.py | 1 + 6 files changed, 59 insertions(+), 2 deletions(-) diff --git a/haystack/document_stores/elasticsearch/base.py b/haystack/document_stores/elasticsearch/base.py index 44edc085b..39d07a30c 100644 --- a/haystack/document_stores/elasticsearch/base.py +++ b/haystack/document_stores/elasticsearch/base.py @@ -282,6 +282,18 @@ class _ElasticsearchDocumentStore(SearchEngineDocumentStore): mapping["properties"][self.embedding_field] = {"type": "dense_vector", "dims": self.embedding_dim} self._index_put_mapping(index=index_id, body=mapping, headers=headers) + def _validate_server_version(self, expected_version: int): + """ + Validate that the Elasticsearch server version is compatible with the used ElasticsearchDocumentStore. + """ + if self.server_version[0] != expected_version: + logger.warning( + "This ElasticsearchDocumentStore has been built for Elasticsearch %s, but the detected version of the " + "Elasticsearch server is %s. Unexpected behaviors or errors may occur due to version incompatibility.", + expected_version, + ".".join(map(str, self.server_version)), + ) + def _get_vector_similarity_query(self, query_emb: np.ndarray, top_k: int): """ Generate Elasticsearch query for vector similarity. @@ -302,12 +314,19 @@ class _ElasticsearchDocumentStore(SearchEngineDocumentStore): if self.skip_missing_embeddings: script_score_query = {"bool": {"filter": {"bool": {"must": [{"exists": {"field": self.embedding_field}}]}}}} + # Elasticsearch 7.6 introduced a breaking change regarding the vector function signatures: + # https://www.elastic.co/guide/en/elasticsearch/reference/7.6/breaking-changes-7.6.html#_update_to_vector_function_signatures + if self.server_version[0] == 7 and self.server_version[1] < 6: + similarity_script_source = f"{similarity_fn_name}(params.query_vector,doc['{self.embedding_field}']) + 1000" + else: + similarity_script_source = f"{similarity_fn_name}(params.query_vector,'{self.embedding_field}') + 1000" + query = { "script_score": { "query": script_score_query, "script": { # offset score to ensure a positive range as required by Elasticsearch - "source": f"{similarity_fn_name}(params.query_vector,'{self.embedding_field}') + 1000", + "source": similarity_script_source, "params": {"query_vector": query_emb.tolist()}, }, } diff --git a/haystack/document_stores/elasticsearch/es7.py b/haystack/document_stores/elasticsearch/es7.py index aa54c4350..970a0abc1 100644 --- a/haystack/document_stores/elasticsearch/es7.py +++ b/haystack/document_stores/elasticsearch/es7.py @@ -178,6 +178,8 @@ class ElasticsearchDocumentStore(_ElasticsearchDocumentStore): batch_size=batch_size, ) + self._validate_server_version(expected_version=7) + def _do_bulk(self, *args, **kwargs): """Override the base class method to use the Elasticsearch client""" return bulk(*args, **kwargs) diff --git a/haystack/document_stores/elasticsearch/es8.py b/haystack/document_stores/elasticsearch/es8.py index 852401408..47c96ab17 100644 --- a/haystack/document_stores/elasticsearch/es8.py +++ b/haystack/document_stores/elasticsearch/es8.py @@ -186,6 +186,8 @@ class ElasticsearchDocumentStore(_ElasticsearchDocumentStore): batch_size=batch_size, ) + self._validate_server_version(expected_version=8) + def _do_bulk(self, *args, **kwargs): """Override the base class method to use the Elasticsearch client""" return bulk(*args, **kwargs) diff --git a/haystack/document_stores/search_engine.py b/haystack/document_stores/search_engine.py index 878c31a4e..1aad3a00c 100644 --- a/haystack/document_stores/search_engine.py +++ b/haystack/document_stores/search_engine.py @@ -106,6 +106,8 @@ class SearchEngineDocumentStore(KeywordDocumentStore): raise DocumentStoreError( f"Invalid value {similarity} for similarity, choose between 'cosine', 'l2' and 'dot_product'" ) + client_info = self.client.info() + self.server_version = tuple(int(num) for num in client_info["version"]["number"].split(".")) self._init_indices( index=index, label_index=label_index, create_index=create_index, recreate_index=recreate_index diff --git a/test/document_stores/test_elasticsearch.py b/test/document_stores/test_elasticsearch.py index 1cc82c919..cd73198a8 100644 --- a/test/document_stores/test_elasticsearch.py +++ b/test/document_stores/test_elasticsearch.py @@ -56,7 +56,13 @@ class TestElasticsearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngine ElasticsearchDocumentStore equipped with a mocked client """ - with patch(f"{ElasticsearchDocumentStore.__module__}.ElasticsearchDocumentStore._init_elastic_client"): + with patch( + f"{ElasticsearchDocumentStore.__module__}.ElasticsearchDocumentStore._init_elastic_client" + ) as mocked_init_client: + if VERSION[0] == 7: + mocked_init_client().info.return_value = {"version": {"number": "7.17.6"}} + else: + mocked_init_client().info.return_value = {"version": {"number": "8.8.0"}} class DSMock(ElasticsearchDocumentStore): # We mock a subclass to avoid messing up the actual class object @@ -376,6 +382,31 @@ class TestElasticsearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngine mocked_document_store.write_documents(documents) assert mocked_bulk.call_count == 5 + @pytest.mark.unit + def test_get_vector_similarity_query(self, mocked_document_store): + """ + Test that the source field of the vector similarity query is correctly formatted for ES 7.6 and above. + We test this to make sure we use the correct syntax for newer ES versions. + """ + vec_sim_query = mocked_document_store._get_vector_similarity_query(np.random.rand(3).astype(np.float32), 10) + assert vec_sim_query["script_score"]["script"]["source"] == "dotProduct(params.query_vector,'embedding') + 1000" + + @pytest.mark.unit + def test_get_vector_similarity_query_es_7_5_and_below(self, mocked_document_store): + """ + Test that the source field of the vector similarity query is correctly formatter for ES 7.5 and below. + We test this to make sure we use the correct syntax for ES versions older than 7.6, as the syntax changed + in 7.6. + """ + # Patch server version to be 7.5.0 + mocked_document_store.server_version = (7, 5, 0) + + vec_sim_query = mocked_document_store._get_vector_similarity_query(np.random.rand(3).astype(np.float32), 10) + assert ( + vec_sim_query["script_score"]["script"]["source"] + == "dotProduct(params.query_vector,doc['embedding']) + 1000" + ) + # The following tests are overridden only to be able to skip them depending on ES version @pytest.mark.skipif(VERSION[0] == 8, reason="Elasticsearch 8 uses a different client call") diff --git a/test/document_stores/test_opensearch.py b/test/document_stores/test_opensearch.py index 628a7f26b..fb55c2417 100644 --- a/test/document_stores/test_opensearch.py +++ b/test/document_stores/test_opensearch.py @@ -60,6 +60,7 @@ class TestOpenSearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngineDoc opensearch_mock = MagicMock() opensearch_mock.indices.exists.return_value = True opensearch_mock.indices.get.return_value = {self.index_name: existing_index} + opensearch_mock.info.return_value = {"version": {"number": "1.3.5"}} DSMock._init_client = MagicMock() DSMock._init_client.configure_mock(return_value=opensearch_mock) dsMock = DSMock()