Remove __eq__ and __hash__ from PromptNode (#3923)

This commit is contained in:
Vladimir Blagojevic 2023-01-26 13:38:35 +01:00 committed by GitHub
parent addebcd256
commit ec85207cf7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 10 deletions

View File

@ -877,16 +877,6 @@ class PromptNode(BaseComponent):
return list(self.prompt_templates[prompt_template].prompt_params) return list(self.prompt_templates[prompt_template].prompt_params)
def __eq__(self, other):
if isinstance(other, PromptNode):
if self.default_prompt_template != other.default_prompt_template:
return False
return self.model_name_or_path == other.model_name_or_path
return False
def __hash__(self):
return hash((self.default_prompt_template, self.model_name_or_path))
def run( def run(
self, self,
query: Optional[str] = None, query: Optional[str] = None,

View File

@ -645,3 +645,41 @@ def test_complex_pipeline_with_all_features(tmp_path):
assert len(result["invocation_context"]) > 0 assert len(result["invocation_context"]) > 0
assert len(result["questions"]) > 0 assert len(result["questions"]) > 0
assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0 assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0
def test_complex_pipeline_with_multiple_same_prompt_node_components_yaml(tmp_path):
# p2 and p3 are essentially the same PromptNode component, make sure we can use them both as is in the pipeline
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
f"""
version: ignore
components:
- name: p1
params:
default_prompt_template: question-generation
output_variable: questions
type: PromptNode
- name: p2
params:
default_prompt_template: question-answering
type: PromptNode
- name: p3
params:
default_prompt_template: question-answering
type: PromptNode
pipelines:
- name: query
nodes:
- name: p1
inputs:
- Query
- name: p2
inputs:
- p1
- name: p3
inputs:
- p2
"""
)
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
assert pipeline is not None