From 361cb1d240d31ec2a2d06ad5c6581143e8b26e9a Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Mon, 22 May 2023 12:33:47 +0200 Subject: [PATCH] fix: Remove streaming LLM tracking; they are all streaming now (#4944) * Remove streaming LLM tracking; they are all streaming now * PR feedback --- haystack/agents/base.py | 15 +++++++-------- haystack/agents/utils.py | 2 -- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/haystack/agents/base.py b/haystack/agents/base.py index 538caa02b..960e610a2 100644 --- a/haystack/agents/base.py +++ b/haystack/agents/base.py @@ -13,7 +13,7 @@ from haystack.agents.memory import Memory, NoMemory from haystack.telemetry import send_event from haystack.agents.agent_step import AgentStep from haystack.agents.types import Color, AgentTokenStreamingHandler -from haystack.agents.utils import print_text, STREAMING_CAPABLE_MODELS +from haystack.agents.utils import print_text from haystack.nodes import PromptNode, BaseRetriever, PromptTemplate from haystack.pipelines import ( BaseStandardPipeline, @@ -228,6 +228,7 @@ class Agent: prompt_parameters_resolver: Optional[Callable] = None, max_steps: int = 8, final_answer_pattern: str = r"Final Answer\s*:\s*(.*)", + streaming: bool = True, ): """ Creates an Agent instance. @@ -248,6 +249,9 @@ class Agent: Set it to at least 2, so that the Agent can run one a tool once and then infer it knows the final answer. The default is 8. :param final_answer_pattern: A regular expression to extract the final answer from the text the Agent generated. + :param streaming: Whether to use streaming or not. If True, the Agent will stream response tokens from the LLM. + If False, the Agent will wait for the LLM to finish generating the response and then process it. The default is + True. """ self.max_steps = max_steps self.tm = tools_manager or ToolsManager() @@ -273,12 +277,7 @@ class Agent: prompt_parameters_resolver if prompt_parameters_resolver else react_parameter_resolver ) self.final_answer_pattern = final_answer_pattern - # Resolve model name to check if it's a streaming model - if isinstance(self.prompt_node.model_name_or_path, str): - model_name = self.prompt_node.model_name_or_path - else: - model_name = self.prompt_node.model_name_or_path.model_name_or_path - self.add_default_logging_callbacks(streaming=any(m for m in STREAMING_CAPABLE_MODELS if m in model_name)) + self.add_default_logging_callbacks(streaming=streaming) self.hash = None self.last_hash = None self.update_hash() @@ -318,7 +317,7 @@ class Agent: self.callback_manager.on_new_token += lambda token, **kwargs: print_text(token, color=agent_color) else: self.callback_manager.on_agent_step += lambda agent_step: print_text( - agent_step.prompt_node_response, color=agent_color + agent_step.prompt_node_response, end="\n", color=agent_color ) def add_tool(self, tool: Tool): diff --git a/haystack/agents/utils.py b/haystack/agents/utils.py index 4b59bd160..f86f9b58d 100644 --- a/haystack/agents/utils.py +++ b/haystack/agents/utils.py @@ -2,8 +2,6 @@ from typing import Optional from haystack.agents.types import Color -STREAMING_CAPABLE_MODELS = ["text-davinci-003", "gpt-3.5-turbo", "gpt-35-turbo", "gpt-4"] - def print_text(text: str, end="", color: Optional[Color] = None) -> None: """