mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-29 08:49:07 +00:00
feat: add query_by_embedding_batch (#3546)
* add query_by_embedding_batch * fix mypy * fix pylint * add test * move query_by_embedding_batch to search_engine * fix and add tests * fix pylint * remove Retriever query logs * add test for multimodal batch retrieval * allow for np.ndarray
This commit is contained in:
parent
25bf95d47f
commit
c1c1c97bb2
@ -12,7 +12,7 @@ import numpy as np
|
|||||||
|
|
||||||
from haystack.schema import Document, Label, MultiLabel
|
from haystack.schema import Document, Label, MultiLabel
|
||||||
from haystack.nodes.base import BaseComponent
|
from haystack.nodes.base import BaseComponent
|
||||||
from haystack.errors import DuplicateDocumentError, DocumentStoreError
|
from haystack.errors import DuplicateDocumentError, DocumentStoreError, HaystackError
|
||||||
from haystack.nodes.preprocessor import PreProcessor
|
from haystack.nodes.preprocessor import PreProcessor
|
||||||
from haystack.document_stores.utils import eval_data_from_json, eval_data_from_jsonl, squad_json_to_jsonl
|
from haystack.document_stores.utils import eval_data_from_json, eval_data_from_jsonl, squad_json_to_jsonl
|
||||||
from haystack.utils.labels import aggregate_labels
|
from haystack.utils.labels import aggregate_labels
|
||||||
@ -359,6 +359,44 @@ class BaseDocumentStore(BaseComponent):
|
|||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def query_by_embedding_batch(
|
||||||
|
self,
|
||||||
|
query_embs: Union[List[np.ndarray], np.ndarray],
|
||||||
|
filters: Optional[
|
||||||
|
Union[
|
||||||
|
Dict[str, Union[Dict, List, str, int, float, bool]],
|
||||||
|
List[Dict[str, Union[Dict, List, str, int, float, bool]]],
|
||||||
|
]
|
||||||
|
] = None,
|
||||||
|
top_k: int = 10,
|
||||||
|
index: Optional[str] = None,
|
||||||
|
return_embedding: Optional[bool] = None,
|
||||||
|
headers: Optional[Dict[str, str]] = None,
|
||||||
|
scale_score: bool = True,
|
||||||
|
) -> List[List[Document]]:
|
||||||
|
if isinstance(filters, list):
|
||||||
|
if len(filters) != len(query_embs):
|
||||||
|
raise HaystackError(
|
||||||
|
"Number of filters does not match number of query_embs. Please provide as many filters"
|
||||||
|
" as query_embs or a single filter that will be applied to each query_emb."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
filters = [filters] * len(query_embs) if filters is not None else [{}] * len(query_embs)
|
||||||
|
results = []
|
||||||
|
for query_emb, filter in zip(query_embs, filters):
|
||||||
|
results.append(
|
||||||
|
self.query_by_embedding(
|
||||||
|
query_emb=query_emb,
|
||||||
|
filters=filter,
|
||||||
|
top_k=top_k,
|
||||||
|
index=index,
|
||||||
|
return_embedding=return_embedding,
|
||||||
|
headers=headers,
|
||||||
|
scale_score=scale_score,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_label_count(self, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None) -> int:
|
def get_label_count(self, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None) -> int:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -377,29 +377,7 @@ class ElasticsearchDocumentStore(SearchEngineDocumentStore):
|
|||||||
if not self.embedding_field:
|
if not self.embedding_field:
|
||||||
raise RuntimeError("Please specify arg `embedding_field` in ElasticsearchDocumentStore()")
|
raise RuntimeError("Please specify arg `embedding_field` in ElasticsearchDocumentStore()")
|
||||||
|
|
||||||
# +1 in similarity to avoid negative numbers (for cosine sim)
|
body = self._construct_dense_query_body(query_emb, filters, top_k, return_embedding)
|
||||||
body = {"size": top_k, "query": self._get_vector_similarity_query(query_emb, top_k)}
|
|
||||||
if filters:
|
|
||||||
filter_ = {"bool": {"filter": LogicalFilterClause.parse(filters).convert_to_elasticsearch()}}
|
|
||||||
if body["query"]["script_score"]["query"] == {"match_all": {}}:
|
|
||||||
body["query"]["script_score"]["query"] = filter_
|
|
||||||
else:
|
|
||||||
body["query"]["script_score"]["query"]["bool"]["filter"]["bool"]["must"].append(filter_)
|
|
||||||
|
|
||||||
excluded_meta_data: Optional[list] = None
|
|
||||||
|
|
||||||
if self.excluded_meta_data:
|
|
||||||
excluded_meta_data = deepcopy(self.excluded_meta_data)
|
|
||||||
|
|
||||||
if return_embedding is True and self.embedding_field in excluded_meta_data:
|
|
||||||
excluded_meta_data.remove(self.embedding_field)
|
|
||||||
elif return_embedding is False and self.embedding_field not in excluded_meta_data:
|
|
||||||
excluded_meta_data.append(self.embedding_field)
|
|
||||||
elif return_embedding is False:
|
|
||||||
excluded_meta_data = [self.embedding_field]
|
|
||||||
|
|
||||||
if excluded_meta_data:
|
|
||||||
body["_source"] = {"excludes": excluded_meta_data}
|
|
||||||
|
|
||||||
logger.debug("Retriever query: %s", body)
|
logger.debug("Retriever query: %s", body)
|
||||||
try:
|
try:
|
||||||
@ -428,6 +406,37 @@ class ElasticsearchDocumentStore(SearchEngineDocumentStore):
|
|||||||
]
|
]
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
|
def _construct_dense_query_body(
|
||||||
|
self,
|
||||||
|
query_emb: np.ndarray,
|
||||||
|
filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None,
|
||||||
|
top_k: int = 10,
|
||||||
|
return_embedding: Optional[bool] = None,
|
||||||
|
):
|
||||||
|
body = {"size": top_k, "query": self._get_vector_similarity_query(query_emb, top_k)}
|
||||||
|
if filters:
|
||||||
|
filter_ = {"bool": {"filter": LogicalFilterClause.parse(filters).convert_to_elasticsearch()}}
|
||||||
|
if body["query"]["script_score"]["query"] == {"match_all": {}}:
|
||||||
|
body["query"]["script_score"]["query"] = filter_
|
||||||
|
else:
|
||||||
|
body["query"]["script_score"]["query"]["bool"]["filter"]["bool"]["must"].append(filter_)
|
||||||
|
|
||||||
|
excluded_meta_data: Optional[list] = None
|
||||||
|
|
||||||
|
if self.excluded_meta_data:
|
||||||
|
excluded_meta_data = deepcopy(self.excluded_meta_data)
|
||||||
|
|
||||||
|
if return_embedding is True and self.embedding_field in excluded_meta_data:
|
||||||
|
excluded_meta_data.remove(self.embedding_field)
|
||||||
|
elif return_embedding is False and self.embedding_field not in excluded_meta_data:
|
||||||
|
excluded_meta_data.append(self.embedding_field)
|
||||||
|
elif return_embedding is False:
|
||||||
|
excluded_meta_data = [self.embedding_field]
|
||||||
|
|
||||||
|
if excluded_meta_data:
|
||||||
|
body["_source"] = {"excludes": excluded_meta_data}
|
||||||
|
return body
|
||||||
|
|
||||||
def _create_document_index(self, index_name: str, headers: Optional[Dict[str, str]] = None):
|
def _create_document_index(self, index_name: str, headers: Optional[Dict[str, str]] = None):
|
||||||
"""
|
"""
|
||||||
Create a new index for storing documents. In case if an index with the name already exists, it ensures that
|
Create a new index for storing documents. In case if an index with the name already exists, it ensures that
|
||||||
|
|||||||
@ -439,7 +439,28 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
|
|||||||
|
|
||||||
if not self.embedding_field:
|
if not self.embedding_field:
|
||||||
raise DocumentStoreError("Please set a valid `embedding_field` for OpenSearchDocumentStore")
|
raise DocumentStoreError("Please set a valid `embedding_field` for OpenSearchDocumentStore")
|
||||||
# +1 in similarity to avoid negative numbers (for cosine sim)
|
body = self._construct_dense_query_body(
|
||||||
|
query_emb=query_emb, filters=filters, top_k=top_k, return_embedding=return_embedding
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("Retriever query: %s", body)
|
||||||
|
result = self.client.search(index=index, body=body, request_timeout=300, headers=headers)["hits"]["hits"]
|
||||||
|
|
||||||
|
documents = [
|
||||||
|
self._convert_es_hit_to_document(
|
||||||
|
hit, adapt_score_for_embedding=True, return_embedding=return_embedding, scale_score=scale_score
|
||||||
|
)
|
||||||
|
for hit in result
|
||||||
|
]
|
||||||
|
return documents
|
||||||
|
|
||||||
|
def _construct_dense_query_body(
|
||||||
|
self,
|
||||||
|
query_emb: np.ndarray,
|
||||||
|
filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None,
|
||||||
|
top_k: int = 10,
|
||||||
|
return_embedding: Optional[bool] = None,
|
||||||
|
):
|
||||||
body: Dict[str, Any] = {"size": top_k, "query": self._get_vector_similarity_query(query_emb, top_k)}
|
body: Dict[str, Any] = {"size": top_k, "query": self._get_vector_similarity_query(query_emb, top_k)}
|
||||||
if filters:
|
if filters:
|
||||||
filter_ = LogicalFilterClause.parse(filters).convert_to_elasticsearch()
|
filter_ = LogicalFilterClause.parse(filters).convert_to_elasticsearch()
|
||||||
@ -450,7 +471,6 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
|
|||||||
body["query"]["bool"]["filter"] = filter_
|
body["query"]["bool"]["filter"] = filter_
|
||||||
|
|
||||||
excluded_meta_data: Optional[list] = None
|
excluded_meta_data: Optional[list] = None
|
||||||
|
|
||||||
if self.excluded_meta_data:
|
if self.excluded_meta_data:
|
||||||
excluded_meta_data = deepcopy(self.excluded_meta_data)
|
excluded_meta_data = deepcopy(self.excluded_meta_data)
|
||||||
|
|
||||||
@ -463,17 +483,7 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
|
|||||||
|
|
||||||
if excluded_meta_data:
|
if excluded_meta_data:
|
||||||
body["_source"] = {"excludes": excluded_meta_data}
|
body["_source"] = {"excludes": excluded_meta_data}
|
||||||
|
return body
|
||||||
logger.debug("Retriever query: %s", body)
|
|
||||||
result = self.client.search(index=index, body=body, request_timeout=300, headers=headers)["hits"]["hits"]
|
|
||||||
|
|
||||||
documents = [
|
|
||||||
self._convert_es_hit_to_document(
|
|
||||||
hit, adapt_score_for_embedding=True, return_embedding=return_embedding, scale_score=scale_score
|
|
||||||
)
|
|
||||||
for hit in result
|
|
||||||
]
|
|
||||||
return documents
|
|
||||||
|
|
||||||
def _create_document_index(self, index_name: str, headers: Optional[Dict[str, str]] = None):
|
def _create_document_index(self, index_name: str, headers: Optional[Dict[str, str]] = None):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
# pylint: disable=too-many-public-methods
|
||||||
|
|
||||||
from typing import List, Optional, Union, Dict, Any, Generator
|
from typing import List, Optional, Union, Dict, Any, Generator
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
import json
|
import json
|
||||||
@ -873,7 +875,6 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
|||||||
all_terms_must_match=all_terms_must_match,
|
all_terms_must_match=all_terms_must_match,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("Retriever query: %s", body)
|
|
||||||
result = self.client.search(index=index, body=body, headers=headers)["hits"]["hits"]
|
result = self.client.search(index=index, body=body, headers=headers)["hits"]["hits"]
|
||||||
|
|
||||||
documents = [
|
documents = [
|
||||||
@ -1012,7 +1013,6 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
|||||||
body.append(headers)
|
body.append(headers)
|
||||||
body.append(cur_query_body)
|
body.append(cur_query_body)
|
||||||
|
|
||||||
logger.debug("Retriever query: %s", body)
|
|
||||||
responses = self.client.msearch(index=index, body=body)
|
responses = self.client.msearch(index=index, body=body)
|
||||||
|
|
||||||
all_documents = []
|
all_documents = []
|
||||||
@ -1142,6 +1142,155 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
|||||||
) from e
|
) from e
|
||||||
return document
|
return document
|
||||||
|
|
||||||
|
def query_by_embedding_batch(
|
||||||
|
self,
|
||||||
|
query_embs: Union[List[np.ndarray], np.ndarray],
|
||||||
|
filters: Optional[
|
||||||
|
Union[
|
||||||
|
Dict[str, Union[Dict, List, str, int, float, bool]],
|
||||||
|
List[Dict[str, Union[Dict, List, str, int, float, bool]]],
|
||||||
|
]
|
||||||
|
] = None,
|
||||||
|
top_k: int = 10,
|
||||||
|
index: Optional[str] = None,
|
||||||
|
return_embedding: Optional[bool] = None,
|
||||||
|
headers: Optional[Dict[str, str]] = None,
|
||||||
|
scale_score: bool = True,
|
||||||
|
) -> List[List[Document]]:
|
||||||
|
"""
|
||||||
|
Find the documents that are most similar to the provided `query_embs` by using a vector similarity metric.
|
||||||
|
|
||||||
|
:param query_embs: Embeddings of the queries (e.g. gathered from DPR).
|
||||||
|
Can be a list of one-dimensional numpy arrays or a two-dimensional numpy array.
|
||||||
|
:param filters: Optional filters to narrow down the search space to documents whose metadata fulfill certain
|
||||||
|
conditions.
|
||||||
|
Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical
|
||||||
|
operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`,
|
||||||
|
`"$gte"`, `"$lt"`, `"$lte"`) or a metadata field name.
|
||||||
|
Logical operator keys take a dictionary of metadata field names and/or logical operators as
|
||||||
|
value. Metadata field names take a dictionary of comparison operators as value. Comparison
|
||||||
|
operator keys take a single value or (in case of `"$in"`) a list of values as value.
|
||||||
|
If no logical operator is provided, `"$and"` is used as default operation. If no comparison
|
||||||
|
operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default
|
||||||
|
operation.
|
||||||
|
|
||||||
|
__Example__:
|
||||||
|
```python
|
||||||
|
filters = {
|
||||||
|
"$and": {
|
||||||
|
"type": {"$eq": "article"},
|
||||||
|
"date": {"$gte": "2015-01-01", "$lt": "2021-01-01"},
|
||||||
|
"rating": {"$gte": 3},
|
||||||
|
"$or": {
|
||||||
|
"genre": {"$in": ["economy", "politics"]},
|
||||||
|
"publisher": {"$eq": "nytimes"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
# or simpler using default operators
|
||||||
|
filters = {
|
||||||
|
"type": "article",
|
||||||
|
"date": {"$gte": "2015-01-01", "$lt": "2021-01-01"},
|
||||||
|
"rating": {"$gte": 3},
|
||||||
|
"$or": {
|
||||||
|
"genre": ["economy", "politics"],
|
||||||
|
"publisher": "nytimes"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
To use the same logical operator multiple times on the same level, logical operators take
|
||||||
|
optionally a list of dictionaries as value.
|
||||||
|
|
||||||
|
__Example__:
|
||||||
|
```python
|
||||||
|
filters = {
|
||||||
|
"$or": [
|
||||||
|
{
|
||||||
|
"$and": {
|
||||||
|
"Type": "News Paper",
|
||||||
|
"Date": {
|
||||||
|
"$lt": "2019-01-01"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$and": {
|
||||||
|
"Type": "Blog Post",
|
||||||
|
"Date": {
|
||||||
|
"$gte": "2019-01-01"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
:param top_k: How many documents to return
|
||||||
|
:param index: Index name for storing the docs and metadata
|
||||||
|
:param return_embedding: To return document embedding
|
||||||
|
:param headers: Custom HTTP headers to pass to elasticsearch client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='})
|
||||||
|
Check out https://www.elastic.co/guide/en/elasticsearch/reference/current/http-clients.html for more information.
|
||||||
|
:param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]).
|
||||||
|
If true (default) similarity scores (e.g. cosine or dot_product) which naturally have a different value range will be scaled to a range of [0,1], where 1 means extremely relevant.
|
||||||
|
Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if index is None:
|
||||||
|
index = self.index
|
||||||
|
|
||||||
|
if return_embedding is None:
|
||||||
|
return_embedding = self.return_embedding
|
||||||
|
|
||||||
|
if headers is None:
|
||||||
|
headers = {}
|
||||||
|
|
||||||
|
if not self.embedding_field:
|
||||||
|
raise DocumentStoreError("Please set a valid `embedding_field` for OpenSearchDocumentStore")
|
||||||
|
|
||||||
|
if isinstance(filters, list):
|
||||||
|
if len(filters) != len(query_embs):
|
||||||
|
raise HaystackError(
|
||||||
|
"Number of filters does not match number of query_embs. Please provide as many filters"
|
||||||
|
" as query_embs or a single filter that will be applied to each query_emb."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
filters = [filters] * len(query_embs) if filters is not None else [{}] * len(query_embs)
|
||||||
|
|
||||||
|
body = []
|
||||||
|
for query_emb, cur_filters in zip(query_embs, filters):
|
||||||
|
cur_query_body = self._construct_dense_query_body(
|
||||||
|
query_emb=query_emb, filters=cur_filters, top_k=top_k, return_embedding=return_embedding
|
||||||
|
)
|
||||||
|
body.append(headers)
|
||||||
|
body.append(cur_query_body)
|
||||||
|
|
||||||
|
logger.debug("Retriever query: %s", body)
|
||||||
|
responses = self.client.msearch(index=index, body=body)
|
||||||
|
|
||||||
|
all_documents = []
|
||||||
|
cur_documents = []
|
||||||
|
for response in responses["responses"]:
|
||||||
|
cur_result = response["hits"]["hits"]
|
||||||
|
cur_documents = [
|
||||||
|
self._convert_es_hit_to_document(
|
||||||
|
hit, adapt_score_for_embedding=True, return_embedding=self.return_embedding, scale_score=scale_score
|
||||||
|
)
|
||||||
|
for hit in cur_result
|
||||||
|
]
|
||||||
|
all_documents.append(cur_documents)
|
||||||
|
|
||||||
|
return all_documents
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _construct_dense_query_body(
|
||||||
|
self,
|
||||||
|
query_emb: np.ndarray,
|
||||||
|
filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None,
|
||||||
|
top_k: int = 10,
|
||||||
|
return_embedding: Optional[bool] = None,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
def update_embeddings(
|
def update_embeddings(
|
||||||
self,
|
self,
|
||||||
retriever: DenseRetriever,
|
retriever: DenseRetriever,
|
||||||
|
|||||||
@ -467,22 +467,12 @@ class DensePassageRetriever(DenseRetriever):
|
|||||||
if scale_score is None:
|
if scale_score is None:
|
||||||
scale_score = self.scale_score
|
scale_score = self.scale_score
|
||||||
|
|
||||||
documents = []
|
|
||||||
query_embs: List[np.ndarray] = []
|
query_embs: List[np.ndarray] = []
|
||||||
for batch in self._get_batches(queries=queries, batch_size=batch_size):
|
for batch in self._get_batches(queries=queries, batch_size=batch_size):
|
||||||
query_embs.extend(self.embed_queries(queries=batch))
|
query_embs.extend(self.embed_queries(queries=batch))
|
||||||
for query_emb, cur_filters in tqdm(
|
documents = document_store.query_by_embedding_batch(
|
||||||
zip(query_embs, filters), total=len(query_embs), disable=not self.progress_bar, desc="Querying"
|
query_embs=query_embs, top_k=top_k, filters=filters, index=index, headers=headers, scale_score=scale_score
|
||||||
):
|
)
|
||||||
cur_docs = document_store.query_by_embedding(
|
|
||||||
query_emb=query_emb,
|
|
||||||
top_k=top_k,
|
|
||||||
filters=cur_filters,
|
|
||||||
index=index,
|
|
||||||
headers=headers,
|
|
||||||
scale_score=scale_score,
|
|
||||||
)
|
|
||||||
documents.append(cur_docs)
|
|
||||||
|
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
@ -1111,22 +1101,12 @@ class TableTextRetriever(DenseRetriever):
|
|||||||
if scale_score is None:
|
if scale_score is None:
|
||||||
scale_score = self.scale_score
|
scale_score = self.scale_score
|
||||||
|
|
||||||
documents = []
|
|
||||||
query_embs: List[np.ndarray] = []
|
query_embs: List[np.ndarray] = []
|
||||||
for batch in self._get_batches(queries=queries, batch_size=batch_size):
|
for batch in self._get_batches(queries=queries, batch_size=batch_size):
|
||||||
query_embs.extend(self.embed_queries(queries=batch))
|
query_embs.extend(self.embed_queries(queries=batch))
|
||||||
for query_emb, cur_filters in tqdm(
|
documents = document_store.query_by_embedding_batch(
|
||||||
zip(query_embs, filters), total=len(query_embs), disable=not self.progress_bar, desc="Querying"
|
query_embs=query_embs, top_k=top_k, filters=filters, index=index, headers=headers, scale_score=scale_score
|
||||||
):
|
)
|
||||||
cur_docs = document_store.query_by_embedding(
|
|
||||||
query_emb=query_emb,
|
|
||||||
top_k=top_k,
|
|
||||||
filters=cur_filters,
|
|
||||||
index=index,
|
|
||||||
headers=headers,
|
|
||||||
scale_score=scale_score,
|
|
||||||
)
|
|
||||||
documents.append(cur_docs)
|
|
||||||
|
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
@ -1823,22 +1803,12 @@ class EmbeddingRetriever(DenseRetriever):
|
|||||||
if scale_score is None:
|
if scale_score is None:
|
||||||
scale_score = self.scale_score
|
scale_score = self.scale_score
|
||||||
|
|
||||||
documents = []
|
|
||||||
query_embs: List[np.ndarray] = []
|
query_embs: List[np.ndarray] = []
|
||||||
for batch in self._get_batches(queries=queries, batch_size=batch_size):
|
for batch in self._get_batches(queries=queries, batch_size=batch_size):
|
||||||
query_embs.extend(self.embed_queries(queries=batch))
|
query_embs.extend(self.embed_queries(queries=batch))
|
||||||
for query_emb, cur_filters in tqdm(
|
documents = document_store.query_by_embedding_batch(
|
||||||
zip(query_embs, filters), total=len(query_embs), disable=not self.progress_bar, desc="Querying"
|
query_embs=query_embs, top_k=top_k, filters=filters, index=index, headers=headers, scale_score=scale_score
|
||||||
):
|
)
|
||||||
cur_docs = document_store.query_by_embedding(
|
|
||||||
query_emb=query_emb,
|
|
||||||
top_k=top_k,
|
|
||||||
filters=cur_filters,
|
|
||||||
index=index,
|
|
||||||
headers=headers,
|
|
||||||
scale_score=scale_score,
|
|
||||||
)
|
|
||||||
documents.append(cur_docs)
|
|
||||||
|
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
@ -2301,22 +2271,22 @@ class MultihopEmbeddingRetriever(EmbeddingRetriever):
|
|||||||
for it in range(self.num_iterations):
|
for it in range(self.num_iterations):
|
||||||
texts = [self._merge_query_and_context(q, c) for q, c in zip(batch, context_docs)]
|
texts = [self._merge_query_and_context(q, c) for q, c in zip(batch, context_docs)]
|
||||||
query_embs = self.embed_queries(texts)
|
query_embs = self.embed_queries(texts)
|
||||||
for idx, emb in enumerate(query_embs):
|
cur_docs_batch = document_store.query_by_embedding_batch(
|
||||||
cur_docs = document_store.query_by_embedding(
|
query_embs=query_embs,
|
||||||
query_emb=emb,
|
top_k=top_k,
|
||||||
top_k=top_k,
|
filters=cur_filters,
|
||||||
filters=cur_filters,
|
index=index,
|
||||||
index=index,
|
headers=headers,
|
||||||
headers=headers,
|
scale_score=scale_score,
|
||||||
scale_score=scale_score,
|
)
|
||||||
)
|
if it < self.num_iterations - 1:
|
||||||
if it < self.num_iterations - 1:
|
# add doc with highest score to context
|
||||||
# add doc with highest score to context
|
for idx, cur_docs in enumerate(cur_docs_batch):
|
||||||
if len(cur_docs) > 0:
|
if len(cur_docs) > 0:
|
||||||
context_docs[idx].append(cur_docs[0])
|
context_docs[idx].append(cur_docs[0])
|
||||||
else:
|
else:
|
||||||
# documents in the last iteration are final results
|
# documents in the last iteration are final results
|
||||||
documents.append(cur_docs)
|
documents.extend(cur_docs_batch)
|
||||||
pb.update(len(batch))
|
pb.update(len(batch))
|
||||||
pb.close()
|
pb.close()
|
||||||
|
|
||||||
|
|||||||
@ -205,18 +205,15 @@ class MultiModalRetriever(DenseRetriever):
|
|||||||
query_embeddings = self.query_embedder.embed(documents=query_docs, batch_size=batch_size)
|
query_embeddings = self.query_embedder.embed(documents=query_docs, batch_size=batch_size)
|
||||||
|
|
||||||
# Query documents by embedding (the actual retrieval step)
|
# Query documents by embedding (the actual retrieval step)
|
||||||
documents = []
|
documents = document_store.query_by_embedding_batch(
|
||||||
for query_embedding, query_filters in zip(query_embeddings, filters_list):
|
query_embs=query_embeddings,
|
||||||
docs = document_store.query_by_embedding(
|
top_k=top_k,
|
||||||
query_emb=query_embedding,
|
filters=filters_list, # type: ignore
|
||||||
top_k=top_k,
|
index=index,
|
||||||
filters=query_filters,
|
headers=headers,
|
||||||
index=index,
|
scale_score=scale_score,
|
||||||
headers=headers,
|
)
|
||||||
scale_score=scale_score,
|
|
||||||
)
|
|
||||||
|
|
||||||
documents.append(docs)
|
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
def embed_documents(self, docs: List[Document]) -> np.ndarray:
|
def embed_documents(self, docs: List[Document]) -> np.ndarray:
|
||||||
|
|||||||
@ -685,15 +685,13 @@ class MostSimilarDocumentsPipeline(BaseStandardPipeline):
|
|||||||
:param top_k: How many documents id to return against single document
|
:param top_k: How many documents id to return against single document
|
||||||
:param index: Optionally specify the name of index to query the document from. If None, the DocumentStore's default index (self.index) will be used.
|
:param index: Optionally specify the name of index to query the document from. If None, the DocumentStore's default index (self.index) will be used.
|
||||||
"""
|
"""
|
||||||
similar_documents: list = []
|
|
||||||
self.document_store.return_embedding = True # type: ignore
|
self.document_store.return_embedding = True # type: ignore
|
||||||
|
|
||||||
for document in self.document_store.get_documents_by_id(ids=document_ids, index=index):
|
documents = self.document_store.get_documents_by_id(ids=document_ids, index=index)
|
||||||
similar_documents.append(
|
query_embs = [doc.embedding for doc in documents]
|
||||||
self.document_store.query_by_embedding(
|
similar_documents = self.document_store.query_by_embedding_batch(
|
||||||
query_emb=document.embedding, filters=filters, return_embedding=False, top_k=top_k, index=index
|
query_embs=query_embs, filters=filters, return_embedding=False, top_k=top_k, index=index
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
self.document_store.return_embedding = False # type: ignore
|
self.document_store.return_embedding = False # type: ignore
|
||||||
return similar_documents
|
return similar_documents
|
||||||
|
|||||||
@ -91,3 +91,14 @@ class TestInMemoryDocumentStore(DocumentStoreBaseTestAbstract):
|
|||||||
assert "A Foo Document" in docs[0][0].content
|
assert "A Foo Document" in docs[0][0].content
|
||||||
assert len(docs[1]) == 5
|
assert len(docs[1]) == 5
|
||||||
assert "A Bar Document" in docs[1][0].content
|
assert "A Bar Document" in docs[1][0].content
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_memory_query_by_embedding_batch(self, ds, documents):
|
||||||
|
documents = [doc for doc in documents if doc.embedding is not None]
|
||||||
|
ds.write_documents(documents)
|
||||||
|
query_embs = [doc.embedding for doc in documents]
|
||||||
|
docs_batch = ds.query_by_embedding_batch(query_embs=query_embs, top_k=5)
|
||||||
|
assert len(docs_batch) == 6
|
||||||
|
for docs, query_emb in zip(docs_batch, query_embs):
|
||||||
|
assert len(docs) == 5
|
||||||
|
assert (docs[0].embedding == query_emb).all()
|
||||||
|
|||||||
@ -178,6 +178,20 @@ class TestOpenSearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngineDoc
|
|||||||
)
|
)
|
||||||
assert len(results) == 3
|
assert len(results) == 3
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.parametrize("use_ann", [True, False])
|
||||||
|
def test_query_embedding_batch_with_filters(self, ds: OpenSearchDocumentStore, documents, use_ann):
|
||||||
|
ds.embeddings_field_supports_similarity = use_ann
|
||||||
|
ds.write_documents(documents)
|
||||||
|
results = ds.query_by_embedding_batch(
|
||||||
|
query_embs=[np.random.rand(768).astype(np.float32) for _ in range(2)],
|
||||||
|
filters=[{"year": "2020"} for _ in range(2)],
|
||||||
|
top_k=10,
|
||||||
|
)
|
||||||
|
assert len(results) == 2
|
||||||
|
for result in results:
|
||||||
|
assert len(result) == 3
|
||||||
|
|
||||||
# Unit tests
|
# Unit tests
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
@ -321,6 +335,13 @@ class TestOpenSearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngineDoc
|
|||||||
_, kwargs = mocked_document_store.client.search.call_args
|
_, kwargs = mocked_document_store.client.search.call_args
|
||||||
assert kwargs["body"]["_source"] == {"excludes": ["foo", "embedding"]}
|
assert kwargs["body"]["_source"] == {"excludes": ["foo", "embedding"]}
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_query_by_embedding_batch_uses_msearch(self, mocked_document_store):
|
||||||
|
mocked_document_store.query_by_embedding_batch([self.query_emb for _ in range(10)])
|
||||||
|
# assert the resulting body is consistent with the `excluded_meta_data` value
|
||||||
|
_, kwargs = mocked_document_store.client.msearch.call_args
|
||||||
|
assert len(kwargs["body"]) == 20 # each search has headers and request
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test__create_document_index_with_alias(self, mocked_document_store, caplog):
|
def test__create_document_index_with_alias(self, mocked_document_store, caplog):
|
||||||
mocked_document_store.client.indices.exists_alias.return_value = True
|
mocked_document_store.client.indices.exists_alias.return_value = True
|
||||||
|
|||||||
@ -809,6 +809,22 @@ def test_multimodal_text_retrieval(text_docs: List[Document]):
|
|||||||
assert results[0].content == "My name is Christelle and I live in Paris"
|
assert results[0].content == "My name is Christelle and I live in Paris"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_multimodal_text_retrieval_batch(text_docs: List[Document]):
|
||||||
|
retriever = MultiModalRetriever(
|
||||||
|
document_store=InMemoryDocumentStore(return_embedding=True),
|
||||||
|
query_embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
|
||||||
|
document_embedding_models={"text": "sentence-transformers/multi-qa-mpnet-base-dot-v1"},
|
||||||
|
)
|
||||||
|
retriever.document_store.write_documents(text_docs)
|
||||||
|
retriever.document_store.update_embeddings(retriever=retriever)
|
||||||
|
|
||||||
|
results = retriever.retrieve_batch(queries=["Who lives in Paris?", "Who lives in Berlin?", "Who lives in Madrid?"])
|
||||||
|
assert results[0][0].content == "My name is Christelle and I live in Paris"
|
||||||
|
assert results[1][0].content == "My name is Carla and I live in Berlin"
|
||||||
|
assert results[2][0].content == "My name is Camila and I live in Madrid"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
def test_multimodal_table_retrieval(table_docs: List[Document]):
|
def test_multimodal_table_retrieval(table_docs: List[Document]):
|
||||||
retriever = MultiModalRetriever(
|
retriever = MultiModalRetriever(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user