fix: authenticate with aws4auth if set in OpenSearchDocumentStore (#3741)

* bug(OpenSearchDocumentStore): fix authenticate with aws4auth if set.

Rearrange check to authenticate with aws4auth before username
and password, as the username is set to "admin" by default.

* Make username check less restrictive

* Fix test, do not used mocked _init_client function

* Add warning for aws4auth and username to ElasticSearchDocumentStore

Co-authored-by: Julian Risch <julian.risch@deepset.ai>
This commit is contained in:
Fabian 2023-01-24 10:01:39 +01:00 committed by GitHub
parent e954230ae7
commit 61ebe4b5dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 96 additions and 14 deletions

View File

@ -231,6 +231,10 @@ class ElasticsearchDocumentStore(SearchEngineDocumentStore):
elif aws4auth:
# aws elasticsearch with IAM
# see https://elasticsearch-py.readthedocs.io/en/v7.12.0/index.html?highlight=http_auth#running-on-aws-with-iam
if username:
logger.warning(
"aws4auth and a username are passed to the ElasticsearchDocumentStore. The username will be ignored and aws4auth will be used for authentication."
)
client = Elasticsearch(
hosts=hosts,
http_auth=aws4auth,

View File

@ -238,7 +238,22 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
if use_system_proxy:
connection_class = RequestsHttpConnection # type: ignore [assignment]
if username:
if aws4auth:
# Sign requests to Opensearch with IAM credentials
# see https://docs.aws.amazon.com/opensearch-service/latest/developerguide/request-signing.html#request-signing-python
if username:
logger.warning(
"aws4auth and a username or the default username 'admin' are passed to the OpenSearchDocumentStore. The username will be ignored and aws4auth will be used for authentication."
)
client = OpenSearch(
hosts=hosts,
http_auth=aws4auth,
connection_class=RequestsHttpConnection,
use_ssl=True,
verify_certs=True,
timeout=timeout,
)
elif username:
# standard http_auth
client = OpenSearch(
hosts=hosts,
@ -249,17 +264,6 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
timeout=timeout,
connection_class=connection_class,
)
elif aws4auth:
# Sign requests to Opensearch with IAM credentials
# see https://docs.aws.amazon.com/opensearch-service/latest/developerguide/request-signing.html#request-signing-python
client = OpenSearch(
hosts=hosts,
http_auth=aws4auth,
connection_class=RequestsHttpConnection,
use_ssl=True,
verify_certs=True,
timeout=timeout,
)
else:
# no authentication needed
client = OpenSearch(

View File

@ -1,11 +1,12 @@
import logging
import os
from unittest.mock import MagicMock
import pytest
import numpy as np
import pytest
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore, Elasticsearch
from haystack.schema import Document
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore
from .test_base import DocumentStoreBaseTestAbstract
from .test_search_engine import SearchEngineDocumentStoreTestAbstract
@ -32,6 +33,18 @@ class TestElasticsearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngine
ds.delete_index(self.index_name)
ds.delete_index(labels_index_name)
@pytest.fixture
def mocked_elastic_search_init(self, monkeypatch):
mocked_init = MagicMock(return_value=None)
monkeypatch.setattr(Elasticsearch, "__init__", mocked_init)
return mocked_init
@pytest.fixture
def mocked_elastic_search_ping(self, monkeypatch):
mocked_ping = MagicMock(return_value=True)
monkeypatch.setattr(Elasticsearch, "ping", mocked_ping)
return mocked_ping
@pytest.fixture
def mocked_document_store(self):
"""
@ -239,3 +252,38 @@ class TestElasticsearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngine
assert ds.get_document_count(only_documents_without_embedding=True) == 3
assert ds.get_document_count(only_documents_without_embedding=True, filters={"month": ["01"]}) == 0
assert ds.get_document_count(only_documents_without_embedding=True, filters={"month": ["03"]}) == 3
@pytest.mark.unit
def test__init_elastic_client_aws4auth_and_username_raises_warning(
self, caplog, mocked_elastic_search_init, mocked_elastic_search_ping
):
_init_client_remaining_kwargs = {
"host": "host",
"port": 443,
"password": "pass",
"api_key_id": None,
"api_key": None,
"scheme": "https",
"ca_certs": None,
"verify_certs": True,
"timeout": 10,
"use_system_proxy": False,
}
with caplog.at_level(logging.WARN, logger="haystack.document_stores.elasticsearch"):
ElasticsearchDocumentStore._init_elastic_client(
username="admin", aws4auth="foo", **_init_client_remaining_kwargs
)
assert len(caplog.records) == 1
for r in caplog.records:
assert r.levelname == "WARNING"
caplog.clear()
with caplog.at_level(logging.WARN, logger="haystack.document_stores.elasticsearch"):
ElasticsearchDocumentStore._init_elastic_client(
username=None, aws4auth="foo", **_init_client_remaining_kwargs
)
ElasticsearchDocumentStore._init_elastic_client(
username="", aws4auth="foo", **_init_client_remaining_kwargs
)
assert len(caplog.records) == 0

View File

@ -199,6 +199,32 @@ class TestOpenSearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngineDoc
for r in caplog.records:
assert r.levelname == "WARNING"
@pytest.mark.unit
def test__init_client_aws4auth_and_username_raises_warning(self, mocked_open_search_init, caplog):
_init_client_remaining_kwargs = {
"host": "host",
"port": 443,
"password": "pass",
"scheme": "https",
"ca_certs": None,
"verify_certs": True,
"timeout": 10,
"use_system_proxy": False,
}
with caplog.at_level(logging.WARN, logger="haystack.document_stores.opensearch"):
OpenSearchDocumentStore._init_client(username="admin", aws4auth="foo", **_init_client_remaining_kwargs)
OpenSearchDocumentStore._init_client(username="bar", aws4auth="foo", **_init_client_remaining_kwargs)
assert len(caplog.records) == 2
for r in caplog.records:
assert r.levelname == "WARNING"
caplog.clear()
with caplog.at_level(logging.WARN, logger="haystack.document_stores.opensearch"):
OpenSearchDocumentStore._init_client(username=None, aws4auth="foo", **_init_client_remaining_kwargs)
OpenSearchDocumentStore._init_client(username="foo", aws4auth=None, **_init_client_remaining_kwargs)
assert len(caplog.records) == 0
@pytest.mark.unit
def test___init___connection_test_fails(self, mocked_document_store):
failing_client = MagicMock()