feat: add query_by_embedding_batch (#3546)

* add query_by_embedding_batch

* fix mypy

* fix pylint

* add test

* move query_by_embedding_batch to search_engine

* fix and add tests

* fix pylint

* remove Retriever query logs

* add test for multimodal batch retrieval

* allow for np.ndarray
This commit is contained in:
tstadel 2022-12-08 08:28:43 +01:00 committed by GitHub
parent 25bf95d47f
commit c1c1c97bb2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 329 additions and 110 deletions

View File

@ -12,7 +12,7 @@ import numpy as np
from haystack.schema import Document, Label, MultiLabel
from haystack.nodes.base import BaseComponent
from haystack.errors import DuplicateDocumentError, DocumentStoreError
from haystack.errors import DuplicateDocumentError, DocumentStoreError, HaystackError
from haystack.nodes.preprocessor import PreProcessor
from haystack.document_stores.utils import eval_data_from_json, eval_data_from_jsonl, squad_json_to_jsonl
from haystack.utils.labels import aggregate_labels
@ -359,6 +359,44 @@ class BaseDocumentStore(BaseComponent):
) -> List[Document]:
pass
def query_by_embedding_batch(
self,
query_embs: Union[List[np.ndarray], np.ndarray],
filters: Optional[
Union[
Dict[str, Union[Dict, List, str, int, float, bool]],
List[Dict[str, Union[Dict, List, str, int, float, bool]]],
]
] = None,
top_k: int = 10,
index: Optional[str] = None,
return_embedding: Optional[bool] = None,
headers: Optional[Dict[str, str]] = None,
scale_score: bool = True,
) -> List[List[Document]]:
if isinstance(filters, list):
if len(filters) != len(query_embs):
raise HaystackError(
"Number of filters does not match number of query_embs. Please provide as many filters"
" as query_embs or a single filter that will be applied to each query_emb."
)
else:
filters = [filters] * len(query_embs) if filters is not None else [{}] * len(query_embs)
results = []
for query_emb, filter in zip(query_embs, filters):
results.append(
self.query_by_embedding(
query_emb=query_emb,
filters=filter,
top_k=top_k,
index=index,
return_embedding=return_embedding,
headers=headers,
scale_score=scale_score,
)
)
return results
@abstractmethod
def get_label_count(self, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None) -> int:
pass

View File

@ -377,29 +377,7 @@ class ElasticsearchDocumentStore(SearchEngineDocumentStore):
if not self.embedding_field:
raise RuntimeError("Please specify arg `embedding_field` in ElasticsearchDocumentStore()")
# +1 in similarity to avoid negative numbers (for cosine sim)
body = {"size": top_k, "query": self._get_vector_similarity_query(query_emb, top_k)}
if filters:
filter_ = {"bool": {"filter": LogicalFilterClause.parse(filters).convert_to_elasticsearch()}}
if body["query"]["script_score"]["query"] == {"match_all": {}}:
body["query"]["script_score"]["query"] = filter_
else:
body["query"]["script_score"]["query"]["bool"]["filter"]["bool"]["must"].append(filter_)
excluded_meta_data: Optional[list] = None
if self.excluded_meta_data:
excluded_meta_data = deepcopy(self.excluded_meta_data)
if return_embedding is True and self.embedding_field in excluded_meta_data:
excluded_meta_data.remove(self.embedding_field)
elif return_embedding is False and self.embedding_field not in excluded_meta_data:
excluded_meta_data.append(self.embedding_field)
elif return_embedding is False:
excluded_meta_data = [self.embedding_field]
if excluded_meta_data:
body["_source"] = {"excludes": excluded_meta_data}
body = self._construct_dense_query_body(query_emb, filters, top_k, return_embedding)
logger.debug("Retriever query: %s", body)
try:
@ -428,6 +406,37 @@ class ElasticsearchDocumentStore(SearchEngineDocumentStore):
]
return documents
def _construct_dense_query_body(
self,
query_emb: np.ndarray,
filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None,
top_k: int = 10,
return_embedding: Optional[bool] = None,
):
body = {"size": top_k, "query": self._get_vector_similarity_query(query_emb, top_k)}
if filters:
filter_ = {"bool": {"filter": LogicalFilterClause.parse(filters).convert_to_elasticsearch()}}
if body["query"]["script_score"]["query"] == {"match_all": {}}:
body["query"]["script_score"]["query"] = filter_
else:
body["query"]["script_score"]["query"]["bool"]["filter"]["bool"]["must"].append(filter_)
excluded_meta_data: Optional[list] = None
if self.excluded_meta_data:
excluded_meta_data = deepcopy(self.excluded_meta_data)
if return_embedding is True and self.embedding_field in excluded_meta_data:
excluded_meta_data.remove(self.embedding_field)
elif return_embedding is False and self.embedding_field not in excluded_meta_data:
excluded_meta_data.append(self.embedding_field)
elif return_embedding is False:
excluded_meta_data = [self.embedding_field]
if excluded_meta_data:
body["_source"] = {"excludes": excluded_meta_data}
return body
def _create_document_index(self, index_name: str, headers: Optional[Dict[str, str]] = None):
"""
Create a new index for storing documents. In case if an index with the name already exists, it ensures that

View File

@ -439,7 +439,28 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
if not self.embedding_field:
raise DocumentStoreError("Please set a valid `embedding_field` for OpenSearchDocumentStore")
# +1 in similarity to avoid negative numbers (for cosine sim)
body = self._construct_dense_query_body(
query_emb=query_emb, filters=filters, top_k=top_k, return_embedding=return_embedding
)
logger.debug("Retriever query: %s", body)
result = self.client.search(index=index, body=body, request_timeout=300, headers=headers)["hits"]["hits"]
documents = [
self._convert_es_hit_to_document(
hit, adapt_score_for_embedding=True, return_embedding=return_embedding, scale_score=scale_score
)
for hit in result
]
return documents
def _construct_dense_query_body(
self,
query_emb: np.ndarray,
filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None,
top_k: int = 10,
return_embedding: Optional[bool] = None,
):
body: Dict[str, Any] = {"size": top_k, "query": self._get_vector_similarity_query(query_emb, top_k)}
if filters:
filter_ = LogicalFilterClause.parse(filters).convert_to_elasticsearch()
@ -450,7 +471,6 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
body["query"]["bool"]["filter"] = filter_
excluded_meta_data: Optional[list] = None
if self.excluded_meta_data:
excluded_meta_data = deepcopy(self.excluded_meta_data)
@ -463,17 +483,7 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
if excluded_meta_data:
body["_source"] = {"excludes": excluded_meta_data}
logger.debug("Retriever query: %s", body)
result = self.client.search(index=index, body=body, request_timeout=300, headers=headers)["hits"]["hits"]
documents = [
self._convert_es_hit_to_document(
hit, adapt_score_for_embedding=True, return_embedding=return_embedding, scale_score=scale_score
)
for hit in result
]
return documents
return body
def _create_document_index(self, index_name: str, headers: Optional[Dict[str, str]] = None):
"""

View File

@ -1,3 +1,5 @@
# pylint: disable=too-many-public-methods
from typing import List, Optional, Union, Dict, Any, Generator
from abc import abstractmethod
import json
@ -873,7 +875,6 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
all_terms_must_match=all_terms_must_match,
)
logger.debug("Retriever query: %s", body)
result = self.client.search(index=index, body=body, headers=headers)["hits"]["hits"]
documents = [
@ -1012,7 +1013,6 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
body.append(headers)
body.append(cur_query_body)
logger.debug("Retriever query: %s", body)
responses = self.client.msearch(index=index, body=body)
all_documents = []
@ -1142,6 +1142,155 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
) from e
return document
def query_by_embedding_batch(
self,
query_embs: Union[List[np.ndarray], np.ndarray],
filters: Optional[
Union[
Dict[str, Union[Dict, List, str, int, float, bool]],
List[Dict[str, Union[Dict, List, str, int, float, bool]]],
]
] = None,
top_k: int = 10,
index: Optional[str] = None,
return_embedding: Optional[bool] = None,
headers: Optional[Dict[str, str]] = None,
scale_score: bool = True,
) -> List[List[Document]]:
"""
Find the documents that are most similar to the provided `query_embs` by using a vector similarity metric.
:param query_embs: Embeddings of the queries (e.g. gathered from DPR).
Can be a list of one-dimensional numpy arrays or a two-dimensional numpy array.
:param filters: Optional filters to narrow down the search space to documents whose metadata fulfill certain
conditions.
Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical
operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`,
`"$gte"`, `"$lt"`, `"$lte"`) or a metadata field name.
Logical operator keys take a dictionary of metadata field names and/or logical operators as
value. Metadata field names take a dictionary of comparison operators as value. Comparison
operator keys take a single value or (in case of `"$in"`) a list of values as value.
If no logical operator is provided, `"$and"` is used as default operation. If no comparison
operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default
operation.
__Example__:
```python
filters = {
"$and": {
"type": {"$eq": "article"},
"date": {"$gte": "2015-01-01", "$lt": "2021-01-01"},
"rating": {"$gte": 3},
"$or": {
"genre": {"$in": ["economy", "politics"]},
"publisher": {"$eq": "nytimes"}
}
}
}
# or simpler using default operators
filters = {
"type": "article",
"date": {"$gte": "2015-01-01", "$lt": "2021-01-01"},
"rating": {"$gte": 3},
"$or": {
"genre": ["economy", "politics"],
"publisher": "nytimes"
}
}
```
To use the same logical operator multiple times on the same level, logical operators take
optionally a list of dictionaries as value.
__Example__:
```python
filters = {
"$or": [
{
"$and": {
"Type": "News Paper",
"Date": {
"$lt": "2019-01-01"
}
}
},
{
"$and": {
"Type": "Blog Post",
"Date": {
"$gte": "2019-01-01"
}
}
}
]
}
```
:param top_k: How many documents to return
:param index: Index name for storing the docs and metadata
:param return_embedding: To return document embedding
:param headers: Custom HTTP headers to pass to elasticsearch client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='})
Check out https://www.elastic.co/guide/en/elasticsearch/reference/current/http-clients.html for more information.
:param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]).
If true (default) similarity scores (e.g. cosine or dot_product) which naturally have a different value range will be scaled to a range of [0,1], where 1 means extremely relevant.
Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
:return:
"""
if index is None:
index = self.index
if return_embedding is None:
return_embedding = self.return_embedding
if headers is None:
headers = {}
if not self.embedding_field:
raise DocumentStoreError("Please set a valid `embedding_field` for OpenSearchDocumentStore")
if isinstance(filters, list):
if len(filters) != len(query_embs):
raise HaystackError(
"Number of filters does not match number of query_embs. Please provide as many filters"
" as query_embs or a single filter that will be applied to each query_emb."
)
else:
filters = [filters] * len(query_embs) if filters is not None else [{}] * len(query_embs)
body = []
for query_emb, cur_filters in zip(query_embs, filters):
cur_query_body = self._construct_dense_query_body(
query_emb=query_emb, filters=cur_filters, top_k=top_k, return_embedding=return_embedding
)
body.append(headers)
body.append(cur_query_body)
logger.debug("Retriever query: %s", body)
responses = self.client.msearch(index=index, body=body)
all_documents = []
cur_documents = []
for response in responses["responses"]:
cur_result = response["hits"]["hits"]
cur_documents = [
self._convert_es_hit_to_document(
hit, adapt_score_for_embedding=True, return_embedding=self.return_embedding, scale_score=scale_score
)
for hit in cur_result
]
all_documents.append(cur_documents)
return all_documents
@abstractmethod
def _construct_dense_query_body(
self,
query_emb: np.ndarray,
filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None,
top_k: int = 10,
return_embedding: Optional[bool] = None,
):
pass
def update_embeddings(
self,
retriever: DenseRetriever,

View File

@ -467,22 +467,12 @@ class DensePassageRetriever(DenseRetriever):
if scale_score is None:
scale_score = self.scale_score
documents = []
query_embs: List[np.ndarray] = []
for batch in self._get_batches(queries=queries, batch_size=batch_size):
query_embs.extend(self.embed_queries(queries=batch))
for query_emb, cur_filters in tqdm(
zip(query_embs, filters), total=len(query_embs), disable=not self.progress_bar, desc="Querying"
):
cur_docs = document_store.query_by_embedding(
query_emb=query_emb,
top_k=top_k,
filters=cur_filters,
index=index,
headers=headers,
scale_score=scale_score,
documents = document_store.query_by_embedding_batch(
query_embs=query_embs, top_k=top_k, filters=filters, index=index, headers=headers, scale_score=scale_score
)
documents.append(cur_docs)
return documents
@ -1111,22 +1101,12 @@ class TableTextRetriever(DenseRetriever):
if scale_score is None:
scale_score = self.scale_score
documents = []
query_embs: List[np.ndarray] = []
for batch in self._get_batches(queries=queries, batch_size=batch_size):
query_embs.extend(self.embed_queries(queries=batch))
for query_emb, cur_filters in tqdm(
zip(query_embs, filters), total=len(query_embs), disable=not self.progress_bar, desc="Querying"
):
cur_docs = document_store.query_by_embedding(
query_emb=query_emb,
top_k=top_k,
filters=cur_filters,
index=index,
headers=headers,
scale_score=scale_score,
documents = document_store.query_by_embedding_batch(
query_embs=query_embs, top_k=top_k, filters=filters, index=index, headers=headers, scale_score=scale_score
)
documents.append(cur_docs)
return documents
@ -1823,22 +1803,12 @@ class EmbeddingRetriever(DenseRetriever):
if scale_score is None:
scale_score = self.scale_score
documents = []
query_embs: List[np.ndarray] = []
for batch in self._get_batches(queries=queries, batch_size=batch_size):
query_embs.extend(self.embed_queries(queries=batch))
for query_emb, cur_filters in tqdm(
zip(query_embs, filters), total=len(query_embs), disable=not self.progress_bar, desc="Querying"
):
cur_docs = document_store.query_by_embedding(
query_emb=query_emb,
top_k=top_k,
filters=cur_filters,
index=index,
headers=headers,
scale_score=scale_score,
documents = document_store.query_by_embedding_batch(
query_embs=query_embs, top_k=top_k, filters=filters, index=index, headers=headers, scale_score=scale_score
)
documents.append(cur_docs)
return documents
@ -2301,9 +2271,8 @@ class MultihopEmbeddingRetriever(EmbeddingRetriever):
for it in range(self.num_iterations):
texts = [self._merge_query_and_context(q, c) for q, c in zip(batch, context_docs)]
query_embs = self.embed_queries(texts)
for idx, emb in enumerate(query_embs):
cur_docs = document_store.query_by_embedding(
query_emb=emb,
cur_docs_batch = document_store.query_by_embedding_batch(
query_embs=query_embs,
top_k=top_k,
filters=cur_filters,
index=index,
@ -2312,11 +2281,12 @@ class MultihopEmbeddingRetriever(EmbeddingRetriever):
)
if it < self.num_iterations - 1:
# add doc with highest score to context
for idx, cur_docs in enumerate(cur_docs_batch):
if len(cur_docs) > 0:
context_docs[idx].append(cur_docs[0])
else:
# documents in the last iteration are final results
documents.append(cur_docs)
documents.extend(cur_docs_batch)
pb.update(len(batch))
pb.close()

View File

@ -205,18 +205,15 @@ class MultiModalRetriever(DenseRetriever):
query_embeddings = self.query_embedder.embed(documents=query_docs, batch_size=batch_size)
# Query documents by embedding (the actual retrieval step)
documents = []
for query_embedding, query_filters in zip(query_embeddings, filters_list):
docs = document_store.query_by_embedding(
query_emb=query_embedding,
documents = document_store.query_by_embedding_batch(
query_embs=query_embeddings,
top_k=top_k,
filters=query_filters,
filters=filters_list, # type: ignore
index=index,
headers=headers,
scale_score=scale_score,
)
documents.append(docs)
return documents
def embed_documents(self, docs: List[Document]) -> np.ndarray:

View File

@ -685,14 +685,12 @@ class MostSimilarDocumentsPipeline(BaseStandardPipeline):
:param top_k: How many documents id to return against single document
:param index: Optionally specify the name of index to query the document from. If None, the DocumentStore's default index (self.index) will be used.
"""
similar_documents: list = []
self.document_store.return_embedding = True # type: ignore
for document in self.document_store.get_documents_by_id(ids=document_ids, index=index):
similar_documents.append(
self.document_store.query_by_embedding(
query_emb=document.embedding, filters=filters, return_embedding=False, top_k=top_k, index=index
)
documents = self.document_store.get_documents_by_id(ids=document_ids, index=index)
query_embs = [doc.embedding for doc in documents]
similar_documents = self.document_store.query_by_embedding_batch(
query_embs=query_embs, filters=filters, return_embedding=False, top_k=top_k, index=index
)
self.document_store.return_embedding = False # type: ignore

View File

@ -91,3 +91,14 @@ class TestInMemoryDocumentStore(DocumentStoreBaseTestAbstract):
assert "A Foo Document" in docs[0][0].content
assert len(docs[1]) == 5
assert "A Bar Document" in docs[1][0].content
@pytest.mark.integration
def test_memory_query_by_embedding_batch(self, ds, documents):
documents = [doc for doc in documents if doc.embedding is not None]
ds.write_documents(documents)
query_embs = [doc.embedding for doc in documents]
docs_batch = ds.query_by_embedding_batch(query_embs=query_embs, top_k=5)
assert len(docs_batch) == 6
for docs, query_emb in zip(docs_batch, query_embs):
assert len(docs) == 5
assert (docs[0].embedding == query_emb).all()

View File

@ -178,6 +178,20 @@ class TestOpenSearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngineDoc
)
assert len(results) == 3
@pytest.mark.integration
@pytest.mark.parametrize("use_ann", [True, False])
def test_query_embedding_batch_with_filters(self, ds: OpenSearchDocumentStore, documents, use_ann):
ds.embeddings_field_supports_similarity = use_ann
ds.write_documents(documents)
results = ds.query_by_embedding_batch(
query_embs=[np.random.rand(768).astype(np.float32) for _ in range(2)],
filters=[{"year": "2020"} for _ in range(2)],
top_k=10,
)
assert len(results) == 2
for result in results:
assert len(result) == 3
# Unit tests
@pytest.mark.unit
@ -321,6 +335,13 @@ class TestOpenSearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngineDoc
_, kwargs = mocked_document_store.client.search.call_args
assert kwargs["body"]["_source"] == {"excludes": ["foo", "embedding"]}
@pytest.mark.unit
def test_query_by_embedding_batch_uses_msearch(self, mocked_document_store):
mocked_document_store.query_by_embedding_batch([self.query_emb for _ in range(10)])
# assert the resulting body is consistent with the `excluded_meta_data` value
_, kwargs = mocked_document_store.client.msearch.call_args
assert len(kwargs["body"]) == 20 # each search has headers and request
@pytest.mark.unit
def test__create_document_index_with_alias(self, mocked_document_store, caplog):
mocked_document_store.client.indices.exists_alias.return_value = True

View File

@ -809,6 +809,22 @@ def test_multimodal_text_retrieval(text_docs: List[Document]):
assert results[0].content == "My name is Christelle and I live in Paris"
@pytest.mark.integration
def test_multimodal_text_retrieval_batch(text_docs: List[Document]):
retriever = MultiModalRetriever(
document_store=InMemoryDocumentStore(return_embedding=True),
query_embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
document_embedding_models={"text": "sentence-transformers/multi-qa-mpnet-base-dot-v1"},
)
retriever.document_store.write_documents(text_docs)
retriever.document_store.update_embeddings(retriever=retriever)
results = retriever.retrieve_batch(queries=["Who lives in Paris?", "Who lives in Berlin?", "Who lives in Madrid?"])
assert results[0][0].content == "My name is Christelle and I live in Paris"
assert results[1][0].content == "My name is Carla and I live in Berlin"
assert results[2][0].content == "My name is Camila and I live in Madrid"
@pytest.mark.integration
def test_multimodal_table_retrieval(table_docs: List[Document]):
retriever = MultiModalRetriever(