fix:MostSimilarDocumentsPipeline doesn't have pipeline property (#3265)

* Add comments and a unit test

* More unit tests for MostSimilarDocumentsPipeline
This commit is contained in:
Vladimir Blagojevic 2022-09-23 15:46:48 +02:00 committed by GitHub
parent eba7cf51b1
commit 9ca3ccae98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 17 additions and 2 deletions

View File

@ -706,6 +706,12 @@ class MostSimilarDocumentsPipeline(BaseStandardPipeline):
:param document_store: Document Store instance with already stored embeddings. :param document_store: Document Store instance with already stored embeddings.
""" """
# we create a pipeline and add the document store as a node
# however, we do not want to use the document store's run method,
# but rather the query_by_embedding method
# pipeline property is here so the superclass methods that rely on pipeline property work
self.pipeline = Pipeline()
self.pipeline.add_node(component=document_store, name="DocumentStore", inputs=["Query"])
self.document_store = document_store self.document_store = document_store
def run(self, document_ids: List[str], top_k: int = 5): def run(self, document_ids: List[str], top_k: int = 5):

View File

@ -813,8 +813,8 @@ def test_pipeline_classify_type():
) )
pipe.get_type().startswith("TranslationWrapperPipeline") pipe.get_type().startswith("TranslationWrapperPipeline")
# pipe = MostSimilarDocumentsPipeline(document_store=MockDocumentStore()) pipe = MostSimilarDocumentsPipeline(document_store=MockDocumentStore())
# assert pipe.get_type().startswith("MostSimilarDocumentsPipeline") assert pipe.get_type().startswith("MostSimilarDocumentsPipeline")
@pytest.mark.usefixtures(deepset_cloud_fixture.__name__) @pytest.mark.usefixtures(deepset_cloud_fixture.__name__)

View File

@ -229,6 +229,15 @@ def test_most_similar_documents_pipeline_batch(retriever, document_store):
assert isinstance(document.content, str) assert isinstance(document.content, str)
@pytest.mark.integration
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
def test_most_similar_documents_pipeline_save(tmpdir, document_store_with_docs):
pipeline = MostSimilarDocumentsPipeline(document_store=document_store_with_docs)
path = Path(tmpdir, "most_similar_document_pipeline.yml")
pipeline.save_to_yaml(path)
os.path.exists(path)
@pytest.mark.elasticsearch @pytest.mark.elasticsearch
@pytest.mark.parametrize("document_store_dot_product_with_docs", ["elasticsearch"], indirect=True) @pytest.mark.parametrize("document_store_dot_product_with_docs", ["elasticsearch"], indirect=True)
def test_join_merge_no_weights(document_store_dot_product_with_docs): def test_join_merge_no_weights(document_store_dot_product_with_docs):