diff --git a/haystack/document_stores/elasticsearch/es7.py b/haystack/document_stores/elasticsearch/es7.py index 0d66ae96a..455861332 100644 --- a/haystack/document_stores/elasticsearch/es7.py +++ b/haystack/document_stores/elasticsearch/es7.py @@ -125,10 +125,12 @@ class ElasticsearchDocumentStore(_ElasticsearchDocumentStore): memory issues, decrease the batch_size. """ + # Ensure all the required inputs were successful es_import.check() - - # Base constructor might need the client to be ready, create it first - client = self._init_elastic_client( + # Let the base class trap the right exception from the specific client + self._RequestError = RequestError + # Initiate the Elasticsearch client for version 7.x + client = ElasticsearchDocumentStore._init_elastic_client( host=host, port=port, username=username, @@ -168,9 +170,6 @@ class ElasticsearchDocumentStore(_ElasticsearchDocumentStore): batch_size=batch_size, ) - # Let the base class trap the right exception from the elasticpy client - self._RequestError = RequestError - def _do_bulk(self, *args, **kwargs): """Override the base class method to use the Elasticsearch client""" return bulk(*args, **kwargs) @@ -179,9 +178,8 @@ class ElasticsearchDocumentStore(_ElasticsearchDocumentStore): """Override the base class method to use the Elasticsearch client""" return scan(*args, **kwargs) - @classmethod + @staticmethod def _init_elastic_client( - cls, host: Union[str, List[str]], port: Union[int, List[int]], username: str, diff --git a/haystack/document_stores/elasticsearch/es8.py b/haystack/document_stores/elasticsearch/es8.py index 5852da393..438fb9346 100644 --- a/haystack/document_stores/elasticsearch/es8.py +++ b/haystack/document_stores/elasticsearch/es8.py @@ -141,10 +141,12 @@ class ElasticsearchDocumentStore(_ElasticsearchDocumentStore): memory issues, decrease the batch_size. """ + # Ensure all the required inputs were successful es_import.check() - - # Base constructor might need the client to be ready, create it first - client = self._init_elastic_client( + # Let the base class trap the right exception from the specific client + self._RequestError = RequestError + # Initiate the Elasticsearch client for version 8.x + client = ElasticsearchDocumentStore._init_elastic_client( host=host, port=port, username=username, @@ -183,9 +185,6 @@ class ElasticsearchDocumentStore(_ElasticsearchDocumentStore): batch_size=batch_size, ) - # Let the base class trap the right exception from the elasticpy client - self._RequestError = RequestError - def _do_bulk(self, *args, **kwargs): """Override the base class method to use the Elasticsearch client""" return bulk(*args, **kwargs) @@ -194,9 +193,8 @@ class ElasticsearchDocumentStore(_ElasticsearchDocumentStore): """Override the base class method to use the Elasticsearch client""" return scan(*args, **kwargs) - @classmethod + @staticmethod def _init_elastic_client( - cls, host: Union[str, List[str]], port: Union[int, List[int]], username: str, diff --git a/test/document_stores/test_elasticsearch.py b/test/document_stores/test_elasticsearch.py index edf9a310c..9a0628cec 100644 --- a/test/document_stores/test_elasticsearch.py +++ b/test/document_stores/test_elasticsearch.py @@ -56,13 +56,14 @@ class TestElasticsearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngine ElasticsearchDocumentStore equipped with a mocked client """ - class DSMock(ElasticsearchDocumentStore): - # We mock a subclass to avoid messing up the actual class object - pass + with patch(f"{ElasticsearchDocumentStore.__module__}.ElasticsearchDocumentStore._init_elastic_client"): - DSMock._init_elastic_client = MagicMock() - DSMock.client = MagicMock() - return DSMock() + class DSMock(ElasticsearchDocumentStore): + # We mock a subclass to avoid messing up the actual class object + pass + + DSMock.client = MagicMock() + yield DSMock() @pytest.mark.integration def test___init__(self):