mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-24 17:30:38 +00:00
feat: Adding filters param to MostSimilarDocumentsPipeline run and run_batch (#3301)
* Adding filters param to MostSimilarDocumentsPipeline run and run_batch * Adding index param to MostSimilarDocumentsPipeline run and run_batch * Adding index param documentation to MostSimilarDocumentsPipeline run and run_batch * Updated index param documentation to MostSimilarDocumentsPipeline run and run_batch. Updated type: ignore in run_batch * Adding filters param to MostSimilarDocumentsPipeline run and run_batch * Adding index param to MostSimilarDocumentsPipeline run and run_batch * Adding index param documentation to MostSimilarDocumentsPipeline run and run_batch * Updated index param documentation to MostSimilarDocumentsPipeline run and run_batch. Updated type: ignore in run_batch
This commit is contained in:
parent
b84a6b1716
commit
797c20c966
@ -717,27 +717,43 @@ class MostSimilarDocumentsPipeline(BaseStandardPipeline):
|
||||
self.pipeline.add_node(component=document_store, name="DocumentStore", inputs=["Query"])
|
||||
self.document_store = document_store
|
||||
|
||||
def run(self, document_ids: List[str], top_k: int = 5):
|
||||
def run(
|
||||
self,
|
||||
document_ids: List[str],
|
||||
filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None,
|
||||
top_k: int = 5,
|
||||
index: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
:param document_ids: document ids
|
||||
:param filters: Optional filters to narrow down the search space to documents whose metadata fulfill certain conditions
|
||||
:param top_k: How many documents id to return against single document
|
||||
:param index: Optionally specify the name of index to query the document from. If None, the DocumentStore's default index (self.index) will be used.
|
||||
"""
|
||||
similar_documents: list = []
|
||||
self.document_store.return_embedding = True # type: ignore
|
||||
|
||||
for document in self.document_store.get_documents_by_id(ids=document_ids):
|
||||
for document in self.document_store.get_documents_by_id(ids=document_ids, index=index):
|
||||
similar_documents.append(
|
||||
self.document_store.query_by_embedding(
|
||||
query_emb=document.embedding, return_embedding=False, top_k=top_k
|
||||
query_emb=document.embedding, filters=filters, return_embedding=False, top_k=top_k, index=index
|
||||
)
|
||||
)
|
||||
|
||||
self.document_store.return_embedding = False # type: ignore
|
||||
return similar_documents
|
||||
|
||||
def run_batch(self, document_ids: List[str], top_k: int = 5): # type: ignore
|
||||
def run_batch( # type: ignore
|
||||
self,
|
||||
document_ids: List[str],
|
||||
filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None,
|
||||
top_k: int = 5,
|
||||
index: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
:param document_ids: document ids
|
||||
:param filters: Optional filters to narrow down the search space to documents whose metadata fulfill certain conditions
|
||||
:param top_k: How many documents id to return against single document
|
||||
:param index: Optionally specify the name of index to query the document from. If None, the DocumentStore's default index (self.index) will be used.
|
||||
"""
|
||||
return self.run(document_ids=document_ids, top_k=top_k)
|
||||
return self.run(document_ids=document_ids, filters=filters, top_k=top_k, index=index)
|
||||
|
@ -200,6 +200,39 @@ def test_most_similar_documents_pipeline(retriever, document_store):
|
||||
assert isinstance(document.content, str)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"retriever,document_store", [("embedding", "milvus1"), ("embedding", "elasticsearch")], indirect=True
|
||||
)
|
||||
def test_most_similar_documents_pipeline_with_filters(retriever, document_store):
|
||||
documents = [
|
||||
{"id": "a", "content": "Sample text for document-1", "meta": {"source": "wiki1"}},
|
||||
{"id": "b", "content": "Sample text for document-2", "meta": {"source": "wiki2"}},
|
||||
{"content": "Sample text for document-3", "meta": {"source": "wiki3"}},
|
||||
{"content": "Sample text for document-4", "meta": {"source": "wiki4"}},
|
||||
{"content": "Sample text for document-5", "meta": {"source": "wiki5"}},
|
||||
]
|
||||
|
||||
document_store.write_documents(documents)
|
||||
document_store.update_embeddings(retriever)
|
||||
|
||||
docs_id: list = ["a", "b"]
|
||||
filters = {"source": ["wiki3", "wiki4", "wiki5"]}
|
||||
pipeline = MostSimilarDocumentsPipeline(document_store=document_store)
|
||||
list_of_documents = pipeline.run(document_ids=docs_id, filters=filters)
|
||||
|
||||
assert len(list_of_documents[0]) > 1
|
||||
assert isinstance(list_of_documents, list)
|
||||
assert len(list_of_documents) == len(docs_id)
|
||||
|
||||
for another_list in list_of_documents:
|
||||
assert isinstance(another_list, list)
|
||||
for document in another_list:
|
||||
assert isinstance(document, Document)
|
||||
assert isinstance(document.id, str)
|
||||
assert isinstance(document.content, str)
|
||||
assert document.meta["source"] in ["wiki3", "wiki4", "wiki5"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("retriever,document_store", [("embedding", "memory")], indirect=True)
|
||||
def test_most_similar_documents_pipeline_batch(retriever, document_store):
|
||||
documents = [
|
||||
@ -229,6 +262,37 @@ def test_most_similar_documents_pipeline_batch(retriever, document_store):
|
||||
assert isinstance(document.content, str)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("retriever,document_store", [("embedding", "memory")], indirect=True)
|
||||
def test_most_similar_documents_pipeline_with_filters_batch(retriever, document_store):
|
||||
documents = [
|
||||
{"id": "a", "content": "Sample text for document-1", "meta": {"source": "wiki1"}},
|
||||
{"id": "b", "content": "Sample text for document-2", "meta": {"source": "wiki2"}},
|
||||
{"content": "Sample text for document-3", "meta": {"source": "wiki3"}},
|
||||
{"content": "Sample text for document-4", "meta": {"source": "wiki4"}},
|
||||
{"content": "Sample text for document-5", "meta": {"source": "wiki5"}},
|
||||
]
|
||||
|
||||
document_store.write_documents(documents)
|
||||
document_store.update_embeddings(retriever)
|
||||
|
||||
docs_id: list = ["a", "b"]
|
||||
filters = {"source": ["wiki3", "wiki4", "wiki5"]}
|
||||
pipeline = MostSimilarDocumentsPipeline(document_store=document_store)
|
||||
list_of_documents = pipeline.run_batch(document_ids=docs_id, filters=filters)
|
||||
|
||||
assert len(list_of_documents[0]) > 1
|
||||
assert isinstance(list_of_documents, list)
|
||||
assert len(list_of_documents) == len(docs_id)
|
||||
|
||||
for another_list in list_of_documents:
|
||||
assert isinstance(another_list, list)
|
||||
for document in another_list:
|
||||
assert isinstance(document, Document)
|
||||
assert isinstance(document.id, str)
|
||||
assert isinstance(document.content, str)
|
||||
assert document.meta["source"] in ["wiki3", "wiki4", "wiki5"]
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
def test_most_similar_documents_pipeline_save(tmpdir, document_store_with_docs):
|
||||
|
Loading…
x
Reference in New Issue
Block a user