diff --git a/haystack/document_stores/elasticsearch.py b/haystack/document_stores/elasticsearch.py index 3d453d41a..9d7ce2dfe 100644 --- a/haystack/document_stores/elasticsearch.py +++ b/haystack/document_stores/elasticsearch.py @@ -56,6 +56,7 @@ class ElasticsearchDocumentStore(SearchEngineDocumentStore): synonyms: Optional[List] = None, synonym_type: str = "synonym", use_system_proxy: bool = False, + batch_size: int = 10_000, ): """ A DocumentStore using Elasticsearch to store and query the documents for our search. @@ -127,6 +128,8 @@ class ElasticsearchDocumentStore(SearchEngineDocumentStore): Synonym or Synonym_graph to handle synonyms, including multi-word synonyms correctly during the analysis process. More info at https://www.elastic.co/guide/en/elasticsearch/reference/current/analysis-synonym-graph-tokenfilter.html :param use_system_proxy: Whether to use system proxy. + :param batch_size: Number of Documents to index at once / Number of queries to execute at once. If you face + memory issues, decrease the batch_size. """ # Base constructor might need the client to be ready, create it first @@ -167,6 +170,7 @@ class ElasticsearchDocumentStore(SearchEngineDocumentStore): skip_missing_embeddings=skip_missing_embeddings, synonyms=synonyms, synonym_type=synonym_type, + batch_size=batch_size, ) # Let the base class trap the right exception from the elasticpy client diff --git a/haystack/document_stores/opensearch.py b/haystack/document_stores/opensearch.py index a3cd45261..6978438d8 100644 --- a/haystack/document_stores/opensearch.py +++ b/haystack/document_stores/opensearch.py @@ -74,6 +74,7 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore): knn_engine: str = "nmslib", knn_parameters: Optional[Dict] = None, ivf_train_size: Optional[int] = None, + batch_size: int = 10_000, ): """ Document Store using OpenSearch (https://opensearch.org/). It is compatible with the Amazon OpenSearch Service. @@ -165,6 +166,8 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore): index type and knn parameters). If `0`, training doesn't happen automatically but needs to be triggered manually via the `train_index` method. Default: `None` + :param batch_size: Number of Documents to index at once / Number of queries to execute at once. If you face + memory issues, decrease the batch_size. """ # These parameters aren't used by Opensearch at the moment but could be in the future, see # https://github.com/opensearch-project/security/issues/1504. Let's not deprecate them for @@ -243,6 +246,7 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore): skip_missing_embeddings=skip_missing_embeddings, synonyms=synonyms, synonym_type=synonym_type, + batch_size=batch_size, ) # Let the base class catch the right error from the Opensearch client @@ -321,7 +325,7 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore): self, documents: Union[List[dict], List[Document]], index: Optional[str] = None, - batch_size: int = 10_000, + batch_size: Optional[int] = None, duplicate_documents: Optional[str] = None, headers: Optional[Dict[str, str]] = None, ): @@ -358,6 +362,8 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore): if index is None: index = self.index + batch_size = batch_size or self.batch_size + if self.knn_engine == "faiss" and self.similarity == "cosine": field_map = self._create_document_field_map() documents = [Document.from_dict(d, field_map=field_map) if isinstance(d, dict) else d for d in documents] @@ -529,6 +535,7 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore): return_embedding: Optional[bool] = None, headers: Optional[Dict[str, str]] = None, scale_score: bool = True, + batch_size: Optional[int] = None, ) -> List[List[Document]]: """ Find the documents that are most similar to the provided `query_embs` by using a vector similarity metric. @@ -605,17 +612,19 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore): Check out https://www.elastic.co/guide/en/elasticsearch/reference/current/http-clients.html for more information. :param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]). If true (default) similarity scores (e.g. cosine or dot_product) which naturally have a different value range will be scaled to a range of [0,1], where 1 means extremely relevant. - Otherwise raw similarity scores (e.g. cosine or dot_product) will be used. - :return: + Otherwise, raw similarity scores (e.g. cosine or dot_product) will be used. + :param batch_size: Number of query embeddings to process at once. If not specified, self.batch_size is used. """ if index is None: index = self.index + batch_size = batch_size or self.batch_size + if self.index_type in ["ivf", "ivf_pq"] and not self._ivf_model_exists(index=index): self._ivf_index_not_trained_error(index=index, headers=headers) return super().query_by_embedding_batch( - query_embs, filters, top_k, index, return_embedding, headers, scale_score + query_embs, filters, top_k, index, return_embedding, headers, scale_score, batch_size ) def query( diff --git a/haystack/document_stores/search_engine.py b/haystack/document_stores/search_engine.py index d4acf0507..87edbb500 100644 --- a/haystack/document_stores/search_engine.py +++ b/haystack/document_stores/search_engine.py @@ -70,6 +70,7 @@ class SearchEngineDocumentStore(KeywordDocumentStore): skip_missing_embeddings: bool = True, synonyms: Optional[List] = None, synonym_type: str = "synonym", + batch_size: int = 10_000, ): super().__init__() @@ -98,6 +99,7 @@ class SearchEngineDocumentStore(KeywordDocumentStore): self.skip_missing_embeddings: bool = skip_missing_embeddings self.duplicate_documents = duplicate_documents self.refresh_type = refresh_type + self.batch_size = batch_size if similarity in ["cosine", "dot_product", "l2"]: self.similarity: str = similarity else: @@ -367,7 +369,7 @@ class SearchEngineDocumentStore(KeywordDocumentStore): self, documents: Union[List[dict], List[Document]], index: Optional[str] = None, - batch_size: int = 10_000, + batch_size: Optional[int] = None, duplicate_documents: Optional[str] = None, headers: Optional[Dict[str, str]] = None, ): @@ -390,6 +392,7 @@ class SearchEngineDocumentStore(KeywordDocumentStore): to what you have set for self.content_field and self.name_field. :param index: search index where the documents should be indexed. If you don't specify it, self.index is used. :param batch_size: Number of documents that are passed to the bulk function at each round. + If not specified, self.batch_size is used. :param duplicate_documents: Handle duplicate documents based on parameter options. Parameter options: ( 'skip','overwrite','fail') skip: Ignore the duplicate documents @@ -407,6 +410,9 @@ class SearchEngineDocumentStore(KeywordDocumentStore): if index is None: index = self.index + + batch_size = batch_size or self.batch_size + duplicate_documents = duplicate_documents or self.duplicate_documents assert ( duplicate_documents in self.duplicate_documents_options @@ -923,9 +929,10 @@ class SearchEngineDocumentStore(KeywordDocumentStore): headers: Optional[Dict[str, str]] = None, all_terms_must_match: bool = False, scale_score: bool = True, + batch_size: Optional[int] = None, ) -> List[List[Document]]: """ - Scan through documents in DocumentStore and return a small number documents + Scan through documents in DocumentStore and return a small number of documents that are most relevant to the provided queries as defined by keyword matching algorithms like BM25. This method lets you find relevant documents for list of query strings (output: List of Lists of Documents). @@ -1005,17 +1012,19 @@ class SearchEngineDocumentStore(KeywordDocumentStore): :param headers: Custom HTTP headers to pass to document store client if supported (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='} for basic authentication) :param all_terms_must_match: Whether all terms of the query must match the document. If true all query terms must be present in a document in order to be retrieved (i.e the AND operator is being used implicitly between query terms: "cozy fish restaurant" -> "cozy AND fish AND restaurant"). - Otherwise at least one query term must be present in a document in order to be retrieved (i.e the OR operator is being used implicitly between query terms: "cozy fish restaurant" -> "cozy OR fish OR restaurant"). + Otherwise, at least one query term must be present in a document in order to be retrieved (i.e the OR operator is being used implicitly between query terms: "cozy fish restaurant" -> "cozy OR fish OR restaurant"). Defaults to False. :param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]). If true (default) similarity scores (e.g. cosine or dot_product) which naturally have a different value range will be scaled to a range of [0,1], where 1 means extremely relevant. - Otherwise raw similarity scores (e.g. cosine or dot_product) will be used. + Otherwise, raw similarity scores (e.g. cosine or dot_product) will be used. + :param batch_size: Number of queries that are processed at once. If not specified, self.batch_size is used. """ if index is None: index = self.index if headers is None: headers = {} + batch_size = batch_size or self.batch_size if isinstance(filters, list): if len(filters) != len(queries): @@ -1027,6 +1036,7 @@ class SearchEngineDocumentStore(KeywordDocumentStore): filters = [filters] * len(queries) body = [] + all_documents = [] for query, cur_filters in zip(queries, filters): cur_query_body = self._construct_query_body( query=query, @@ -1038,17 +1048,27 @@ class SearchEngineDocumentStore(KeywordDocumentStore): body.append(headers) body.append(cur_query_body) - responses = self.client.msearch(index=index, body=body) + if len(body) == 2 * batch_size: + cur_documents = self._execute_msearch(index=index, body=body, scale_score=scale_score) + all_documents.extend(cur_documents) + body = [] - all_documents = [] - cur_documents = [] - for response in responses["responses"]: - cur_result = response["hits"]["hits"] - cur_documents = [self._convert_es_hit_to_document(hit, scale_score=scale_score) for hit in cur_result] - all_documents.append(cur_documents) + if len(body) > 0: + cur_documents = self._execute_msearch(index=index, body=body, scale_score=scale_score) + all_documents.extend(cur_documents) return all_documents + def _execute_msearch(self, index: str, body: List[Dict[str, Any]], scale_score: bool) -> List[List[Document]]: + responses = self.client.msearch(index=index, body=body) + documents = [] + for response in responses["responses"]: + result = response["hits"]["hits"] + cur_documents = [self._convert_es_hit_to_document(hit, scale_score=scale_score) for hit in result] + documents.append(cur_documents) + + return documents + def _construct_query_body( self, query: Optional[str], @@ -1188,6 +1208,7 @@ class SearchEngineDocumentStore(KeywordDocumentStore): return_embedding: Optional[bool] = None, headers: Optional[Dict[str, str]] = None, scale_score: bool = True, + batch_size: Optional[int] = None, ) -> List[List[Document]]: """ Find the documents that are most similar to the provided `query_embs` by using a vector similarity metric. @@ -1264,8 +1285,8 @@ class SearchEngineDocumentStore(KeywordDocumentStore): Check out https://www.elastic.co/guide/en/elasticsearch/reference/current/http-clients.html for more information. :param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]). If true (default) similarity scores (e.g. cosine or dot_product) which naturally have a different value range will be scaled to a range of [0,1], where 1 means extremely relevant. - Otherwise raw similarity scores (e.g. cosine or dot_product) will be used. - :return: + Otherwise, raw similarity scores (e.g. cosine or dot_product) will be used. + :param batch_size: Number of query embeddings to process at once. If not specified, self.batch_size is used. """ if index is None: index = self.index @@ -1276,6 +1297,8 @@ class SearchEngineDocumentStore(KeywordDocumentStore): if headers is None: headers = {} + batch_size = batch_size or self.batch_size + if not self.embedding_field: raise DocumentStoreError("Please set a valid `embedding_field` for OpenSearchDocumentStore") @@ -1289,25 +1312,24 @@ class SearchEngineDocumentStore(KeywordDocumentStore): filters = [filters] * len(query_embs) if filters is not None else [{}] * len(query_embs) body = [] - for query_emb, cur_filters in zip(query_embs, filters): + all_documents = [] + for query_emb, cur_filters in tqdm(zip(query_embs, filters)): cur_query_body = self._construct_dense_query_body( query_emb=query_emb, filters=cur_filters, top_k=top_k, return_embedding=return_embedding ) body.append(headers) body.append(cur_query_body) - logger.debug("Retriever query: %s", body) - responses = self.client.msearch(index=index, body=body) + if len(body) >= batch_size * 2: + logger.debug("Retriever query: %s", body) + cur_documents = self._execute_msearch(index=index, body=body, scale_score=scale_score) + all_documents.extend(cur_documents) + body = [] - all_documents = [] - cur_documents = [] - for response in responses["responses"]: - cur_result = response["hits"]["hits"] - cur_documents = [ - self._convert_es_hit_to_document(hit, adapt_score_for_embedding=True, scale_score=scale_score) - for hit in cur_result - ] - all_documents.append(cur_documents) + if len(body) > 0: + logger.debug("Retriever query: %s", body) + cur_documents = self._execute_msearch(index=index, body=body, scale_score=scale_score) + all_documents.extend(cur_documents) return all_documents @@ -1323,7 +1345,7 @@ class SearchEngineDocumentStore(KeywordDocumentStore): index: Optional[str] = None, filters: Optional[FilterType] = None, update_existing_embeddings: bool = True, - batch_size: int = 10_000, + batch_size: Optional[int] = None, headers: Optional[Dict[str, str]] = None, ): """ @@ -1370,6 +1392,8 @@ class SearchEngineDocumentStore(KeywordDocumentStore): if index is None: index = self.index + batch_size = batch_size or self.batch_size + if self.refresh_type == "false": self.client.indices.refresh(index=index, headers=headers) diff --git a/test/document_stores/test_elasticsearch.py b/test/document_stores/test_elasticsearch.py index 2398d4758..27b2588dc 100644 --- a/test/document_stores/test_elasticsearch.py +++ b/test/document_stores/test_elasticsearch.py @@ -1,6 +1,6 @@ import logging import os -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import numpy as np import pytest @@ -344,3 +344,10 @@ class TestElasticsearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngine # assert the resulting body is not affected by the `excluded_meta_data` value _, kwargs = mocked_document_store.client.search.call_args assert kwargs["_source"] == {"excludes": ["embedding"]} + + @pytest.mark.unit + def test_write_documents_req_for_each_batch(self, mocked_document_store, documents): + mocked_document_store.batch_size = 2 + with patch("haystack.document_stores.elasticsearch.bulk") as mocked_bulk: + mocked_document_store.write_documents(documents) + assert mocked_bulk.call_count == 5 diff --git a/test/document_stores/test_opensearch.py b/test/document_stores/test_opensearch.py index 47bc06c31..628a7f26b 100644 --- a/test/document_stores/test_opensearch.py +++ b/test/document_stores/test_opensearch.py @@ -1291,3 +1291,10 @@ class TestOpenSearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngineDoc # assert the resulting body is not affected by the `excluded_meta_data` value _, kwargs = mocked_document_store.client.search.call_args assert kwargs["body"]["_source"] == {"excludes": ["embedding"]} + + @pytest.mark.unit + def test_write_documents_req_for_each_batch(self, mocked_document_store, documents): + mocked_document_store.batch_size = 2 + with patch("haystack.document_stores.opensearch.bulk") as mocked_bulk: + mocked_document_store.write_documents(documents) + assert mocked_bulk.call_count == 5 diff --git a/test/document_stores/test_search_engine.py b/test/document_stores/test_search_engine.py index e83adad24..a99d03c44 100644 --- a/test/document_stores/test_search_engine.py +++ b/test/document_stores/test_search_engine.py @@ -1,4 +1,6 @@ from unittest.mock import MagicMock + +import numpy as np import pytest from haystack.document_stores.search_engine import SearchEngineDocumentStore, prepare_hosts @@ -167,6 +169,18 @@ class SearchEngineDocumentStoreTestAbstract: labels = mocked_document_store.get_all_labels() assert labels[0].answer.document_ids == ["fc18c987a8312e72a47fb1524f230bb0"] + @pytest.mark.unit + def test_query_batch_req_for_each_batch(self, mocked_document_store): + mocked_document_store.batch_size = 2 + mocked_document_store.query_batch([self.query] * 3) + assert mocked_document_store.client.msearch.call_count == 2 + + @pytest.mark.unit + def test_query_by_embedding_batch_req_for_each_batch(self, mocked_document_store): + mocked_document_store.batch_size = 2 + mocked_document_store.query_by_embedding_batch([np.array([1, 2, 3])] * 3) + assert mocked_document_store.client.msearch.call_count == 2 + @pytest.mark.document_store class TestSearchEngineDocumentStore: