feat: Add agent memory (#4829)

This commit is contained in:
Vladimir Blagojevic 2023-05-15 18:08:44 +02:00 committed by GitHub
parent d4bbde2d9d
commit 4c9843017c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 397 additions and 4 deletions

View File

@ -0,0 +1,4 @@
from haystack.agents.memory.base import Memory
from haystack.agents.memory.no_memory import NoMemory
from haystack.agents.memory.conversation_memory import ConversationMemory
from haystack.agents.memory.conversation_summary_memory import ConversationSummaryMemory

View File

@ -0,0 +1,31 @@
from abc import ABC, abstractmethod
from typing import Dict, Any, List, Optional
class Memory(ABC):
"""
Abstract base class for memory management in an Agent.
"""
@abstractmethod
def load(self, keys: Optional[List[str]] = None, **kwargs) -> Any:
"""
Load the context of this model run from memory.
:param keys: Optional list of keys to specify the data to load.
:return: The loaded data.
"""
@abstractmethod
def save(self, data: Dict[str, Any]) -> None:
"""
Save the context of this model run to memory.
:param data: A dictionary containing the data to save.
"""
@abstractmethod
def clear(self) -> None:
"""
Clear memory contents.
"""

View File

@ -0,0 +1,60 @@
import collections
from typing import OrderedDict, List, Optional, Any, Dict
from haystack.agents.memory import Memory
class ConversationMemory(Memory):
"""
A memory class that stores conversation history.
"""
def __init__(self, input_key: str = "input", output_key: str = "output"):
"""
Initialize ConversationMemory with input and output keys.
:param input_key: The key to use for storing user input.
:param output_key: The key to use for storing model output.
"""
self.list: List[OrderedDict] = []
self.input_key = input_key
self.output_key = output_key
def load(self, keys: Optional[List[str]] = None, **kwargs) -> str:
"""
Load conversation history as a formatted string.
:param keys: Optional list of keys (ignored in this implementation).
:param kwargs: Optional keyword arguments
- window_size: integer specifying the number of most recent conversation snippets to load.
:return: A formatted string containing the conversation history.
"""
chat_transcript = ""
window_size = kwargs.get("window_size", None)
if window_size is not None:
chat_list = self.list[-window_size:] # pylint: disable=invalid-unary-operand-type
else:
chat_list = self.list
for chat_snippet in chat_list:
chat_transcript += f"Human: {chat_snippet['Human']}\n"
chat_transcript += f"AI: {chat_snippet['AI']}\n"
return chat_transcript
def save(self, data: Dict[str, Any]) -> None:
"""
Save a conversation snippet to memory.
:param data: A dictionary containing the conversation snippet to save.
"""
chat_snippet = collections.OrderedDict()
chat_snippet["Human"] = data[self.input_key]
chat_snippet["AI"] = data[self.output_key]
self.list.append(chat_snippet)
def clear(self) -> None:
"""
Clear the conversation history.
"""
self.list = []

View File

@ -0,0 +1,108 @@
from typing import Optional, Union, Dict, Any, List
from haystack.agents.memory import ConversationMemory
from haystack.nodes import PromptTemplate, PromptNode
class ConversationSummaryMemory(ConversationMemory):
"""
A memory class that stores conversation history and periodically generates summaries.
"""
def __init__(
self,
prompt_node: PromptNode,
prompt_template: Optional[Union[str, PromptTemplate]] = None,
input_key: str = "input",
output_key: str = "output",
summary_frequency: int = 3,
):
"""
Initialize ConversationSummaryMemory with a PromptNode, optional prompt_template,
input and output keys, and a summary_frequency.
:param prompt_node: A PromptNode object for generating conversation summaries.
:param prompt_template: Optional prompt template as a string or PromptTemplate object.
:param input_key: input key, default is "input".
:param output_key: output key, default is "output".
:param summary_frequency: integer specifying how often to generate a summary (default is 3).
"""
super().__init__(input_key, output_key)
self.save_count = 0
self.prompt_node = prompt_node
template = (
prompt_template
if prompt_template is not None
else prompt_node.default_prompt_template or "conversational-summary"
)
self.template = prompt_node.get_prompt_template(template)
self.summary_frequency = summary_frequency
self.summary = ""
def load(self, keys: Optional[List[str]] = None, **kwargs) -> str:
"""
Load conversation history as a formatted string, including the latest summary.
:param keys: Optional list of keys (ignored in this implementation).
:param kwargs: Optional keyword arguments
- window_size: integer specifying the number of most recent conversation snippets to load.
:return: A formatted string containing the conversation history with the latest summary.
"""
if self.has_unsummarized_snippets():
unsummarized = super().load(keys=keys, window_size=self.unsummarized_snippets())
return f"{self.summary}\n{unsummarized}"
else:
return self.summary
def summarize(self) -> str:
"""
Generate a summary of the conversation history and clear the history.
:return: A string containing the generated summary.
"""
most_recent_chat_snippets = self.load(window_size=self.summary_frequency)
pn_response = self.prompt_node.prompt(self.template, chat_transcript=most_recent_chat_snippets)
return pn_response[0]
def needs_summary(self) -> bool:
"""
Determine if a new summary should be generated.
:return: True if a new summary should be generated, otherwise False.
"""
return self.save_count % self.summary_frequency == 0
def unsummarized_snippets(self) -> int:
"""
Returns how many conversation snippets have not been summarized.
:return: The number of conversation snippets that have not been summarized.
"""
return self.save_count % self.summary_frequency
def has_unsummarized_snippets(self) -> bool:
"""
Returns True if there are any conversation snippets that have not been summarized.
:return: True if there are unsummarized snippets, otherwise False.
"""
return self.unsummarized_snippets() != 0
def save(self, data: Dict[str, Any]) -> None:
"""
Save a conversation snippet to memory and update the save count.
Generate a summary if needed.
:param data: A dictionary containing the conversation snippet to save.
"""
super().save(data)
self.save_count += 1
if self.needs_summary():
self.summary = self.summarize()
def clear(self) -> None:
"""
Clear the conversation history and the summary.
"""
super().clear()
self.save_count = 0
self.summary = ""

View File

@ -0,0 +1,32 @@
from typing import Optional, List, Any, Dict
from haystack.agents.memory import Memory
class NoMemory(Memory):
"""
A memory class that doesn't store any data.
"""
def load(self, keys: Optional[List[str]] = None, **kwargs) -> str:
"""
Load an empty dictionary.
:param keys: Optional list of keys (ignored in this implementation).
:return: An empty str.
"""
return ""
def save(self, data: Dict[str, Any]) -> None:
"""
Save method that does nothing.
:param data: A dictionary containing the data to save (ignored in this implementation).
"""
pass
def clear(self) -> None:
"""
Clear method that does nothing.
"""
pass

View File

@ -434,4 +434,8 @@ def get_predefined_prompt_templates() -> List[PromptTemplate]:
"Question: {query}\n"
"Thought: Let's think step-by-step, I first need to ",
),
PromptTemplate(
name="conversational-summary",
prompt_text="Condense the following chat transcript by shortening and summarizing the content without losing important information:\n{chat_transcript}\nCondensed Transcript:",
),
]

View File

@ -0,0 +1,45 @@
import pytest
from typing import Dict, Any
from haystack.agents.memory import NoMemory, ConversationMemory
@pytest.mark.unit
def test_no_memory():
no_mem = NoMemory()
assert no_mem.load() == ""
no_mem.save({"key": "value"})
no_mem.clear()
@pytest.mark.unit
def test_conversation_memory():
conv_mem = ConversationMemory()
assert conv_mem.load() == ""
data: Dict[str, Any] = {"input": "Hello", "output": "Hi there"}
conv_mem.save(data)
assert conv_mem.load() == "Human: Hello\nAI: Hi there\n"
data: Dict[str, Any] = {"input": "How are you?", "output": "I'm doing well, thanks."}
conv_mem.save(data)
assert conv_mem.load() == "Human: Hello\nAI: Hi there\nHuman: How are you?\nAI: I'm doing well, thanks.\n"
assert conv_mem.load(window_size=1) == "Human: How are you?\nAI: I'm doing well, thanks.\n"
conv_mem.clear()
assert conv_mem.load() == ""
@pytest.mark.unit
def test_conversation_memory_window_size():
conv_mem = ConversationMemory()
assert conv_mem.load() == ""
data: Dict[str, Any] = {"input": "Hello", "output": "Hi there"}
conv_mem.save(data)
data: Dict[str, Any] = {"input": "How are you?", "output": "I'm doing well, thanks."}
conv_mem.save(data)
assert conv_mem.load() == "Human: Hello\nAI: Hi there\nHuman: How are you?\nAI: I'm doing well, thanks.\n"
assert conv_mem.load(window_size=1) == "Human: How are you?\nAI: I'm doing well, thanks.\n"
# clear the memory
conv_mem.clear()
assert conv_mem.load() == ""
assert conv_mem.load(window_size=1) == ""

View File

@ -0,0 +1,109 @@
from unittest.mock import MagicMock
from haystack.nodes import PromptNode, PromptTemplate
import pytest
from typing import Dict, Any
from haystack.agents.memory import ConversationSummaryMemory
@pytest.fixture
def mocked_prompt_node():
mock_prompt_node = MagicMock(spec=PromptNode)
mock_prompt_node.default_prompt_template = PromptTemplate(
"conversational-summary", "Summarize the conversation: {chat_transcript}"
)
mock_prompt_node.prompt.return_value = ["This is a summary."]
return mock_prompt_node
@pytest.mark.unit
def test_conversation_summary_memory(mocked_prompt_node):
summary = "This is a fake summary definitely."
mocked_prompt_node.prompt.return_value = [summary]
summary_mem = ConversationSummaryMemory(mocked_prompt_node)
# Test saving and loading without summaries
data1: Dict[str, Any] = {"input": "Hello", "output": "Hi there"}
summary_mem.save(data1)
assert summary_mem.load() == "\nHuman: Hello\nAI: Hi there\n"
assert summary_mem.has_unsummarized_snippets()
assert summary_mem.unsummarized_snippets() == 1
data2: Dict[str, Any] = {"input": "How are you?", "output": "I'm doing well, thanks."}
summary_mem.save(data2)
assert summary_mem.load() == "\nHuman: Hello\nAI: Hi there\nHuman: How are you?\nAI: I'm doing well, thanks.\n"
assert summary_mem.has_unsummarized_snippets()
assert summary_mem.unsummarized_snippets() == 2
# Test summarization
data3: Dict[str, Any] = {"input": "What's the weather like?", "output": "It's sunny outside."}
summary_mem.save(data3)
assert summary_mem.load() == summary
assert not summary_mem.has_unsummarized_snippets()
assert summary_mem.unsummarized_snippets() == 0
summary_mem.clear()
assert summary_mem.load() == ""
@pytest.mark.unit
def test_conversation_summary_memory_lower_summary_frequency(mocked_prompt_node):
summary = "This is a fake summary definitely."
mocked_prompt_node.prompt.return_value = [summary]
summary_mem = ConversationSummaryMemory(mocked_prompt_node, summary_frequency=2)
data1: Dict[str, Any] = {"input": "Hello", "output": "Hi there"}
summary_mem.save(data1)
assert summary_mem.load() == "\nHuman: Hello\nAI: Hi there\n"
assert summary_mem.has_unsummarized_snippets()
assert summary_mem.unsummarized_snippets() == 1
# Test summarization
data2: Dict[str, Any] = {"input": "How are you?", "output": "I'm doing well, thanks."}
summary_mem.save(data2)
assert summary_mem.load() == summary
assert not summary_mem.has_unsummarized_snippets()
assert summary_mem.unsummarized_snippets() == 0
data3: Dict[str, Any] = {"input": "What's the weather like?", "output": "It's sunny outside."}
summary_mem.save(data3)
assert summary_mem.load() == summary + "\nHuman: What's the weather like?\nAI: It's sunny outside.\n"
assert summary_mem.has_unsummarized_snippets()
assert summary_mem.unsummarized_snippets() == 1
summary_mem.clear()
assert summary_mem.load() == ""
# start over
summary_mem.save(data1)
assert summary_mem.load() == "\nHuman: Hello\nAI: Hi there\n"
assert summary_mem.has_unsummarized_snippets()
assert summary_mem.unsummarized_snippets() == 1
# Test summarization
data2: Dict[str, Any] = {"input": "How are you?", "output": "I'm doing well, thanks."}
summary_mem.save(data2)
assert summary_mem.load() == summary
assert not summary_mem.has_unsummarized_snippets()
assert summary_mem.unsummarized_snippets() == 0
@pytest.mark.unit
def test_conversation_summary_memory_with_template(mocked_prompt_node):
pt = PromptTemplate("conversational-summary", "Summarize the conversation: {chat_transcript}")
summary_mem = ConversationSummaryMemory(mocked_prompt_node, prompt_template=pt)
data1: Dict[str, Any] = {"input": "Hello", "output": "Hi there"}
summary_mem.save(data1)
assert summary_mem.load() == "\nHuman: Hello\nAI: Hi there\n"
data2: Dict[str, Any] = {"input": "How are you?", "output": "I'm doing well, thanks."}
summary_mem.save(data2)
assert summary_mem.load() == "\nHuman: Hello\nAI: Hi there\nHuman: How are you?\nAI: I'm doing well, thanks.\n"
data3: Dict[str, Any] = {"input": "What's the weather like?", "output": "It's sunny outside."}
summary_mem.save(data3)
assert summary_mem.load() == "This is a summary."
summary_mem.clear()
assert summary_mem.load() == ""

View File

@ -28,14 +28,14 @@ def get_api_key(request):
def test_add_and_remove_template():
with patch("haystack.nodes.prompt.prompt_node.PromptModel"):
node = PromptNode()
total_count = 15
# Verifies default
assert len(node.get_prompt_template_names()) == 14
assert len(node.get_prompt_template_names()) == total_count
# Add a fake template
fake_template = PromptTemplate(name="fake-template", prompt_text="Fake prompt")
node.add_prompt_template(fake_template)
assert len(node.get_prompt_template_names()) == 15
assert len(node.get_prompt_template_names()) == total_count + 1
assert "fake-template" in node.get_prompt_template_names()
# Verify that adding the same template throws an expection
@ -47,7 +47,7 @@ def test_add_and_remove_template():
# Verify template is correctly removed
assert node.remove_prompt_template("fake-template")
assert len(node.get_prompt_template_names()) == 14
assert len(node.get_prompt_template_names()) == total_count
assert "fake-template" not in node.get_prompt_template_names()
# Verify that removing the same template throws an expection