diff --git a/haystack/document_stores/elasticsearch.py b/haystack/document_stores/elasticsearch.py index 673ef6ce6..b24ea44b6 100644 --- a/haystack/document_stores/elasticsearch.py +++ b/haystack/document_stores/elasticsearch.py @@ -231,6 +231,10 @@ class ElasticsearchDocumentStore(SearchEngineDocumentStore): elif aws4auth: # aws elasticsearch with IAM # see https://elasticsearch-py.readthedocs.io/en/v7.12.0/index.html?highlight=http_auth#running-on-aws-with-iam + if username: + logger.warning( + "aws4auth and a username are passed to the ElasticsearchDocumentStore. The username will be ignored and aws4auth will be used for authentication." + ) client = Elasticsearch( hosts=hosts, http_auth=aws4auth, diff --git a/haystack/document_stores/opensearch.py b/haystack/document_stores/opensearch.py index f1099089c..6adf391cc 100644 --- a/haystack/document_stores/opensearch.py +++ b/haystack/document_stores/opensearch.py @@ -238,7 +238,22 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore): if use_system_proxy: connection_class = RequestsHttpConnection # type: ignore [assignment] - if username: + if aws4auth: + # Sign requests to Opensearch with IAM credentials + # see https://docs.aws.amazon.com/opensearch-service/latest/developerguide/request-signing.html#request-signing-python + if username: + logger.warning( + "aws4auth and a username or the default username 'admin' are passed to the OpenSearchDocumentStore. The username will be ignored and aws4auth will be used for authentication." + ) + client = OpenSearch( + hosts=hosts, + http_auth=aws4auth, + connection_class=RequestsHttpConnection, + use_ssl=True, + verify_certs=True, + timeout=timeout, + ) + elif username: # standard http_auth client = OpenSearch( hosts=hosts, @@ -249,17 +264,6 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore): timeout=timeout, connection_class=connection_class, ) - elif aws4auth: - # Sign requests to Opensearch with IAM credentials - # see https://docs.aws.amazon.com/opensearch-service/latest/developerguide/request-signing.html#request-signing-python - client = OpenSearch( - hosts=hosts, - http_auth=aws4auth, - connection_class=RequestsHttpConnection, - use_ssl=True, - verify_certs=True, - timeout=timeout, - ) else: # no authentication needed client = OpenSearch( diff --git a/test/document_stores/test_elasticsearch.py b/test/document_stores/test_elasticsearch.py index cca979eff..5b9ee1577 100644 --- a/test/document_stores/test_elasticsearch.py +++ b/test/document_stores/test_elasticsearch.py @@ -1,11 +1,12 @@ +import logging import os from unittest.mock import MagicMock -import pytest import numpy as np +import pytest +from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore, Elasticsearch from haystack.schema import Document -from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore from .test_base import DocumentStoreBaseTestAbstract from .test_search_engine import SearchEngineDocumentStoreTestAbstract @@ -32,6 +33,18 @@ class TestElasticsearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngine ds.delete_index(self.index_name) ds.delete_index(labels_index_name) + @pytest.fixture + def mocked_elastic_search_init(self, monkeypatch): + mocked_init = MagicMock(return_value=None) + monkeypatch.setattr(Elasticsearch, "__init__", mocked_init) + return mocked_init + + @pytest.fixture + def mocked_elastic_search_ping(self, monkeypatch): + mocked_ping = MagicMock(return_value=True) + monkeypatch.setattr(Elasticsearch, "ping", mocked_ping) + return mocked_ping + @pytest.fixture def mocked_document_store(self): """ @@ -239,3 +252,38 @@ class TestElasticsearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngine assert ds.get_document_count(only_documents_without_embedding=True) == 3 assert ds.get_document_count(only_documents_without_embedding=True, filters={"month": ["01"]}) == 0 assert ds.get_document_count(only_documents_without_embedding=True, filters={"month": ["03"]}) == 3 + + @pytest.mark.unit + def test__init_elastic_client_aws4auth_and_username_raises_warning( + self, caplog, mocked_elastic_search_init, mocked_elastic_search_ping + ): + _init_client_remaining_kwargs = { + "host": "host", + "port": 443, + "password": "pass", + "api_key_id": None, + "api_key": None, + "scheme": "https", + "ca_certs": None, + "verify_certs": True, + "timeout": 10, + "use_system_proxy": False, + } + + with caplog.at_level(logging.WARN, logger="haystack.document_stores.elasticsearch"): + ElasticsearchDocumentStore._init_elastic_client( + username="admin", aws4auth="foo", **_init_client_remaining_kwargs + ) + assert len(caplog.records) == 1 + for r in caplog.records: + assert r.levelname == "WARNING" + + caplog.clear() + with caplog.at_level(logging.WARN, logger="haystack.document_stores.elasticsearch"): + ElasticsearchDocumentStore._init_elastic_client( + username=None, aws4auth="foo", **_init_client_remaining_kwargs + ) + ElasticsearchDocumentStore._init_elastic_client( + username="", aws4auth="foo", **_init_client_remaining_kwargs + ) + assert len(caplog.records) == 0 diff --git a/test/document_stores/test_opensearch.py b/test/document_stores/test_opensearch.py index 53a1d97a5..f043536be 100644 --- a/test/document_stores/test_opensearch.py +++ b/test/document_stores/test_opensearch.py @@ -199,6 +199,32 @@ class TestOpenSearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngineDoc for r in caplog.records: assert r.levelname == "WARNING" + @pytest.mark.unit + def test__init_client_aws4auth_and_username_raises_warning(self, mocked_open_search_init, caplog): + _init_client_remaining_kwargs = { + "host": "host", + "port": 443, + "password": "pass", + "scheme": "https", + "ca_certs": None, + "verify_certs": True, + "timeout": 10, + "use_system_proxy": False, + } + + with caplog.at_level(logging.WARN, logger="haystack.document_stores.opensearch"): + OpenSearchDocumentStore._init_client(username="admin", aws4auth="foo", **_init_client_remaining_kwargs) + OpenSearchDocumentStore._init_client(username="bar", aws4auth="foo", **_init_client_remaining_kwargs) + assert len(caplog.records) == 2 + for r in caplog.records: + assert r.levelname == "WARNING" + + caplog.clear() + with caplog.at_level(logging.WARN, logger="haystack.document_stores.opensearch"): + OpenSearchDocumentStore._init_client(username=None, aws4auth="foo", **_init_client_remaining_kwargs) + OpenSearchDocumentStore._init_client(username="foo", aws4auth=None, **_init_client_remaining_kwargs) + assert len(caplog.records) == 0 + @pytest.mark.unit def test___init___connection_test_fails(self, mocked_document_store): failing_client = MagicMock()