bug: The PromptNode handles all parameters as lists without checking if they are in fact lists (#3820)

This commit is contained in:
Zoltan Fedor 2023-01-10 02:08:17 -05:00 committed by GitHub
parent 897e89c9b1
commit 0288e1be76
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 0 deletions

View File

@ -645,6 +645,8 @@ class PromptNode(BaseComponent):
template = Template(prompt_prepared["prompt_template"])
prompt_context_copy = prompt_prepared.copy()
prompt_context_copy.pop("prompt_template")
# the prompt context values should all be lists, as they will be split as one
prompt_context_copy = {k: v if isinstance(v, list) else [v] for k, v in prompt_context_copy.items()}
for prompt_context_values in zip(*prompt_context_copy.values()):
template_input = {key: prompt_context_values[idx] for idx, key in enumerate(prompt_context_copy.keys())}
template_prepared: str = template.substitute(template_input)

View File

@ -268,6 +268,35 @@ def test_complex_pipeline(prompt_model):
assert "berlin" in result["results"][0].casefold()
@pytest.mark.parametrize("prompt_model", ["hf", "openai"], indirect=True)
def test_complex_pipeline_with_qa(prompt_model):
"""Test the PromptNode where the `query` is a string instead of a list what the PromptNode would expects,
because in a question-answering pipeline the retrievers need `query` as a string, so the PromptNode
need to be able to handle the `query` being a string instead of a list."""
if prompt_model.api_key is not None and not is_openai_api_key_set(prompt_model.api_key):
pytest.skip("No API key found for OpenAI, skipping test")
prompt_template = PromptTemplate(
name="question-answering-new",
prompt_text="Given the context please answer the question. Context: $documents; Question: $query; Answer:",
prompt_params=["documents", "query"],
)
node = PromptNode(prompt_model, default_prompt_template=prompt_template)
pipe = Pipeline()
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
result = pipe.run(
query="Who lives in Berlin?", # this being a string instead of a list what is being tested
documents=[
Document("My name is Carla and I live in Berlin"),
Document("My name is Christelle and I live in Paris"),
],
)
assert len(result["results"]) == 1
assert "carla" in result["results"][0].casefold()
def test_complex_pipeline_with_shared_model():
model = PromptModel()
node = PromptNode(