diff --git a/haystack/document_stores/elasticsearch/base.py b/haystack/document_stores/elasticsearch/base.py index 113cf1f23..44edc085b 100644 --- a/haystack/document_stores/elasticsearch/base.py +++ b/haystack/document_stores/elasticsearch/base.py @@ -116,7 +116,7 @@ class _ElasticsearchDocumentStore(SearchEngineDocumentStore): ) try: - result = self.client.search(index=index, **body, headers=headers)["hits"]["hits"] + result = self._search(index=index, **body, headers=headers)["hits"]["hits"] if len(result) == 0: count_documents = self.get_document_count(index=index, headers=headers) if count_documents == 0: @@ -197,7 +197,7 @@ class _ElasticsearchDocumentStore(SearchEngineDocumentStore): } try: - self.client.indices.create(index=index_name, **mapping, headers=headers) + self._index_create(index=index_name, **mapping, headers=headers) except self._RequestError as e: # With multiple workers we need to avoid race conditions, where: # - there's no index in the beginning @@ -226,7 +226,7 @@ class _ElasticsearchDocumentStore(SearchEngineDocumentStore): } } try: - self.client.indices.create(index=index_name, **mapping, headers=headers) + self._index_create(index=index_name, **mapping, headers=headers) except self._RequestError as e: # With multiple workers we need to avoid race conditions, where: # - there's no index in the beginning @@ -239,7 +239,7 @@ class _ElasticsearchDocumentStore(SearchEngineDocumentStore): """ Validates an existing document index. If there's no embedding field, we'll add it. """ - indices = self.client.indices.get(index=index_name, headers=headers) + indices = self._index_get(index=index_name, headers=headers) if not any(indices): logger.warning( @@ -267,7 +267,7 @@ class _ElasticsearchDocumentStore(SearchEngineDocumentStore): mapping["properties"][search_field] = ( {"type": "text", "analyzer": "synonym"} if self.synonyms else {"type": "text"} ) - self.client.indices.put_mapping(index=index_id, body=mapping, headers=headers) + self._index_put_mapping(index=index_id, body=mapping, headers=headers) if self.embedding_field: if ( @@ -280,7 +280,7 @@ class _ElasticsearchDocumentStore(SearchEngineDocumentStore): f"of type '{mapping['properties'][self.embedding_field]['type']}'." ) mapping["properties"][self.embedding_field] = {"type": "dense_vector", "dims": self.embedding_dim} - self.client.indices.put_mapping(index=index_id, body=mapping, headers=headers) + self._index_put_mapping(index=index_id, body=mapping, headers=headers) def _get_vector_similarity_query(self, query_emb: np.ndarray, top_k: int): """ diff --git a/haystack/document_stores/elasticsearch/es8.py b/haystack/document_stores/elasticsearch/es8.py index 438fb9346..852401408 100644 --- a/haystack/document_stores/elasticsearch/es8.py +++ b/haystack/document_stores/elasticsearch/es8.py @@ -1,7 +1,8 @@ import logging -from typing import List, Optional, Union +from typing import List, Optional, Union, Dict, Any from haystack.lazy_imports import LazyImport +from haystack import Document with LazyImport("Run 'pip install farm-haystack[elasticsearch8]'") as es_import: from elasticsearch import Elasticsearch, RequestError @@ -294,3 +295,63 @@ class ElasticsearchDocumentStore(_ElasticsearchDocumentStore): f"correct credentials if you are using a secured Elasticsearch instance." ) return client + + def _index_exists(self, index_name: str, headers: Optional[Dict[str, str]] = None) -> bool: + if logger.isEnabledFor(logging.DEBUG): + if self.client.options(headers=headers).indices.exists_alias(name=index_name): + logger.debug("Index name %s is an alias.", index_name) + + return self.client.options(headers=headers).indices.exists(index=index_name) + + def _index_delete(self, index): + if self._index_exists(index): + self.client.options(ignore_status=[400, 404]).indices.delete(index=index) + logger.info("Index '%s' deleted.", index) + + def _index_refresh(self, index, headers): + if self._index_exists(index): + self.client.options(headers=headers).indices.refresh(index=index) + + def _index_create(self, *args, **kwargs): + headers = kwargs.pop("headers", {}) + return self.client.options(headers=headers).indices.create(*args, **kwargs) + + def _index_get(self, *args, **kwargs): + headers = kwargs.pop("headers", {}) + return self.client.options(headers=headers).indices.get(*args, **kwargs) + + def _index_put_mapping(self, *args, **kwargs): + headers = kwargs.pop("headers", {}) + body = kwargs.pop("body", {}) + return self.client.options(headers=headers).indices.put_mapping(*args, **kwargs, **body) + + def _search(self, *args, **kwargs): + headers = kwargs.pop("headers", {}) + return self.client.options(headers=headers).search(*args, **kwargs) + + def _update(self, *args, **kwargs): + headers = kwargs.pop("headers", {}) + return self.client.options(headers=headers).update(*args, **kwargs) + + def _count(self, *args, **kwargs): + headers = kwargs.pop("headers", {}) + body = kwargs.pop("body", {}) + return self.client.options(headers=headers).count(*args, **kwargs, **body) + + def _delete_by_query(self, *args, **kwargs): + headers = kwargs.pop("headers", {}) + ignore_status = kwargs.pop("ignore", []) + body = kwargs.pop("body", {}) + return self.client.options(headers=headers, ignore_status=ignore_status).delete_by_query( + *args, **kwargs, **body + ) + + def _execute_msearch(self, index: str, body: List[Dict[str, Any]], scale_score: bool) -> List[List[Document]]: + responses = self.client.msearch(index=index, body=body) + documents = [] + for response in responses["responses"]: + result = response["hits"]["hits"] + cur_documents = [self._convert_es_hit_to_document(hit, scale_score=scale_score) for hit in result] + documents.append(cur_documents) + + return documents diff --git a/haystack/document_stores/search_engine.py b/haystack/document_stores/search_engine.py index 753cd8f61..6016237ab 100644 --- a/haystack/document_stores/search_engine.py +++ b/haystack/document_stores/search_engine.py @@ -1636,6 +1636,15 @@ class SearchEngineDocumentStore(KeywordDocumentStore): if self._index_exists(index): self.client.indices.refresh(index=index, headers=headers) + def _index_create(self, *args, **kwargs): + return self.client.indices.create(*args, **kwargs) + + def _index_get(self, *args, **kwargs): + return self.client.indices.get(*args, **kwargs) + + def _index_put_mapping(self, *args, **kwargs): + return self.client.indices.put_mapping(*args, **kwargs) + def _search(self, *args, **kwargs): return self.client.search(*args, **kwargs)