Support filters for DensePassageRetriever + InMemoryDocumentStore (#754)

This commit is contained in:
Tanay Soni 2021-01-20 12:52:52 +01:00 committed by GitHub
parent 35dcf23a4b
commit aa8a3666c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 7 deletions

View File

@ -107,11 +107,6 @@ class InMemoryDocumentStore(BaseDocumentStore):
from numpy import dot
from numpy.linalg import norm
if filters:
raise NotImplementedError("Setting `filters` is currently not supported in "
"InMemoryDocumentStore.query_by_embedding(). Please remove filters or "
"use a different DocumentStore (e.g. ElasticsearchDocumentStore).")
index = index or self.index
if return_embedding is None:
return_embedding = self.return_embedding
@ -119,8 +114,9 @@ class InMemoryDocumentStore(BaseDocumentStore):
if query_emb is None:
return []
document_to_search = self.get_all_documents(index=index, filters=filters, return_embedding=True)
candidate_docs = []
for idx, doc in self.indexes[index].items():
for doc in document_to_search:
curr_meta = deepcopy(doc.meta)
new_document = Document(
id=doc.id,
@ -185,7 +181,7 @@ class InMemoryDocumentStore(BaseDocumentStore):
"""
index = index or self.label_index
return len(self.indexes[index].items())
def get_all_documents(
self,
index: Optional[str] = None,

View File

@ -3,6 +3,7 @@ import time
import numpy as np
from haystack import Document
from haystack.document_store.faiss import FAISSDocumentStore
from haystack.retriever.dense import DensePassageRetriever
from transformers import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast
@ -62,6 +63,13 @@ def test_dpr_retrieval(document_store, retriever, return_embedding):
else:
assert res[0].embedding is None
# test filtering
if not isinstance(document_store, FAISSDocumentStore):
res = retriever.retrieve(query="Which philosopher attacked Schopenhauer?", filters={"name": ["0", "2"]})
assert len(res) == 2
for r in res:
assert r.meta["name"] in ["0", "2"]
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)