feat: Add prompt_template to conversational agent init params (#4994)

This commit is contained in:
Julian Risch 2023-05-24 09:22:29 +02:00 committed by GitHub
parent 524d2cba36
commit ae9f384a97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,8 +1,8 @@
from typing import Optional, Callable
from typing import Optional, Callable, Union
from haystack.agents import Agent
from haystack.agents.memory import Memory, ConversationMemory
from haystack.nodes import PromptNode
from haystack.nodes import PromptNode, PromptTemplate
class ConversationalAgent(Agent):
@ -36,6 +36,7 @@ class ConversationalAgent(Agent):
def __init__(
self,
prompt_node: PromptNode,
prompt_template: Optional[Union[str, PromptTemplate]] = None,
memory: Optional[Memory] = None,
prompt_parameters_resolver: Optional[Callable] = None,
):
@ -43,6 +44,8 @@ class ConversationalAgent(Agent):
Creates a new ConversationalAgent instance
:param prompt_node: A PromptNode used to communicate with LLM.
:param prompt_template: A string or PromptTemplate instance to use as the prompt template. If no prompt_template
is provided, the agent will use the "conversational-agent" template.
:param memory: A memory instance for storing conversation history and other relevant data, defaults to
ConversationMemory.
:param prompt_parameters_resolver: An optional callable for resolving prompt template parameters,
@ -50,9 +53,7 @@ class ConversationalAgent(Agent):
"""
super().__init__(
prompt_node=prompt_node,
prompt_template=prompt_node.default_prompt_template
if prompt_node.default_prompt_template is not None
else "conversational-agent",
prompt_template=prompt_template or "conversational-agent",
max_steps=2,
memory=memory if memory else ConversationMemory(),
prompt_parameters_resolver=prompt_parameters_resolver