mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-13 16:43:44 +00:00
feat: Add batching for querying in ElasticsearchDocumentStore and OpenSearchDocumentStore (#5063)
* Include benchmark config in output * Use queries from aggregated labels * Introduce batching for querying in ElasticsearchDocStore and OpenSearchDocStore * Fix mypy * Use self.batch_size in write_documents * Use 10_000 as default batch size * Add unit tests for write documents
This commit is contained in:
parent
c3e59914da
commit
a9a49e2c0a
@ -56,6 +56,7 @@ class ElasticsearchDocumentStore(SearchEngineDocumentStore):
|
|||||||
synonyms: Optional[List] = None,
|
synonyms: Optional[List] = None,
|
||||||
synonym_type: str = "synonym",
|
synonym_type: str = "synonym",
|
||||||
use_system_proxy: bool = False,
|
use_system_proxy: bool = False,
|
||||||
|
batch_size: int = 10_000,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
A DocumentStore using Elasticsearch to store and query the documents for our search.
|
A DocumentStore using Elasticsearch to store and query the documents for our search.
|
||||||
@ -127,6 +128,8 @@ class ElasticsearchDocumentStore(SearchEngineDocumentStore):
|
|||||||
Synonym or Synonym_graph to handle synonyms, including multi-word synonyms correctly during the analysis process.
|
Synonym or Synonym_graph to handle synonyms, including multi-word synonyms correctly during the analysis process.
|
||||||
More info at https://www.elastic.co/guide/en/elasticsearch/reference/current/analysis-synonym-graph-tokenfilter.html
|
More info at https://www.elastic.co/guide/en/elasticsearch/reference/current/analysis-synonym-graph-tokenfilter.html
|
||||||
:param use_system_proxy: Whether to use system proxy.
|
:param use_system_proxy: Whether to use system proxy.
|
||||||
|
:param batch_size: Number of Documents to index at once / Number of queries to execute at once. If you face
|
||||||
|
memory issues, decrease the batch_size.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# Base constructor might need the client to be ready, create it first
|
# Base constructor might need the client to be ready, create it first
|
||||||
@ -167,6 +170,7 @@ class ElasticsearchDocumentStore(SearchEngineDocumentStore):
|
|||||||
skip_missing_embeddings=skip_missing_embeddings,
|
skip_missing_embeddings=skip_missing_embeddings,
|
||||||
synonyms=synonyms,
|
synonyms=synonyms,
|
||||||
synonym_type=synonym_type,
|
synonym_type=synonym_type,
|
||||||
|
batch_size=batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Let the base class trap the right exception from the elasticpy client
|
# Let the base class trap the right exception from the elasticpy client
|
||||||
|
|||||||
@ -74,6 +74,7 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
|
|||||||
knn_engine: str = "nmslib",
|
knn_engine: str = "nmslib",
|
||||||
knn_parameters: Optional[Dict] = None,
|
knn_parameters: Optional[Dict] = None,
|
||||||
ivf_train_size: Optional[int] = None,
|
ivf_train_size: Optional[int] = None,
|
||||||
|
batch_size: int = 10_000,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Document Store using OpenSearch (https://opensearch.org/). It is compatible with the Amazon OpenSearch Service.
|
Document Store using OpenSearch (https://opensearch.org/). It is compatible with the Amazon OpenSearch Service.
|
||||||
@ -165,6 +166,8 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
|
|||||||
index type and knn parameters). If `0`, training doesn't happen automatically but needs
|
index type and knn parameters). If `0`, training doesn't happen automatically but needs
|
||||||
to be triggered manually via the `train_index` method.
|
to be triggered manually via the `train_index` method.
|
||||||
Default: `None`
|
Default: `None`
|
||||||
|
:param batch_size: Number of Documents to index at once / Number of queries to execute at once. If you face
|
||||||
|
memory issues, decrease the batch_size.
|
||||||
"""
|
"""
|
||||||
# These parameters aren't used by Opensearch at the moment but could be in the future, see
|
# These parameters aren't used by Opensearch at the moment but could be in the future, see
|
||||||
# https://github.com/opensearch-project/security/issues/1504. Let's not deprecate them for
|
# https://github.com/opensearch-project/security/issues/1504. Let's not deprecate them for
|
||||||
@ -243,6 +246,7 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
|
|||||||
skip_missing_embeddings=skip_missing_embeddings,
|
skip_missing_embeddings=skip_missing_embeddings,
|
||||||
synonyms=synonyms,
|
synonyms=synonyms,
|
||||||
synonym_type=synonym_type,
|
synonym_type=synonym_type,
|
||||||
|
batch_size=batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Let the base class catch the right error from the Opensearch client
|
# Let the base class catch the right error from the Opensearch client
|
||||||
@ -321,7 +325,7 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
|
|||||||
self,
|
self,
|
||||||
documents: Union[List[dict], List[Document]],
|
documents: Union[List[dict], List[Document]],
|
||||||
index: Optional[str] = None,
|
index: Optional[str] = None,
|
||||||
batch_size: int = 10_000,
|
batch_size: Optional[int] = None,
|
||||||
duplicate_documents: Optional[str] = None,
|
duplicate_documents: Optional[str] = None,
|
||||||
headers: Optional[Dict[str, str]] = None,
|
headers: Optional[Dict[str, str]] = None,
|
||||||
):
|
):
|
||||||
@ -358,6 +362,8 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
|
|||||||
if index is None:
|
if index is None:
|
||||||
index = self.index
|
index = self.index
|
||||||
|
|
||||||
|
batch_size = batch_size or self.batch_size
|
||||||
|
|
||||||
if self.knn_engine == "faiss" and self.similarity == "cosine":
|
if self.knn_engine == "faiss" and self.similarity == "cosine":
|
||||||
field_map = self._create_document_field_map()
|
field_map = self._create_document_field_map()
|
||||||
documents = [Document.from_dict(d, field_map=field_map) if isinstance(d, dict) else d for d in documents]
|
documents = [Document.from_dict(d, field_map=field_map) if isinstance(d, dict) else d for d in documents]
|
||||||
@ -529,6 +535,7 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
|
|||||||
return_embedding: Optional[bool] = None,
|
return_embedding: Optional[bool] = None,
|
||||||
headers: Optional[Dict[str, str]] = None,
|
headers: Optional[Dict[str, str]] = None,
|
||||||
scale_score: bool = True,
|
scale_score: bool = True,
|
||||||
|
batch_size: Optional[int] = None,
|
||||||
) -> List[List[Document]]:
|
) -> List[List[Document]]:
|
||||||
"""
|
"""
|
||||||
Find the documents that are most similar to the provided `query_embs` by using a vector similarity metric.
|
Find the documents that are most similar to the provided `query_embs` by using a vector similarity metric.
|
||||||
@ -605,17 +612,19 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
|
|||||||
Check out https://www.elastic.co/guide/en/elasticsearch/reference/current/http-clients.html for more information.
|
Check out https://www.elastic.co/guide/en/elasticsearch/reference/current/http-clients.html for more information.
|
||||||
:param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]).
|
:param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]).
|
||||||
If true (default) similarity scores (e.g. cosine or dot_product) which naturally have a different value range will be scaled to a range of [0,1], where 1 means extremely relevant.
|
If true (default) similarity scores (e.g. cosine or dot_product) which naturally have a different value range will be scaled to a range of [0,1], where 1 means extremely relevant.
|
||||||
Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
|
Otherwise, raw similarity scores (e.g. cosine or dot_product) will be used.
|
||||||
:return:
|
:param batch_size: Number of query embeddings to process at once. If not specified, self.batch_size is used.
|
||||||
"""
|
"""
|
||||||
if index is None:
|
if index is None:
|
||||||
index = self.index
|
index = self.index
|
||||||
|
|
||||||
|
batch_size = batch_size or self.batch_size
|
||||||
|
|
||||||
if self.index_type in ["ivf", "ivf_pq"] and not self._ivf_model_exists(index=index):
|
if self.index_type in ["ivf", "ivf_pq"] and not self._ivf_model_exists(index=index):
|
||||||
self._ivf_index_not_trained_error(index=index, headers=headers)
|
self._ivf_index_not_trained_error(index=index, headers=headers)
|
||||||
|
|
||||||
return super().query_by_embedding_batch(
|
return super().query_by_embedding_batch(
|
||||||
query_embs, filters, top_k, index, return_embedding, headers, scale_score
|
query_embs, filters, top_k, index, return_embedding, headers, scale_score, batch_size
|
||||||
)
|
)
|
||||||
|
|
||||||
def query(
|
def query(
|
||||||
|
|||||||
@ -70,6 +70,7 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
|||||||
skip_missing_embeddings: bool = True,
|
skip_missing_embeddings: bool = True,
|
||||||
synonyms: Optional[List] = None,
|
synonyms: Optional[List] = None,
|
||||||
synonym_type: str = "synonym",
|
synonym_type: str = "synonym",
|
||||||
|
batch_size: int = 10_000,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -98,6 +99,7 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
|||||||
self.skip_missing_embeddings: bool = skip_missing_embeddings
|
self.skip_missing_embeddings: bool = skip_missing_embeddings
|
||||||
self.duplicate_documents = duplicate_documents
|
self.duplicate_documents = duplicate_documents
|
||||||
self.refresh_type = refresh_type
|
self.refresh_type = refresh_type
|
||||||
|
self.batch_size = batch_size
|
||||||
if similarity in ["cosine", "dot_product", "l2"]:
|
if similarity in ["cosine", "dot_product", "l2"]:
|
||||||
self.similarity: str = similarity
|
self.similarity: str = similarity
|
||||||
else:
|
else:
|
||||||
@ -367,7 +369,7 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
|||||||
self,
|
self,
|
||||||
documents: Union[List[dict], List[Document]],
|
documents: Union[List[dict], List[Document]],
|
||||||
index: Optional[str] = None,
|
index: Optional[str] = None,
|
||||||
batch_size: int = 10_000,
|
batch_size: Optional[int] = None,
|
||||||
duplicate_documents: Optional[str] = None,
|
duplicate_documents: Optional[str] = None,
|
||||||
headers: Optional[Dict[str, str]] = None,
|
headers: Optional[Dict[str, str]] = None,
|
||||||
):
|
):
|
||||||
@ -390,6 +392,7 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
|||||||
to what you have set for self.content_field and self.name_field.
|
to what you have set for self.content_field and self.name_field.
|
||||||
:param index: search index where the documents should be indexed. If you don't specify it, self.index is used.
|
:param index: search index where the documents should be indexed. If you don't specify it, self.index is used.
|
||||||
:param batch_size: Number of documents that are passed to the bulk function at each round.
|
:param batch_size: Number of documents that are passed to the bulk function at each round.
|
||||||
|
If not specified, self.batch_size is used.
|
||||||
:param duplicate_documents: Handle duplicate documents based on parameter options.
|
:param duplicate_documents: Handle duplicate documents based on parameter options.
|
||||||
Parameter options: ( 'skip','overwrite','fail')
|
Parameter options: ( 'skip','overwrite','fail')
|
||||||
skip: Ignore the duplicate documents
|
skip: Ignore the duplicate documents
|
||||||
@ -407,6 +410,9 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
|||||||
|
|
||||||
if index is None:
|
if index is None:
|
||||||
index = self.index
|
index = self.index
|
||||||
|
|
||||||
|
batch_size = batch_size or self.batch_size
|
||||||
|
|
||||||
duplicate_documents = duplicate_documents or self.duplicate_documents
|
duplicate_documents = duplicate_documents or self.duplicate_documents
|
||||||
assert (
|
assert (
|
||||||
duplicate_documents in self.duplicate_documents_options
|
duplicate_documents in self.duplicate_documents_options
|
||||||
@ -923,9 +929,10 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
|||||||
headers: Optional[Dict[str, str]] = None,
|
headers: Optional[Dict[str, str]] = None,
|
||||||
all_terms_must_match: bool = False,
|
all_terms_must_match: bool = False,
|
||||||
scale_score: bool = True,
|
scale_score: bool = True,
|
||||||
|
batch_size: Optional[int] = None,
|
||||||
) -> List[List[Document]]:
|
) -> List[List[Document]]:
|
||||||
"""
|
"""
|
||||||
Scan through documents in DocumentStore and return a small number documents
|
Scan through documents in DocumentStore and return a small number of documents
|
||||||
that are most relevant to the provided queries as defined by keyword matching algorithms like BM25.
|
that are most relevant to the provided queries as defined by keyword matching algorithms like BM25.
|
||||||
|
|
||||||
This method lets you find relevant documents for list of query strings (output: List of Lists of Documents).
|
This method lets you find relevant documents for list of query strings (output: List of Lists of Documents).
|
||||||
@ -1005,17 +1012,19 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
|||||||
:param headers: Custom HTTP headers to pass to document store client if supported (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='} for basic authentication)
|
:param headers: Custom HTTP headers to pass to document store client if supported (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='} for basic authentication)
|
||||||
:param all_terms_must_match: Whether all terms of the query must match the document.
|
:param all_terms_must_match: Whether all terms of the query must match the document.
|
||||||
If true all query terms must be present in a document in order to be retrieved (i.e the AND operator is being used implicitly between query terms: "cozy fish restaurant" -> "cozy AND fish AND restaurant").
|
If true all query terms must be present in a document in order to be retrieved (i.e the AND operator is being used implicitly between query terms: "cozy fish restaurant" -> "cozy AND fish AND restaurant").
|
||||||
Otherwise at least one query term must be present in a document in order to be retrieved (i.e the OR operator is being used implicitly between query terms: "cozy fish restaurant" -> "cozy OR fish OR restaurant").
|
Otherwise, at least one query term must be present in a document in order to be retrieved (i.e the OR operator is being used implicitly between query terms: "cozy fish restaurant" -> "cozy OR fish OR restaurant").
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
:param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]).
|
:param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]).
|
||||||
If true (default) similarity scores (e.g. cosine or dot_product) which naturally have a different value range will be scaled to a range of [0,1], where 1 means extremely relevant.
|
If true (default) similarity scores (e.g. cosine or dot_product) which naturally have a different value range will be scaled to a range of [0,1], where 1 means extremely relevant.
|
||||||
Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
|
Otherwise, raw similarity scores (e.g. cosine or dot_product) will be used.
|
||||||
|
:param batch_size: Number of queries that are processed at once. If not specified, self.batch_size is used.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if index is None:
|
if index is None:
|
||||||
index = self.index
|
index = self.index
|
||||||
if headers is None:
|
if headers is None:
|
||||||
headers = {}
|
headers = {}
|
||||||
|
batch_size = batch_size or self.batch_size
|
||||||
|
|
||||||
if isinstance(filters, list):
|
if isinstance(filters, list):
|
||||||
if len(filters) != len(queries):
|
if len(filters) != len(queries):
|
||||||
@ -1027,6 +1036,7 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
|||||||
filters = [filters] * len(queries)
|
filters = [filters] * len(queries)
|
||||||
|
|
||||||
body = []
|
body = []
|
||||||
|
all_documents = []
|
||||||
for query, cur_filters in zip(queries, filters):
|
for query, cur_filters in zip(queries, filters):
|
||||||
cur_query_body = self._construct_query_body(
|
cur_query_body = self._construct_query_body(
|
||||||
query=query,
|
query=query,
|
||||||
@ -1038,17 +1048,27 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
|||||||
body.append(headers)
|
body.append(headers)
|
||||||
body.append(cur_query_body)
|
body.append(cur_query_body)
|
||||||
|
|
||||||
responses = self.client.msearch(index=index, body=body)
|
if len(body) == 2 * batch_size:
|
||||||
|
cur_documents = self._execute_msearch(index=index, body=body, scale_score=scale_score)
|
||||||
|
all_documents.extend(cur_documents)
|
||||||
|
body = []
|
||||||
|
|
||||||
all_documents = []
|
if len(body) > 0:
|
||||||
cur_documents = []
|
cur_documents = self._execute_msearch(index=index, body=body, scale_score=scale_score)
|
||||||
for response in responses["responses"]:
|
all_documents.extend(cur_documents)
|
||||||
cur_result = response["hits"]["hits"]
|
|
||||||
cur_documents = [self._convert_es_hit_to_document(hit, scale_score=scale_score) for hit in cur_result]
|
|
||||||
all_documents.append(cur_documents)
|
|
||||||
|
|
||||||
return all_documents
|
return all_documents
|
||||||
|
|
||||||
|
def _execute_msearch(self, index: str, body: List[Dict[str, Any]], scale_score: bool) -> List[List[Document]]:
|
||||||
|
responses = self.client.msearch(index=index, body=body)
|
||||||
|
documents = []
|
||||||
|
for response in responses["responses"]:
|
||||||
|
result = response["hits"]["hits"]
|
||||||
|
cur_documents = [self._convert_es_hit_to_document(hit, scale_score=scale_score) for hit in result]
|
||||||
|
documents.append(cur_documents)
|
||||||
|
|
||||||
|
return documents
|
||||||
|
|
||||||
def _construct_query_body(
|
def _construct_query_body(
|
||||||
self,
|
self,
|
||||||
query: Optional[str],
|
query: Optional[str],
|
||||||
@ -1188,6 +1208,7 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
|||||||
return_embedding: Optional[bool] = None,
|
return_embedding: Optional[bool] = None,
|
||||||
headers: Optional[Dict[str, str]] = None,
|
headers: Optional[Dict[str, str]] = None,
|
||||||
scale_score: bool = True,
|
scale_score: bool = True,
|
||||||
|
batch_size: Optional[int] = None,
|
||||||
) -> List[List[Document]]:
|
) -> List[List[Document]]:
|
||||||
"""
|
"""
|
||||||
Find the documents that are most similar to the provided `query_embs` by using a vector similarity metric.
|
Find the documents that are most similar to the provided `query_embs` by using a vector similarity metric.
|
||||||
@ -1264,8 +1285,8 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
|||||||
Check out https://www.elastic.co/guide/en/elasticsearch/reference/current/http-clients.html for more information.
|
Check out https://www.elastic.co/guide/en/elasticsearch/reference/current/http-clients.html for more information.
|
||||||
:param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]).
|
:param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]).
|
||||||
If true (default) similarity scores (e.g. cosine or dot_product) which naturally have a different value range will be scaled to a range of [0,1], where 1 means extremely relevant.
|
If true (default) similarity scores (e.g. cosine or dot_product) which naturally have a different value range will be scaled to a range of [0,1], where 1 means extremely relevant.
|
||||||
Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
|
Otherwise, raw similarity scores (e.g. cosine or dot_product) will be used.
|
||||||
:return:
|
:param batch_size: Number of query embeddings to process at once. If not specified, self.batch_size is used.
|
||||||
"""
|
"""
|
||||||
if index is None:
|
if index is None:
|
||||||
index = self.index
|
index = self.index
|
||||||
@ -1276,6 +1297,8 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
|||||||
if headers is None:
|
if headers is None:
|
||||||
headers = {}
|
headers = {}
|
||||||
|
|
||||||
|
batch_size = batch_size or self.batch_size
|
||||||
|
|
||||||
if not self.embedding_field:
|
if not self.embedding_field:
|
||||||
raise DocumentStoreError("Please set a valid `embedding_field` for OpenSearchDocumentStore")
|
raise DocumentStoreError("Please set a valid `embedding_field` for OpenSearchDocumentStore")
|
||||||
|
|
||||||
@ -1289,25 +1312,24 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
|||||||
filters = [filters] * len(query_embs) if filters is not None else [{}] * len(query_embs)
|
filters = [filters] * len(query_embs) if filters is not None else [{}] * len(query_embs)
|
||||||
|
|
||||||
body = []
|
body = []
|
||||||
for query_emb, cur_filters in zip(query_embs, filters):
|
all_documents = []
|
||||||
|
for query_emb, cur_filters in tqdm(zip(query_embs, filters)):
|
||||||
cur_query_body = self._construct_dense_query_body(
|
cur_query_body = self._construct_dense_query_body(
|
||||||
query_emb=query_emb, filters=cur_filters, top_k=top_k, return_embedding=return_embedding
|
query_emb=query_emb, filters=cur_filters, top_k=top_k, return_embedding=return_embedding
|
||||||
)
|
)
|
||||||
body.append(headers)
|
body.append(headers)
|
||||||
body.append(cur_query_body)
|
body.append(cur_query_body)
|
||||||
|
|
||||||
|
if len(body) >= batch_size * 2:
|
||||||
logger.debug("Retriever query: %s", body)
|
logger.debug("Retriever query: %s", body)
|
||||||
responses = self.client.msearch(index=index, body=body)
|
cur_documents = self._execute_msearch(index=index, body=body, scale_score=scale_score)
|
||||||
|
all_documents.extend(cur_documents)
|
||||||
|
body = []
|
||||||
|
|
||||||
all_documents = []
|
if len(body) > 0:
|
||||||
cur_documents = []
|
logger.debug("Retriever query: %s", body)
|
||||||
for response in responses["responses"]:
|
cur_documents = self._execute_msearch(index=index, body=body, scale_score=scale_score)
|
||||||
cur_result = response["hits"]["hits"]
|
all_documents.extend(cur_documents)
|
||||||
cur_documents = [
|
|
||||||
self._convert_es_hit_to_document(hit, adapt_score_for_embedding=True, scale_score=scale_score)
|
|
||||||
for hit in cur_result
|
|
||||||
]
|
|
||||||
all_documents.append(cur_documents)
|
|
||||||
|
|
||||||
return all_documents
|
return all_documents
|
||||||
|
|
||||||
@ -1323,7 +1345,7 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
|||||||
index: Optional[str] = None,
|
index: Optional[str] = None,
|
||||||
filters: Optional[FilterType] = None,
|
filters: Optional[FilterType] = None,
|
||||||
update_existing_embeddings: bool = True,
|
update_existing_embeddings: bool = True,
|
||||||
batch_size: int = 10_000,
|
batch_size: Optional[int] = None,
|
||||||
headers: Optional[Dict[str, str]] = None,
|
headers: Optional[Dict[str, str]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -1370,6 +1392,8 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
|||||||
if index is None:
|
if index is None:
|
||||||
index = self.index
|
index = self.index
|
||||||
|
|
||||||
|
batch_size = batch_size or self.batch_size
|
||||||
|
|
||||||
if self.refresh_type == "false":
|
if self.refresh_type == "false":
|
||||||
self.client.indices.refresh(index=index, headers=headers)
|
self.client.indices.refresh(index=index, headers=headers)
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
@ -344,3 +344,10 @@ class TestElasticsearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngine
|
|||||||
# assert the resulting body is not affected by the `excluded_meta_data` value
|
# assert the resulting body is not affected by the `excluded_meta_data` value
|
||||||
_, kwargs = mocked_document_store.client.search.call_args
|
_, kwargs = mocked_document_store.client.search.call_args
|
||||||
assert kwargs["_source"] == {"excludes": ["embedding"]}
|
assert kwargs["_source"] == {"excludes": ["embedding"]}
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_write_documents_req_for_each_batch(self, mocked_document_store, documents):
|
||||||
|
mocked_document_store.batch_size = 2
|
||||||
|
with patch("haystack.document_stores.elasticsearch.bulk") as mocked_bulk:
|
||||||
|
mocked_document_store.write_documents(documents)
|
||||||
|
assert mocked_bulk.call_count == 5
|
||||||
|
|||||||
@ -1291,3 +1291,10 @@ class TestOpenSearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngineDoc
|
|||||||
# assert the resulting body is not affected by the `excluded_meta_data` value
|
# assert the resulting body is not affected by the `excluded_meta_data` value
|
||||||
_, kwargs = mocked_document_store.client.search.call_args
|
_, kwargs = mocked_document_store.client.search.call_args
|
||||||
assert kwargs["body"]["_source"] == {"excludes": ["embedding"]}
|
assert kwargs["body"]["_source"] == {"excludes": ["embedding"]}
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_write_documents_req_for_each_batch(self, mocked_document_store, documents):
|
||||||
|
mocked_document_store.batch_size = 2
|
||||||
|
with patch("haystack.document_stores.opensearch.bulk") as mocked_bulk:
|
||||||
|
mocked_document_store.write_documents(documents)
|
||||||
|
assert mocked_bulk.call_count == 5
|
||||||
|
|||||||
@ -1,4 +1,6 @@
|
|||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from haystack.document_stores.search_engine import SearchEngineDocumentStore, prepare_hosts
|
from haystack.document_stores.search_engine import SearchEngineDocumentStore, prepare_hosts
|
||||||
|
|
||||||
@ -167,6 +169,18 @@ class SearchEngineDocumentStoreTestAbstract:
|
|||||||
labels = mocked_document_store.get_all_labels()
|
labels = mocked_document_store.get_all_labels()
|
||||||
assert labels[0].answer.document_ids == ["fc18c987a8312e72a47fb1524f230bb0"]
|
assert labels[0].answer.document_ids == ["fc18c987a8312e72a47fb1524f230bb0"]
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_query_batch_req_for_each_batch(self, mocked_document_store):
|
||||||
|
mocked_document_store.batch_size = 2
|
||||||
|
mocked_document_store.query_batch([self.query] * 3)
|
||||||
|
assert mocked_document_store.client.msearch.call_count == 2
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_query_by_embedding_batch_req_for_each_batch(self, mocked_document_store):
|
||||||
|
mocked_document_store.batch_size = 2
|
||||||
|
mocked_document_store.query_by_embedding_batch([np.array([1, 2, 3])] * 3)
|
||||||
|
assert mocked_document_store.client.msearch.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.document_store
|
@pytest.mark.document_store
|
||||||
class TestSearchEngineDocumentStore:
|
class TestSearchEngineDocumentStore:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user