mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-26 23:38:58 +00:00
feat: add query_by_embedding_batch (#3546)
* add query_by_embedding_batch * fix mypy * fix pylint * add test * move query_by_embedding_batch to search_engine * fix and add tests * fix pylint * remove Retriever query logs * add test for multimodal batch retrieval * allow for np.ndarray
This commit is contained in:
parent
25bf95d47f
commit
c1c1c97bb2
@ -12,7 +12,7 @@ import numpy as np
|
||||
|
||||
from haystack.schema import Document, Label, MultiLabel
|
||||
from haystack.nodes.base import BaseComponent
|
||||
from haystack.errors import DuplicateDocumentError, DocumentStoreError
|
||||
from haystack.errors import DuplicateDocumentError, DocumentStoreError, HaystackError
|
||||
from haystack.nodes.preprocessor import PreProcessor
|
||||
from haystack.document_stores.utils import eval_data_from_json, eval_data_from_jsonl, squad_json_to_jsonl
|
||||
from haystack.utils.labels import aggregate_labels
|
||||
@ -359,6 +359,44 @@ class BaseDocumentStore(BaseComponent):
|
||||
) -> List[Document]:
|
||||
pass
|
||||
|
||||
def query_by_embedding_batch(
|
||||
self,
|
||||
query_embs: Union[List[np.ndarray], np.ndarray],
|
||||
filters: Optional[
|
||||
Union[
|
||||
Dict[str, Union[Dict, List, str, int, float, bool]],
|
||||
List[Dict[str, Union[Dict, List, str, int, float, bool]]],
|
||||
]
|
||||
] = None,
|
||||
top_k: int = 10,
|
||||
index: Optional[str] = None,
|
||||
return_embedding: Optional[bool] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
scale_score: bool = True,
|
||||
) -> List[List[Document]]:
|
||||
if isinstance(filters, list):
|
||||
if len(filters) != len(query_embs):
|
||||
raise HaystackError(
|
||||
"Number of filters does not match number of query_embs. Please provide as many filters"
|
||||
" as query_embs or a single filter that will be applied to each query_emb."
|
||||
)
|
||||
else:
|
||||
filters = [filters] * len(query_embs) if filters is not None else [{}] * len(query_embs)
|
||||
results = []
|
||||
for query_emb, filter in zip(query_embs, filters):
|
||||
results.append(
|
||||
self.query_by_embedding(
|
||||
query_emb=query_emb,
|
||||
filters=filter,
|
||||
top_k=top_k,
|
||||
index=index,
|
||||
return_embedding=return_embedding,
|
||||
headers=headers,
|
||||
scale_score=scale_score,
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
@abstractmethod
|
||||
def get_label_count(self, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None) -> int:
|
||||
pass
|
||||
|
||||
@ -377,29 +377,7 @@ class ElasticsearchDocumentStore(SearchEngineDocumentStore):
|
||||
if not self.embedding_field:
|
||||
raise RuntimeError("Please specify arg `embedding_field` in ElasticsearchDocumentStore()")
|
||||
|
||||
# +1 in similarity to avoid negative numbers (for cosine sim)
|
||||
body = {"size": top_k, "query": self._get_vector_similarity_query(query_emb, top_k)}
|
||||
if filters:
|
||||
filter_ = {"bool": {"filter": LogicalFilterClause.parse(filters).convert_to_elasticsearch()}}
|
||||
if body["query"]["script_score"]["query"] == {"match_all": {}}:
|
||||
body["query"]["script_score"]["query"] = filter_
|
||||
else:
|
||||
body["query"]["script_score"]["query"]["bool"]["filter"]["bool"]["must"].append(filter_)
|
||||
|
||||
excluded_meta_data: Optional[list] = None
|
||||
|
||||
if self.excluded_meta_data:
|
||||
excluded_meta_data = deepcopy(self.excluded_meta_data)
|
||||
|
||||
if return_embedding is True and self.embedding_field in excluded_meta_data:
|
||||
excluded_meta_data.remove(self.embedding_field)
|
||||
elif return_embedding is False and self.embedding_field not in excluded_meta_data:
|
||||
excluded_meta_data.append(self.embedding_field)
|
||||
elif return_embedding is False:
|
||||
excluded_meta_data = [self.embedding_field]
|
||||
|
||||
if excluded_meta_data:
|
||||
body["_source"] = {"excludes": excluded_meta_data}
|
||||
body = self._construct_dense_query_body(query_emb, filters, top_k, return_embedding)
|
||||
|
||||
logger.debug("Retriever query: %s", body)
|
||||
try:
|
||||
@ -428,6 +406,37 @@ class ElasticsearchDocumentStore(SearchEngineDocumentStore):
|
||||
]
|
||||
return documents
|
||||
|
||||
def _construct_dense_query_body(
|
||||
self,
|
||||
query_emb: np.ndarray,
|
||||
filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None,
|
||||
top_k: int = 10,
|
||||
return_embedding: Optional[bool] = None,
|
||||
):
|
||||
body = {"size": top_k, "query": self._get_vector_similarity_query(query_emb, top_k)}
|
||||
if filters:
|
||||
filter_ = {"bool": {"filter": LogicalFilterClause.parse(filters).convert_to_elasticsearch()}}
|
||||
if body["query"]["script_score"]["query"] == {"match_all": {}}:
|
||||
body["query"]["script_score"]["query"] = filter_
|
||||
else:
|
||||
body["query"]["script_score"]["query"]["bool"]["filter"]["bool"]["must"].append(filter_)
|
||||
|
||||
excluded_meta_data: Optional[list] = None
|
||||
|
||||
if self.excluded_meta_data:
|
||||
excluded_meta_data = deepcopy(self.excluded_meta_data)
|
||||
|
||||
if return_embedding is True and self.embedding_field in excluded_meta_data:
|
||||
excluded_meta_data.remove(self.embedding_field)
|
||||
elif return_embedding is False and self.embedding_field not in excluded_meta_data:
|
||||
excluded_meta_data.append(self.embedding_field)
|
||||
elif return_embedding is False:
|
||||
excluded_meta_data = [self.embedding_field]
|
||||
|
||||
if excluded_meta_data:
|
||||
body["_source"] = {"excludes": excluded_meta_data}
|
||||
return body
|
||||
|
||||
def _create_document_index(self, index_name: str, headers: Optional[Dict[str, str]] = None):
|
||||
"""
|
||||
Create a new index for storing documents. In case if an index with the name already exists, it ensures that
|
||||
|
||||
@ -439,7 +439,28 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
|
||||
|
||||
if not self.embedding_field:
|
||||
raise DocumentStoreError("Please set a valid `embedding_field` for OpenSearchDocumentStore")
|
||||
# +1 in similarity to avoid negative numbers (for cosine sim)
|
||||
body = self._construct_dense_query_body(
|
||||
query_emb=query_emb, filters=filters, top_k=top_k, return_embedding=return_embedding
|
||||
)
|
||||
|
||||
logger.debug("Retriever query: %s", body)
|
||||
result = self.client.search(index=index, body=body, request_timeout=300, headers=headers)["hits"]["hits"]
|
||||
|
||||
documents = [
|
||||
self._convert_es_hit_to_document(
|
||||
hit, adapt_score_for_embedding=True, return_embedding=return_embedding, scale_score=scale_score
|
||||
)
|
||||
for hit in result
|
||||
]
|
||||
return documents
|
||||
|
||||
def _construct_dense_query_body(
|
||||
self,
|
||||
query_emb: np.ndarray,
|
||||
filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None,
|
||||
top_k: int = 10,
|
||||
return_embedding: Optional[bool] = None,
|
||||
):
|
||||
body: Dict[str, Any] = {"size": top_k, "query": self._get_vector_similarity_query(query_emb, top_k)}
|
||||
if filters:
|
||||
filter_ = LogicalFilterClause.parse(filters).convert_to_elasticsearch()
|
||||
@ -450,7 +471,6 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
|
||||
body["query"]["bool"]["filter"] = filter_
|
||||
|
||||
excluded_meta_data: Optional[list] = None
|
||||
|
||||
if self.excluded_meta_data:
|
||||
excluded_meta_data = deepcopy(self.excluded_meta_data)
|
||||
|
||||
@ -463,17 +483,7 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
|
||||
|
||||
if excluded_meta_data:
|
||||
body["_source"] = {"excludes": excluded_meta_data}
|
||||
|
||||
logger.debug("Retriever query: %s", body)
|
||||
result = self.client.search(index=index, body=body, request_timeout=300, headers=headers)["hits"]["hits"]
|
||||
|
||||
documents = [
|
||||
self._convert_es_hit_to_document(
|
||||
hit, adapt_score_for_embedding=True, return_embedding=return_embedding, scale_score=scale_score
|
||||
)
|
||||
for hit in result
|
||||
]
|
||||
return documents
|
||||
return body
|
||||
|
||||
def _create_document_index(self, index_name: str, headers: Optional[Dict[str, str]] = None):
|
||||
"""
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# pylint: disable=too-many-public-methods
|
||||
|
||||
from typing import List, Optional, Union, Dict, Any, Generator
|
||||
from abc import abstractmethod
|
||||
import json
|
||||
@ -873,7 +875,6 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
||||
all_terms_must_match=all_terms_must_match,
|
||||
)
|
||||
|
||||
logger.debug("Retriever query: %s", body)
|
||||
result = self.client.search(index=index, body=body, headers=headers)["hits"]["hits"]
|
||||
|
||||
documents = [
|
||||
@ -1012,7 +1013,6 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
||||
body.append(headers)
|
||||
body.append(cur_query_body)
|
||||
|
||||
logger.debug("Retriever query: %s", body)
|
||||
responses = self.client.msearch(index=index, body=body)
|
||||
|
||||
all_documents = []
|
||||
@ -1142,6 +1142,155 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
||||
) from e
|
||||
return document
|
||||
|
||||
def query_by_embedding_batch(
|
||||
self,
|
||||
query_embs: Union[List[np.ndarray], np.ndarray],
|
||||
filters: Optional[
|
||||
Union[
|
||||
Dict[str, Union[Dict, List, str, int, float, bool]],
|
||||
List[Dict[str, Union[Dict, List, str, int, float, bool]]],
|
||||
]
|
||||
] = None,
|
||||
top_k: int = 10,
|
||||
index: Optional[str] = None,
|
||||
return_embedding: Optional[bool] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
scale_score: bool = True,
|
||||
) -> List[List[Document]]:
|
||||
"""
|
||||
Find the documents that are most similar to the provided `query_embs` by using a vector similarity metric.
|
||||
|
||||
:param query_embs: Embeddings of the queries (e.g. gathered from DPR).
|
||||
Can be a list of one-dimensional numpy arrays or a two-dimensional numpy array.
|
||||
:param filters: Optional filters to narrow down the search space to documents whose metadata fulfill certain
|
||||
conditions.
|
||||
Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical
|
||||
operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`,
|
||||
`"$gte"`, `"$lt"`, `"$lte"`) or a metadata field name.
|
||||
Logical operator keys take a dictionary of metadata field names and/or logical operators as
|
||||
value. Metadata field names take a dictionary of comparison operators as value. Comparison
|
||||
operator keys take a single value or (in case of `"$in"`) a list of values as value.
|
||||
If no logical operator is provided, `"$and"` is used as default operation. If no comparison
|
||||
operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default
|
||||
operation.
|
||||
|
||||
__Example__:
|
||||
```python
|
||||
filters = {
|
||||
"$and": {
|
||||
"type": {"$eq": "article"},
|
||||
"date": {"$gte": "2015-01-01", "$lt": "2021-01-01"},
|
||||
"rating": {"$gte": 3},
|
||||
"$or": {
|
||||
"genre": {"$in": ["economy", "politics"]},
|
||||
"publisher": {"$eq": "nytimes"}
|
||||
}
|
||||
}
|
||||
}
|
||||
# or simpler using default operators
|
||||
filters = {
|
||||
"type": "article",
|
||||
"date": {"$gte": "2015-01-01", "$lt": "2021-01-01"},
|
||||
"rating": {"$gte": 3},
|
||||
"$or": {
|
||||
"genre": ["economy", "politics"],
|
||||
"publisher": "nytimes"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
To use the same logical operator multiple times on the same level, logical operators take
|
||||
optionally a list of dictionaries as value.
|
||||
|
||||
__Example__:
|
||||
```python
|
||||
filters = {
|
||||
"$or": [
|
||||
{
|
||||
"$and": {
|
||||
"Type": "News Paper",
|
||||
"Date": {
|
||||
"$lt": "2019-01-01"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$and": {
|
||||
"Type": "Blog Post",
|
||||
"Date": {
|
||||
"$gte": "2019-01-01"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
:param top_k: How many documents to return
|
||||
:param index: Index name for storing the docs and metadata
|
||||
:param return_embedding: To return document embedding
|
||||
:param headers: Custom HTTP headers to pass to elasticsearch client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='})
|
||||
Check out https://www.elastic.co/guide/en/elasticsearch/reference/current/http-clients.html for more information.
|
||||
:param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]).
|
||||
If true (default) similarity scores (e.g. cosine or dot_product) which naturally have a different value range will be scaled to a range of [0,1], where 1 means extremely relevant.
|
||||
Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
|
||||
:return:
|
||||
"""
|
||||
if index is None:
|
||||
index = self.index
|
||||
|
||||
if return_embedding is None:
|
||||
return_embedding = self.return_embedding
|
||||
|
||||
if headers is None:
|
||||
headers = {}
|
||||
|
||||
if not self.embedding_field:
|
||||
raise DocumentStoreError("Please set a valid `embedding_field` for OpenSearchDocumentStore")
|
||||
|
||||
if isinstance(filters, list):
|
||||
if len(filters) != len(query_embs):
|
||||
raise HaystackError(
|
||||
"Number of filters does not match number of query_embs. Please provide as many filters"
|
||||
" as query_embs or a single filter that will be applied to each query_emb."
|
||||
)
|
||||
else:
|
||||
filters = [filters] * len(query_embs) if filters is not None else [{}] * len(query_embs)
|
||||
|
||||
body = []
|
||||
for query_emb, cur_filters in zip(query_embs, filters):
|
||||
cur_query_body = self._construct_dense_query_body(
|
||||
query_emb=query_emb, filters=cur_filters, top_k=top_k, return_embedding=return_embedding
|
||||
)
|
||||
body.append(headers)
|
||||
body.append(cur_query_body)
|
||||
|
||||
logger.debug("Retriever query: %s", body)
|
||||
responses = self.client.msearch(index=index, body=body)
|
||||
|
||||
all_documents = []
|
||||
cur_documents = []
|
||||
for response in responses["responses"]:
|
||||
cur_result = response["hits"]["hits"]
|
||||
cur_documents = [
|
||||
self._convert_es_hit_to_document(
|
||||
hit, adapt_score_for_embedding=True, return_embedding=self.return_embedding, scale_score=scale_score
|
||||
)
|
||||
for hit in cur_result
|
||||
]
|
||||
all_documents.append(cur_documents)
|
||||
|
||||
return all_documents
|
||||
|
||||
@abstractmethod
|
||||
def _construct_dense_query_body(
|
||||
self,
|
||||
query_emb: np.ndarray,
|
||||
filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None,
|
||||
top_k: int = 10,
|
||||
return_embedding: Optional[bool] = None,
|
||||
):
|
||||
pass
|
||||
|
||||
def update_embeddings(
|
||||
self,
|
||||
retriever: DenseRetriever,
|
||||
|
||||
@ -467,22 +467,12 @@ class DensePassageRetriever(DenseRetriever):
|
||||
if scale_score is None:
|
||||
scale_score = self.scale_score
|
||||
|
||||
documents = []
|
||||
query_embs: List[np.ndarray] = []
|
||||
for batch in self._get_batches(queries=queries, batch_size=batch_size):
|
||||
query_embs.extend(self.embed_queries(queries=batch))
|
||||
for query_emb, cur_filters in tqdm(
|
||||
zip(query_embs, filters), total=len(query_embs), disable=not self.progress_bar, desc="Querying"
|
||||
):
|
||||
cur_docs = document_store.query_by_embedding(
|
||||
query_emb=query_emb,
|
||||
top_k=top_k,
|
||||
filters=cur_filters,
|
||||
index=index,
|
||||
headers=headers,
|
||||
scale_score=scale_score,
|
||||
)
|
||||
documents.append(cur_docs)
|
||||
documents = document_store.query_by_embedding_batch(
|
||||
query_embs=query_embs, top_k=top_k, filters=filters, index=index, headers=headers, scale_score=scale_score
|
||||
)
|
||||
|
||||
return documents
|
||||
|
||||
@ -1111,22 +1101,12 @@ class TableTextRetriever(DenseRetriever):
|
||||
if scale_score is None:
|
||||
scale_score = self.scale_score
|
||||
|
||||
documents = []
|
||||
query_embs: List[np.ndarray] = []
|
||||
for batch in self._get_batches(queries=queries, batch_size=batch_size):
|
||||
query_embs.extend(self.embed_queries(queries=batch))
|
||||
for query_emb, cur_filters in tqdm(
|
||||
zip(query_embs, filters), total=len(query_embs), disable=not self.progress_bar, desc="Querying"
|
||||
):
|
||||
cur_docs = document_store.query_by_embedding(
|
||||
query_emb=query_emb,
|
||||
top_k=top_k,
|
||||
filters=cur_filters,
|
||||
index=index,
|
||||
headers=headers,
|
||||
scale_score=scale_score,
|
||||
)
|
||||
documents.append(cur_docs)
|
||||
documents = document_store.query_by_embedding_batch(
|
||||
query_embs=query_embs, top_k=top_k, filters=filters, index=index, headers=headers, scale_score=scale_score
|
||||
)
|
||||
|
||||
return documents
|
||||
|
||||
@ -1823,22 +1803,12 @@ class EmbeddingRetriever(DenseRetriever):
|
||||
if scale_score is None:
|
||||
scale_score = self.scale_score
|
||||
|
||||
documents = []
|
||||
query_embs: List[np.ndarray] = []
|
||||
for batch in self._get_batches(queries=queries, batch_size=batch_size):
|
||||
query_embs.extend(self.embed_queries(queries=batch))
|
||||
for query_emb, cur_filters in tqdm(
|
||||
zip(query_embs, filters), total=len(query_embs), disable=not self.progress_bar, desc="Querying"
|
||||
):
|
||||
cur_docs = document_store.query_by_embedding(
|
||||
query_emb=query_emb,
|
||||
top_k=top_k,
|
||||
filters=cur_filters,
|
||||
index=index,
|
||||
headers=headers,
|
||||
scale_score=scale_score,
|
||||
)
|
||||
documents.append(cur_docs)
|
||||
documents = document_store.query_by_embedding_batch(
|
||||
query_embs=query_embs, top_k=top_k, filters=filters, index=index, headers=headers, scale_score=scale_score
|
||||
)
|
||||
|
||||
return documents
|
||||
|
||||
@ -2301,22 +2271,22 @@ class MultihopEmbeddingRetriever(EmbeddingRetriever):
|
||||
for it in range(self.num_iterations):
|
||||
texts = [self._merge_query_and_context(q, c) for q, c in zip(batch, context_docs)]
|
||||
query_embs = self.embed_queries(texts)
|
||||
for idx, emb in enumerate(query_embs):
|
||||
cur_docs = document_store.query_by_embedding(
|
||||
query_emb=emb,
|
||||
top_k=top_k,
|
||||
filters=cur_filters,
|
||||
index=index,
|
||||
headers=headers,
|
||||
scale_score=scale_score,
|
||||
)
|
||||
if it < self.num_iterations - 1:
|
||||
# add doc with highest score to context
|
||||
cur_docs_batch = document_store.query_by_embedding_batch(
|
||||
query_embs=query_embs,
|
||||
top_k=top_k,
|
||||
filters=cur_filters,
|
||||
index=index,
|
||||
headers=headers,
|
||||
scale_score=scale_score,
|
||||
)
|
||||
if it < self.num_iterations - 1:
|
||||
# add doc with highest score to context
|
||||
for idx, cur_docs in enumerate(cur_docs_batch):
|
||||
if len(cur_docs) > 0:
|
||||
context_docs[idx].append(cur_docs[0])
|
||||
else:
|
||||
# documents in the last iteration are final results
|
||||
documents.append(cur_docs)
|
||||
else:
|
||||
# documents in the last iteration are final results
|
||||
documents.extend(cur_docs_batch)
|
||||
pb.update(len(batch))
|
||||
pb.close()
|
||||
|
||||
|
||||
@ -205,18 +205,15 @@ class MultiModalRetriever(DenseRetriever):
|
||||
query_embeddings = self.query_embedder.embed(documents=query_docs, batch_size=batch_size)
|
||||
|
||||
# Query documents by embedding (the actual retrieval step)
|
||||
documents = []
|
||||
for query_embedding, query_filters in zip(query_embeddings, filters_list):
|
||||
docs = document_store.query_by_embedding(
|
||||
query_emb=query_embedding,
|
||||
top_k=top_k,
|
||||
filters=query_filters,
|
||||
index=index,
|
||||
headers=headers,
|
||||
scale_score=scale_score,
|
||||
)
|
||||
documents = document_store.query_by_embedding_batch(
|
||||
query_embs=query_embeddings,
|
||||
top_k=top_k,
|
||||
filters=filters_list, # type: ignore
|
||||
index=index,
|
||||
headers=headers,
|
||||
scale_score=scale_score,
|
||||
)
|
||||
|
||||
documents.append(docs)
|
||||
return documents
|
||||
|
||||
def embed_documents(self, docs: List[Document]) -> np.ndarray:
|
||||
|
||||
@ -685,15 +685,13 @@ class MostSimilarDocumentsPipeline(BaseStandardPipeline):
|
||||
:param top_k: How many documents id to return against single document
|
||||
:param index: Optionally specify the name of index to query the document from. If None, the DocumentStore's default index (self.index) will be used.
|
||||
"""
|
||||
similar_documents: list = []
|
||||
self.document_store.return_embedding = True # type: ignore
|
||||
|
||||
for document in self.document_store.get_documents_by_id(ids=document_ids, index=index):
|
||||
similar_documents.append(
|
||||
self.document_store.query_by_embedding(
|
||||
query_emb=document.embedding, filters=filters, return_embedding=False, top_k=top_k, index=index
|
||||
)
|
||||
)
|
||||
documents = self.document_store.get_documents_by_id(ids=document_ids, index=index)
|
||||
query_embs = [doc.embedding for doc in documents]
|
||||
similar_documents = self.document_store.query_by_embedding_batch(
|
||||
query_embs=query_embs, filters=filters, return_embedding=False, top_k=top_k, index=index
|
||||
)
|
||||
|
||||
self.document_store.return_embedding = False # type: ignore
|
||||
return similar_documents
|
||||
|
||||
@ -91,3 +91,14 @@ class TestInMemoryDocumentStore(DocumentStoreBaseTestAbstract):
|
||||
assert "A Foo Document" in docs[0][0].content
|
||||
assert len(docs[1]) == 5
|
||||
assert "A Bar Document" in docs[1][0].content
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_memory_query_by_embedding_batch(self, ds, documents):
|
||||
documents = [doc for doc in documents if doc.embedding is not None]
|
||||
ds.write_documents(documents)
|
||||
query_embs = [doc.embedding for doc in documents]
|
||||
docs_batch = ds.query_by_embedding_batch(query_embs=query_embs, top_k=5)
|
||||
assert len(docs_batch) == 6
|
||||
for docs, query_emb in zip(docs_batch, query_embs):
|
||||
assert len(docs) == 5
|
||||
assert (docs[0].embedding == query_emb).all()
|
||||
|
||||
@ -178,6 +178,20 @@ class TestOpenSearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngineDoc
|
||||
)
|
||||
assert len(results) == 3
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("use_ann", [True, False])
|
||||
def test_query_embedding_batch_with_filters(self, ds: OpenSearchDocumentStore, documents, use_ann):
|
||||
ds.embeddings_field_supports_similarity = use_ann
|
||||
ds.write_documents(documents)
|
||||
results = ds.query_by_embedding_batch(
|
||||
query_embs=[np.random.rand(768).astype(np.float32) for _ in range(2)],
|
||||
filters=[{"year": "2020"} for _ in range(2)],
|
||||
top_k=10,
|
||||
)
|
||||
assert len(results) == 2
|
||||
for result in results:
|
||||
assert len(result) == 3
|
||||
|
||||
# Unit tests
|
||||
|
||||
@pytest.mark.unit
|
||||
@ -321,6 +335,13 @@ class TestOpenSearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngineDoc
|
||||
_, kwargs = mocked_document_store.client.search.call_args
|
||||
assert kwargs["body"]["_source"] == {"excludes": ["foo", "embedding"]}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_query_by_embedding_batch_uses_msearch(self, mocked_document_store):
|
||||
mocked_document_store.query_by_embedding_batch([self.query_emb for _ in range(10)])
|
||||
# assert the resulting body is consistent with the `excluded_meta_data` value
|
||||
_, kwargs = mocked_document_store.client.msearch.call_args
|
||||
assert len(kwargs["body"]) == 20 # each search has headers and request
|
||||
|
||||
@pytest.mark.unit
|
||||
def test__create_document_index_with_alias(self, mocked_document_store, caplog):
|
||||
mocked_document_store.client.indices.exists_alias.return_value = True
|
||||
|
||||
@ -809,6 +809,22 @@ def test_multimodal_text_retrieval(text_docs: List[Document]):
|
||||
assert results[0].content == "My name is Christelle and I live in Paris"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_multimodal_text_retrieval_batch(text_docs: List[Document]):
|
||||
retriever = MultiModalRetriever(
|
||||
document_store=InMemoryDocumentStore(return_embedding=True),
|
||||
query_embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
|
||||
document_embedding_models={"text": "sentence-transformers/multi-qa-mpnet-base-dot-v1"},
|
||||
)
|
||||
retriever.document_store.write_documents(text_docs)
|
||||
retriever.document_store.update_embeddings(retriever=retriever)
|
||||
|
||||
results = retriever.retrieve_batch(queries=["Who lives in Paris?", "Who lives in Berlin?", "Who lives in Madrid?"])
|
||||
assert results[0][0].content == "My name is Christelle and I live in Paris"
|
||||
assert results[1][0].content == "My name is Carla and I live in Berlin"
|
||||
assert results[2][0].content == "My name is Camila and I live in Madrid"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_multimodal_table_retrieval(table_docs: List[Document]):
|
||||
retriever = MultiModalRetriever(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user