refactor: add a new Document Store supporting Elasticsearch 8 (#5231)

* introduce es8

* prepare tests

* fix unit tests

* adjust tests

* install elastic_transport package

* make mypy happy

* fix opensearch tests
This commit is contained in:
Massimiliano Pippi 2023-06-29 16:40:10 +02:00 committed by GitHub
parent d5c13aa71d
commit 037e4f24ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 335 additions and 5 deletions

View File

@ -1 +1,12 @@
from .es7 import ElasticsearchDocumentStore
try:
# Use appropriate ElasticsearchDocumentStore depending on ES client version
from elasticsearch import VERSION
if VERSION[0] == 8:
from .es8 import ElasticsearchDocumentStore # type: ignore
else:
from .es7 import ElasticsearchDocumentStore # type: ignore
except (ModuleNotFoundError, ImportError):
# Import ES 7 as default if ES is not installed to raise the error message that elasticsearch extra is needed
from .es7 import ElasticsearchDocumentStore # type: ignore

View File

@ -0,0 +1,298 @@
import logging
from typing import List, Optional, Union
from haystack.lazy_imports import LazyImport
with LazyImport("Run 'pip install farm-haystack[elasticsearch8]'") as es_import:
from elasticsearch import Elasticsearch, RequestError
from elasticsearch.helpers import bulk, scan
from elastic_transport import RequestsHttpNode, Urllib3HttpNode
from .base import _ElasticsearchDocumentStore
logger = logging.getLogger(__name__)
def _prepare_hosts(host: Union[str, List[str]], port: Union[int, List[int]], scheme: str):
"""
Create a list of host(s), port(s) and scheme to allow direct client connections to multiple nodes,
in the format expected by the client.
"""
if isinstance(host, list):
if isinstance(port, list):
if not len(port) == len(host):
raise ValueError("Length of list `host` must match length of list `port`")
hosts = [{"host": h, "port": p, "scheme": scheme} for h, p in zip(host, port)]
else:
hosts = [{"host": h, "port": port, "scheme": scheme} for h in host]
else:
hosts = [{"host": host, "port": port, "scheme": scheme}]
return hosts
class ElasticsearchDocumentStore(_ElasticsearchDocumentStore):
def __init__(
self,
host: Union[str, List[str]] = "localhost",
port: Union[int, List[int]] = 9200,
username: str = "",
password: str = "",
api_key_id: Optional[str] = None,
api_key: Optional[str] = None,
aws4auth=None,
index: str = "document",
label_index: str = "label",
search_fields: Union[str, list] = "content",
content_field: str = "content",
name_field: str = "name",
embedding_field: str = "embedding",
embedding_dim: int = 768,
custom_mapping: Optional[dict] = None,
excluded_meta_data: Optional[list] = None,
analyzer: str = "standard",
scheme: str = "http",
ca_certs: Optional[str] = None,
verify_certs: bool = True,
recreate_index: bool = False,
create_index: bool = True,
refresh_type: str = "wait_for",
similarity: str = "dot_product",
timeout: int = 300,
return_embedding: bool = False,
duplicate_documents: str = "overwrite",
scroll: str = "1d",
skip_missing_embeddings: bool = True,
synonyms: Optional[List] = None,
synonym_type: str = "synonym",
use_system_proxy: bool = False,
batch_size: int = 10_000,
):
"""
A DocumentStore using Elasticsearch to store and query the documents for our search.
* Keeps all the logic to store and query documents from Elastic, incl. mapping of fields, adding filters or boosts to your queries, and storing embeddings
* You can either use an existing Elasticsearch index or create a new one via haystack
* Retrievers operate on top of this DocumentStore to find the relevant documents for a query
:param host: url(s) of elasticsearch nodes
:param port: port(s) of elasticsearch nodes
:param username: username (standard authentication via http_auth)
:param password: password (standard authentication via http_auth)
:param api_key_id: ID of the API key (altenative authentication mode to the above http_auth)
:param api_key: Secret value of the API key (altenative authentication mode to the above http_auth)
:param aws4auth: Authentication for usage with aws elasticsearch (can be generated with the requests-aws4auth package)
:param index: Name of index in elasticsearch to use for storing the documents that we want to search. If not existing yet, we will create one.
:param label_index: Name of index in elasticsearch to use for storing labels. If not existing yet, we will create one.
:param search_fields: Name of fields used by BM25Retriever to find matches in the docs to our incoming query (using elastic's multi_match query), e.g. ["title", "full_text"]
:param content_field: Name of field that might contain the answer and will therefore be passed to the Reader Model (e.g. "full_text").
If no Reader is used (e.g. in FAQ-Style QA) the plain content of this field will just be returned.
:param name_field: Name of field that contains the title of the the doc
:param embedding_field: Name of field containing an embedding vector (Only needed when using a dense retriever (e.g. DensePassageRetriever, EmbeddingRetriever) on top)
:param embedding_dim: Dimensionality of embedding vector (Only needed when using a dense retriever (e.g. DensePassageRetriever, EmbeddingRetriever) on top)
:param custom_mapping: If you want to use your own custom mapping for creating a new index in Elasticsearch, you can supply it here as a dictionary.
:param analyzer: Specify the default analyzer from one of the built-ins when creating a new Elasticsearch Index.
Elasticsearch also has built-in analyzers for different languages (e.g. impacting tokenization). More info at:
https://www.elastic.co/guide/en/elasticsearch/reference/7.9/analysis-analyzers.html
:param excluded_meta_data: Name of fields in Elasticsearch that should not be returned (e.g. [field_one, field_two]).
Helpful if you have fields with long, irrelevant content that you don't want to display in results (e.g. embedding vectors).
:param scheme: 'https' or 'http', protocol used to connect to your elasticsearch instance
:param ca_certs: Root certificates for SSL: it is a path to certificate authority (CA) certs on disk. You can use certifi package with certifi.where() to find where the CA certs file is located in your machine.
:param verify_certs: Whether to be strict about ca certificates
:param recreate_index: If set to True, an existing elasticsearch index will be deleted and a new one will be
created using the config you are using for initialization. Be aware that all data in the old index will be
lost if you choose to recreate the index. Be aware that both the document_index and the label_index will
be recreated.
:param create_index:
Whether to try creating a new index (If the index of that name is already existing, we will just continue in any case)
..deprecated:: 2.0
This param is deprecated. In the next major version we will always try to create an index if there is no
existing index (the current behaviour when create_index=True). If you are looking to recreate an
existing index by deleting it first if it already exist use param recreate_index.
:param refresh_type: Type of ES refresh used to control when changes made by a request (e.g. bulk) are made visible to search.
If set to 'wait_for', continue only after changes are visible (slow, but safe).
If set to 'false', continue directly (fast, but sometimes unintuitive behaviour when docs are not immediately available after ingestion).
More info at https://www.elastic.co/guide/en/elasticsearch/reference/6.8/docs-refresh.html
:param similarity: The similarity function used to compare document vectors. 'dot_product' is the default since it is
more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence BERT model.
:param timeout: Number of seconds after which an ElasticSearch request times out.
:param return_embedding: To return document embedding
:param duplicate_documents: Handle duplicates document based on parameter options.
Parameter options : ( 'skip','overwrite','fail')
skip: Ignore the duplicates documents
overwrite: Update any existing documents with the same ID when adding documents.
fail: an error is raised if the document ID of the document being added already
exists.
:param scroll: Determines how long the current index is fixed, e.g. during updating all documents with embeddings.
Defaults to "1d" and should not be larger than this. Can also be in minutes "5m" or hours "15h"
For details, see https://www.elastic.co/guide/en/elasticsearch/reference/current/scroll-api.html
:param skip_missing_embeddings: Parameter to control queries based on vector similarity when indexed documents miss embeddings.
Parameter options: (True, False)
False: Raises exception if one or more documents do not have embeddings at query time
True: Query will ignore all documents without embeddings (recommended if you concurrently index and query)
:param synonyms: List of synonyms can be passed while elasticsearch initialization.
For example: [ "foo, bar => baz",
"foozball , foosball" ]
More info at https://www.elastic.co/guide/en/elasticsearch/reference/current/analysis-synonym-tokenfilter.html
:param synonym_type: Synonym filter type can be passed.
Synonym or Synonym_graph to handle synonyms, including multi-word synonyms correctly during the analysis process.
More info at https://www.elastic.co/guide/en/elasticsearch/reference/current/analysis-synonym-graph-tokenfilter.html
:param use_system_proxy: Whether to use system proxy.
:param batch_size: Number of Documents to index at once / Number of queries to execute at once. If you face
memory issues, decrease the batch_size.
"""
es_import.check()
# Base constructor might need the client to be ready, create it first
client = self._init_elastic_client(
host=host,
port=port,
username=username,
password=password,
api_key=api_key,
api_key_id=api_key_id,
scheme=scheme,
ca_certs=ca_certs,
verify_certs=verify_certs,
timeout=timeout,
use_system_proxy=use_system_proxy,
)
super().__init__(
client=client,
index=index,
label_index=label_index,
search_fields=search_fields,
content_field=content_field,
name_field=name_field,
embedding_field=embedding_field,
embedding_dim=embedding_dim,
custom_mapping=custom_mapping,
excluded_meta_data=excluded_meta_data,
analyzer=analyzer,
recreate_index=recreate_index,
create_index=create_index,
refresh_type=refresh_type,
similarity=similarity,
return_embedding=return_embedding,
duplicate_documents=duplicate_documents,
scroll=scroll,
skip_missing_embeddings=skip_missing_embeddings,
synonyms=synonyms,
synonym_type=synonym_type,
batch_size=batch_size,
)
# Let the base class trap the right exception from the elasticpy client
self._RequestError = RequestError
def _do_bulk(self, *args, **kwargs):
"""Override the base class method to use the Elasticsearch client"""
return bulk(*args, **kwargs)
def _do_scan(self, *args, **kwargs):
"""Override the base class method to use the Elasticsearch client"""
return scan(*args, **kwargs)
@classmethod
def _init_elastic_client(
cls,
host: Union[str, List[str]],
port: Union[int, List[int]],
username: str,
password: str,
api_key_id: Optional[str],
api_key: Optional[str],
scheme: str,
ca_certs: Optional[str],
verify_certs: bool,
timeout: int,
use_system_proxy: bool,
aws4auth: Optional[str] = "",
) -> Elasticsearch:
hosts = _prepare_hosts(host, port, scheme)
if aws4auth:
logger.warning("AWS authentication is not supported in Elasticsearch version 8 and later!")
if (api_key or api_key_id) and not (api_key and api_key_id):
raise ValueError("You must provide either both or none of `api_key_id` and `api_key`.")
node_class = RequestsHttpNode if use_system_proxy else Urllib3HttpNode
if api_key_id and api_key:
# api key authentication
if ca_certs is not None:
client = Elasticsearch(
hosts=hosts,
api_key=(api_key_id, api_key),
ca_certs=ca_certs,
verify_certs=verify_certs,
request_timeout=timeout,
node_class=node_class,
)
else:
client = Elasticsearch(
hosts=hosts,
api_key=(api_key_id, api_key),
verify_certs=verify_certs,
request_timeout=timeout,
node_class=node_class,
)
elif username:
# standard http_auth
if ca_certs is not None:
client = Elasticsearch(
hosts=hosts,
basic_auth=(username, password),
ca_certs=ca_certs,
verify_certs=verify_certs,
request_timeout=timeout,
node_class=node_class,
)
else:
client = Elasticsearch(
hosts=hosts,
basic_auth=(username, password),
verify_certs=verify_certs,
request_timeout=timeout,
node_class=node_class,
)
else:
# there is no authentication for this elasticsearch instance
if ca_certs is not None:
client = Elasticsearch(
hosts=hosts,
ca_certs=ca_certs,
verify_certs=verify_certs,
request_timeout=timeout,
node_class=node_class,
)
else:
client = Elasticsearch(
hosts=hosts,
basic_auth=(username, password),
verify_certs=verify_certs,
request_timeout=timeout,
node_class=node_class,
)
# Test connection
try:
# ping uses a HEAD request on the root URI. In some cases, the user might not have permissions for that,
# resulting in a HTTP Forbidden 403 response.
if username in ["", "elastic"]:
status = client.ping()
if not status:
raise ConnectionError(
f"Initial connection to Elasticsearch failed. Make sure you run an Elasticsearch instance "
f"at `{hosts}` and that it has finished the initial ramp up (can take > 30s). Also, make sure "
f"you are using the correct credentials if you are using a secured Elasticsearch instance."
)
except Exception:
raise ConnectionError(
f"Initial connection to Elasticsearch failed. Make sure you run an Elasticsearch instance at `{hosts}` "
f"and that it has finished the initial ramp up (can take > 30s). Also, make sure you are using the "
f"correct credentials if you are using a secured Elasticsearch instance."
)
return client

View File

@ -95,6 +95,7 @@ inference = [
]
elasticsearch = [
"elasticsearch>=7.17,<8",
"elastic_transport<8"
]
sql = [
"sqlalchemy>=1.4.2,<2",

View File

@ -4,9 +4,9 @@ from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from elasticsearch import Elasticsearch
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore
from haystack.document_stores.elasticsearch.es7 import Elasticsearch
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore, VERSION
from haystack.document_stores.es_converter import elasticsearch_index_to_document_store
from haystack.document_stores.memory import InMemoryDocumentStore
from haystack.nodes import PreProcessor
@ -294,6 +294,7 @@ class TestElasticsearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngine
# Check if number of transferred_documents is equal to number of unique words.
assert len(transferred_documents) == len(set(" ".join(original_content).split()))
@pytest.mark.skipif(VERSION[0] == 8, reason="Elasticsearch 8 is not supported")
@pytest.mark.unit
def test__init_elastic_client_aws4auth_and_username_raises_warning(
self, caplog, mocked_elastic_search_init, mocked_elastic_search_ping
@ -329,6 +330,7 @@ class TestElasticsearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngine
)
assert len(caplog.records) == 0
@pytest.mark.skipif(VERSION[0] == 8, reason="Elasticsearch 8 is not supported")
@pytest.mark.unit
def test_get_document_by_id_return_embedding_false(self, mocked_document_store):
mocked_document_store.return_embedding = False
@ -337,6 +339,7 @@ class TestElasticsearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngine
_, kwargs = mocked_document_store.client.search.call_args
assert kwargs["_source"] == {"excludes": ["embedding"]}
@pytest.mark.skipif(VERSION[0] == 8, reason="Elasticsearch 8 is not supported")
@pytest.mark.unit
def test_get_document_by_id_excluded_meta_data_has_no_influence(self, mocked_document_store):
mocked_document_store.excluded_meta_data = ["foo"]
@ -349,6 +352,23 @@ class TestElasticsearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngine
@pytest.mark.unit
def test_write_documents_req_for_each_batch(self, mocked_document_store, documents):
mocked_document_store.batch_size = 2
with patch("haystack.document_stores.elasticsearch.es7.bulk") as mocked_bulk:
with patch(f"{ElasticsearchDocumentStore.__module__}.bulk") as mocked_bulk:
mocked_document_store.write_documents(documents)
assert mocked_bulk.call_count == 5
# The following tests are overridden only to be able to skip them depending on ES version
@pytest.mark.skipif(VERSION[0] == 8, reason="Elasticsearch 8 is not supported")
@pytest.mark.unit
def test_get_all_documents_return_embedding_true(self, mocked_document_store):
super().test_get_all_documents_return_embedding_true(mocked_document_store)
@pytest.mark.skipif(VERSION[0] == 8, reason="Elasticsearch 8 is not supported")
@pytest.mark.unit
def test_get_all_documents_return_embedding_false(self, mocked_document_store):
super().test_get_all_documents_return_embedding_false(mocked_document_store)
@pytest.mark.skipif(VERSION[0] == 8, reason="Elasticsearch 8 is not supported")
@pytest.mark.unit
def test_get_all_documents_excluded_meta_data_has_no_influence(self, mocked_document_store):
super().test_get_all_documents_excluded_meta_data_has_no_influence(mocked_document_store)

View File

@ -2,7 +2,7 @@ from unittest.mock import MagicMock
import numpy as np
import pytest
from haystack.document_stores.search_engine import SearchEngineDocumentStore, prepare_hosts
from haystack.document_stores.search_engine import SearchEngineDocumentStore
@pytest.mark.unit