mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-08 04:56:45 +00:00
Support filters for DensePassageRetriever + InMemoryDocumentStore (#754)
This commit is contained in:
parent
35dcf23a4b
commit
aa8a3666c3
@ -107,11 +107,6 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
from numpy import dot
|
||||
from numpy.linalg import norm
|
||||
|
||||
if filters:
|
||||
raise NotImplementedError("Setting `filters` is currently not supported in "
|
||||
"InMemoryDocumentStore.query_by_embedding(). Please remove filters or "
|
||||
"use a different DocumentStore (e.g. ElasticsearchDocumentStore).")
|
||||
|
||||
index = index or self.index
|
||||
if return_embedding is None:
|
||||
return_embedding = self.return_embedding
|
||||
@ -119,8 +114,9 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
if query_emb is None:
|
||||
return []
|
||||
|
||||
document_to_search = self.get_all_documents(index=index, filters=filters, return_embedding=True)
|
||||
candidate_docs = []
|
||||
for idx, doc in self.indexes[index].items():
|
||||
for doc in document_to_search:
|
||||
curr_meta = deepcopy(doc.meta)
|
||||
new_document = Document(
|
||||
id=doc.id,
|
||||
@ -185,7 +181,7 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
"""
|
||||
index = index or self.label_index
|
||||
return len(self.indexes[index].items())
|
||||
|
||||
|
||||
def get_all_documents(
|
||||
self,
|
||||
index: Optional[str] = None,
|
||||
|
||||
@ -3,6 +3,7 @@ import time
|
||||
import numpy as np
|
||||
|
||||
from haystack import Document
|
||||
from haystack.document_store.faiss import FAISSDocumentStore
|
||||
from haystack.retriever.dense import DensePassageRetriever
|
||||
|
||||
from transformers import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast
|
||||
@ -62,6 +63,13 @@ def test_dpr_retrieval(document_store, retriever, return_embedding):
|
||||
else:
|
||||
assert res[0].embedding is None
|
||||
|
||||
# test filtering
|
||||
if not isinstance(document_store, FAISSDocumentStore):
|
||||
res = retriever.retrieve(query="Which philosopher attacked Schopenhauer?", filters={"name": ["0", "2"]})
|
||||
assert len(res) == 2
|
||||
for r in res:
|
||||
assert r.meta["name"] in ["0", "2"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user