diff --git a/haystack/document_store/memory.py b/haystack/document_store/memory.py index 9814e7888..1ebbf9645 100644 --- a/haystack/document_store/memory.py +++ b/haystack/document_store/memory.py @@ -107,11 +107,6 @@ class InMemoryDocumentStore(BaseDocumentStore): from numpy import dot from numpy.linalg import norm - if filters: - raise NotImplementedError("Setting `filters` is currently not supported in " - "InMemoryDocumentStore.query_by_embedding(). Please remove filters or " - "use a different DocumentStore (e.g. ElasticsearchDocumentStore).") - index = index or self.index if return_embedding is None: return_embedding = self.return_embedding @@ -119,8 +114,9 @@ class InMemoryDocumentStore(BaseDocumentStore): if query_emb is None: return [] + document_to_search = self.get_all_documents(index=index, filters=filters, return_embedding=True) candidate_docs = [] - for idx, doc in self.indexes[index].items(): + for doc in document_to_search: curr_meta = deepcopy(doc.meta) new_document = Document( id=doc.id, @@ -185,7 +181,7 @@ class InMemoryDocumentStore(BaseDocumentStore): """ index = index or self.label_index return len(self.indexes[index].items()) - + def get_all_documents( self, index: Optional[str] = None, diff --git a/test/test_dpr_retriever.py b/test/test_dpr_retriever.py index 685ef191b..9433ed4dc 100644 --- a/test/test_dpr_retriever.py +++ b/test/test_dpr_retriever.py @@ -3,6 +3,7 @@ import time import numpy as np from haystack import Document +from haystack.document_store.faiss import FAISSDocumentStore from haystack.retriever.dense import DensePassageRetriever from transformers import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast @@ -62,6 +63,13 @@ def test_dpr_retrieval(document_store, retriever, return_embedding): else: assert res[0].embedding is None + # test filtering + if not isinstance(document_store, FAISSDocumentStore): + res = retriever.retrieve(query="Which philosopher attacked Schopenhauer?", filters={"name": ["0", "2"]}) + assert len(res) == 2 + for r in res: + assert r.meta["name"] in ["0", "2"] + @pytest.mark.parametrize("retriever", ["dpr"], indirect=True) @pytest.mark.parametrize("document_store", ["memory"], indirect=True)