From 9ca3ccae987e62743b3cf7fb66cb4671a7390fff Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 23 Sep 2022 15:46:48 +0200 Subject: [PATCH] fix:MostSimilarDocumentsPipeline doesn't have pipeline property (#3265) * Add comments and a unit test * More unit tests for MostSimilarDocumentsPipeline --- haystack/pipelines/standard_pipelines.py | 6 ++++++ test/pipelines/test_pipeline.py | 4 ++-- test/pipelines/test_standard_pipelines.py | 9 +++++++++ 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/haystack/pipelines/standard_pipelines.py b/haystack/pipelines/standard_pipelines.py index 81d7a5453..6275b0272 100644 --- a/haystack/pipelines/standard_pipelines.py +++ b/haystack/pipelines/standard_pipelines.py @@ -706,6 +706,12 @@ class MostSimilarDocumentsPipeline(BaseStandardPipeline): :param document_store: Document Store instance with already stored embeddings. """ + # we create a pipeline and add the document store as a node + # however, we do not want to use the document store's run method, + # but rather the query_by_embedding method + # pipeline property is here so the superclass methods that rely on pipeline property work + self.pipeline = Pipeline() + self.pipeline.add_node(component=document_store, name="DocumentStore", inputs=["Query"]) self.document_store = document_store def run(self, document_ids: List[str], top_k: int = 5): diff --git a/test/pipelines/test_pipeline.py b/test/pipelines/test_pipeline.py index d9027a424..08a930b1b 100644 --- a/test/pipelines/test_pipeline.py +++ b/test/pipelines/test_pipeline.py @@ -813,8 +813,8 @@ def test_pipeline_classify_type(): ) pipe.get_type().startswith("TranslationWrapperPipeline") - # pipe = MostSimilarDocumentsPipeline(document_store=MockDocumentStore()) - # assert pipe.get_type().startswith("MostSimilarDocumentsPipeline") + pipe = MostSimilarDocumentsPipeline(document_store=MockDocumentStore()) + assert pipe.get_type().startswith("MostSimilarDocumentsPipeline") @pytest.mark.usefixtures(deepset_cloud_fixture.__name__) diff --git a/test/pipelines/test_standard_pipelines.py b/test/pipelines/test_standard_pipelines.py index 57e59665a..21034f22e 100644 --- a/test/pipelines/test_standard_pipelines.py +++ b/test/pipelines/test_standard_pipelines.py @@ -229,6 +229,15 @@ def test_most_similar_documents_pipeline_batch(retriever, document_store): assert isinstance(document.content, str) +@pytest.mark.integration +@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True) +def test_most_similar_documents_pipeline_save(tmpdir, document_store_with_docs): + pipeline = MostSimilarDocumentsPipeline(document_store=document_store_with_docs) + path = Path(tmpdir, "most_similar_document_pipeline.yml") + pipeline.save_to_yaml(path) + os.path.exists(path) + + @pytest.mark.elasticsearch @pytest.mark.parametrize("document_store_dot_product_with_docs", ["elasticsearch"], indirect=True) def test_join_merge_no_weights(document_store_dot_product_with_docs):