diff --git a/haystack/agents/conversational.py b/haystack/agents/conversational.py index adb42c980..82c22c639 100644 --- a/haystack/agents/conversational.py +++ b/haystack/agents/conversational.py @@ -1,8 +1,8 @@ -from typing import Optional, Callable +from typing import Optional, Callable, Union from haystack.agents import Agent from haystack.agents.memory import Memory, ConversationMemory -from haystack.nodes import PromptNode +from haystack.nodes import PromptNode, PromptTemplate class ConversationalAgent(Agent): @@ -36,6 +36,7 @@ class ConversationalAgent(Agent): def __init__( self, prompt_node: PromptNode, + prompt_template: Optional[Union[str, PromptTemplate]] = None, memory: Optional[Memory] = None, prompt_parameters_resolver: Optional[Callable] = None, ): @@ -43,6 +44,8 @@ class ConversationalAgent(Agent): Creates a new ConversationalAgent instance :param prompt_node: A PromptNode used to communicate with LLM. + :param prompt_template: A string or PromptTemplate instance to use as the prompt template. If no prompt_template + is provided, the agent will use the "conversational-agent" template. :param memory: A memory instance for storing conversation history and other relevant data, defaults to ConversationMemory. :param prompt_parameters_resolver: An optional callable for resolving prompt template parameters, @@ -50,9 +53,7 @@ class ConversationalAgent(Agent): """ super().__init__( prompt_node=prompt_node, - prompt_template=prompt_node.default_prompt_template - if prompt_node.default_prompt_template is not None - else "conversational-agent", + prompt_template=prompt_template or "conversational-agent", max_steps=2, memory=memory if memory else ConversationMemory(), prompt_parameters_resolver=prompt_parameters_resolver