From ec85207cf7b584efd14de076adaa68ab042a0c1b Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 26 Jan 2023 13:38:35 +0100 Subject: [PATCH] Remove __eq__ and __hash__ from PromptNode (#3923) --- haystack/nodes/prompt/prompt_node.py | 10 -------- test/nodes/test_prompt_node.py | 38 ++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/haystack/nodes/prompt/prompt_node.py b/haystack/nodes/prompt/prompt_node.py index 4e5063531..24c2e0c12 100644 --- a/haystack/nodes/prompt/prompt_node.py +++ b/haystack/nodes/prompt/prompt_node.py @@ -877,16 +877,6 @@ class PromptNode(BaseComponent): return list(self.prompt_templates[prompt_template].prompt_params) - def __eq__(self, other): - if isinstance(other, PromptNode): - if self.default_prompt_template != other.default_prompt_template: - return False - return self.model_name_or_path == other.model_name_or_path - return False - - def __hash__(self): - return hash((self.default_prompt_template, self.model_name_or_path)) - def run( self, query: Optional[str] = None, diff --git a/test/nodes/test_prompt_node.py b/test/nodes/test_prompt_node.py index 3d9f97f6d..cde3af9f7 100644 --- a/test/nodes/test_prompt_node.py +++ b/test/nodes/test_prompt_node.py @@ -645,3 +645,41 @@ def test_complex_pipeline_with_all_features(tmp_path): assert len(result["invocation_context"]) > 0 assert len(result["questions"]) > 0 assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0 + + +def test_complex_pipeline_with_multiple_same_prompt_node_components_yaml(tmp_path): + # p2 and p3 are essentially the same PromptNode component, make sure we can use them both as is in the pipeline + with open(tmp_path / "tmp_config.yml", "w") as tmp_file: + tmp_file.write( + f""" + version: ignore + components: + - name: p1 + params: + default_prompt_template: question-generation + output_variable: questions + type: PromptNode + - name: p2 + params: + default_prompt_template: question-answering + type: PromptNode + - name: p3 + params: + default_prompt_template: question-answering + type: PromptNode + pipelines: + - name: query + nodes: + - name: p1 + inputs: + - Query + - name: p2 + inputs: + - p1 + - name: p3 + inputs: + - p2 + """ + ) + pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml") + assert pipeline is not None