fix: Fix Pipeline.run() getting stuck in a loop even though there are components that can run (#7434)

This commit is contained in:
Silvano Cerza 2024-03-28 12:31:36 +01:00 committed by GitHub
parent 6fcb62ae34
commit 6e289698e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 45 additions and 5 deletions

View File

@ -897,12 +897,15 @@ class Pipeline:
and last_waiting_for_input is not None
and before_last_waiting_for_input == last_waiting_for_input
):
# Are we actually stuck or there's a lazy variadic waiting for input?
# This is our last resort, if there's no lazy variadic waiting for input
# Are we actually stuck or there's a lazy variadic or a component with has only default inputs waiting for input?
# This is our last resort, if there's no lazy variadic or component with only default inputs waiting for input
# we're stuck for real and we can't make any progress.
for name, comp in waiting_for_input:
is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) # type: ignore
if is_variadic and not comp.__haystack_is_greedy__: # type: ignore[attr-defined]
has_only_defaults = all(
not socket.is_mandatory for socket in comp.__haystack_input__._sockets_dict.values() # type: ignore
)
if is_variadic and not comp.__haystack_is_greedy__ or has_only_defaults: # type: ignore[attr-defined]
break
else:
# We're stuck in a loop for real, we can't make any progress.
@ -910,13 +913,13 @@ class Pipeline:
break
if len(waiting_for_input) == 1:
# We have a single component with variadic input waiting for input.
# We have a single component with variadic input or only default inputs waiting for input.
# If we're at this point it means it has been waiting for input for at least 2 iterations.
# This will never run.
# BAIL!
break
# There was a lazy variadic waiting for input, we can run it
# There was a lazy variadic or a component with only default waiting for input, we can run it
waiting_for_input.remove((name, comp))
to_run.append((name, comp))
continue

View File

@ -9,6 +9,7 @@ import pytest
from haystack import Document
from haystack.components.builders import PromptBuilder
from haystack.components.builders.answer_builder import AnswerBuilder
from haystack.components.others import Multiplexer
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
from haystack.core.component import component
@ -807,3 +808,39 @@ def test_correct_execution_order_of_components_with_only_defaults(spying_tracer)
"Question: What is the capital of France?"
}
}
def test_pipeline_is_not_stuck_with_components_with_only_defaults():
FakeGenerator = component_class(
"FakeGenerator", input_types={"prompt": str}, output_types={"replies": List[str]}, output={"replies": ["Paris"]}
)
docs = [Document(content="Rome is the capital of Italy"), Document(content="Paris is the capital of France")]
doc_store = InMemoryDocumentStore()
doc_store.write_documents(docs)
template = (
"Given the following information, answer the question.\n"
"Context:\n"
"{% for document in documents %}"
" {{ document.content }}\n"
"{% endfor %}"
"Question: {{ query }}"
)
pipe = Pipeline()
pipe.add_component("retriever", InMemoryBM25Retriever(document_store=doc_store))
pipe.add_component("prompt_builder", PromptBuilder(template=template))
pipe.add_component("generator", FakeGenerator())
pipe.add_component("answer_builder", AnswerBuilder())
pipe.connect("retriever", "prompt_builder.documents")
pipe.connect("prompt_builder.prompt", "generator.prompt")
pipe.connect("generator.replies", "answer_builder.replies")
pipe.connect("retriever.documents", "answer_builder.documents")
query = "What is the capital of France?"
res = pipe.run({"query": query})
assert len(res) == 1
answers = res["answer_builder"]["answers"]
assert len(answers) == 1
assert answers[0].data == "Paris"