mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-25 06:26:05 +00:00
feat: Add agent memory (#4829)
This commit is contained in:
parent
d4bbde2d9d
commit
4c9843017c
4
haystack/agents/memory/__init__.py
Normal file
4
haystack/agents/memory/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
from haystack.agents.memory.base import Memory
|
||||
from haystack.agents.memory.no_memory import NoMemory
|
||||
from haystack.agents.memory.conversation_memory import ConversationMemory
|
||||
from haystack.agents.memory.conversation_summary_memory import ConversationSummaryMemory
|
||||
31
haystack/agents/memory/base.py
Normal file
31
haystack/agents/memory/base.py
Normal file
@ -0,0 +1,31 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
|
||||
class Memory(ABC):
|
||||
"""
|
||||
Abstract base class for memory management in an Agent.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def load(self, keys: Optional[List[str]] = None, **kwargs) -> Any:
|
||||
"""
|
||||
Load the context of this model run from memory.
|
||||
|
||||
:param keys: Optional list of keys to specify the data to load.
|
||||
:return: The loaded data.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def save(self, data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Save the context of this model run to memory.
|
||||
|
||||
:param data: A dictionary containing the data to save.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def clear(self) -> None:
|
||||
"""
|
||||
Clear memory contents.
|
||||
"""
|
||||
60
haystack/agents/memory/conversation_memory.py
Normal file
60
haystack/agents/memory/conversation_memory.py
Normal file
@ -0,0 +1,60 @@
|
||||
import collections
|
||||
from typing import OrderedDict, List, Optional, Any, Dict
|
||||
|
||||
from haystack.agents.memory import Memory
|
||||
|
||||
|
||||
class ConversationMemory(Memory):
|
||||
"""
|
||||
A memory class that stores conversation history.
|
||||
"""
|
||||
|
||||
def __init__(self, input_key: str = "input", output_key: str = "output"):
|
||||
"""
|
||||
Initialize ConversationMemory with input and output keys.
|
||||
|
||||
:param input_key: The key to use for storing user input.
|
||||
:param output_key: The key to use for storing model output.
|
||||
"""
|
||||
self.list: List[OrderedDict] = []
|
||||
self.input_key = input_key
|
||||
self.output_key = output_key
|
||||
|
||||
def load(self, keys: Optional[List[str]] = None, **kwargs) -> str:
|
||||
"""
|
||||
Load conversation history as a formatted string.
|
||||
|
||||
:param keys: Optional list of keys (ignored in this implementation).
|
||||
:param kwargs: Optional keyword arguments
|
||||
- window_size: integer specifying the number of most recent conversation snippets to load.
|
||||
:return: A formatted string containing the conversation history.
|
||||
"""
|
||||
chat_transcript = ""
|
||||
window_size = kwargs.get("window_size", None)
|
||||
|
||||
if window_size is not None:
|
||||
chat_list = self.list[-window_size:] # pylint: disable=invalid-unary-operand-type
|
||||
else:
|
||||
chat_list = self.list
|
||||
|
||||
for chat_snippet in chat_list:
|
||||
chat_transcript += f"Human: {chat_snippet['Human']}\n"
|
||||
chat_transcript += f"AI: {chat_snippet['AI']}\n"
|
||||
return chat_transcript
|
||||
|
||||
def save(self, data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Save a conversation snippet to memory.
|
||||
|
||||
:param data: A dictionary containing the conversation snippet to save.
|
||||
"""
|
||||
chat_snippet = collections.OrderedDict()
|
||||
chat_snippet["Human"] = data[self.input_key]
|
||||
chat_snippet["AI"] = data[self.output_key]
|
||||
self.list.append(chat_snippet)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""
|
||||
Clear the conversation history.
|
||||
"""
|
||||
self.list = []
|
||||
108
haystack/agents/memory/conversation_summary_memory.py
Normal file
108
haystack/agents/memory/conversation_summary_memory.py
Normal file
@ -0,0 +1,108 @@
|
||||
from typing import Optional, Union, Dict, Any, List
|
||||
|
||||
from haystack.agents.memory import ConversationMemory
|
||||
from haystack.nodes import PromptTemplate, PromptNode
|
||||
|
||||
|
||||
class ConversationSummaryMemory(ConversationMemory):
|
||||
"""
|
||||
A memory class that stores conversation history and periodically generates summaries.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_node: PromptNode,
|
||||
prompt_template: Optional[Union[str, PromptTemplate]] = None,
|
||||
input_key: str = "input",
|
||||
output_key: str = "output",
|
||||
summary_frequency: int = 3,
|
||||
):
|
||||
"""
|
||||
Initialize ConversationSummaryMemory with a PromptNode, optional prompt_template,
|
||||
input and output keys, and a summary_frequency.
|
||||
|
||||
:param prompt_node: A PromptNode object for generating conversation summaries.
|
||||
:param prompt_template: Optional prompt template as a string or PromptTemplate object.
|
||||
:param input_key: input key, default is "input".
|
||||
:param output_key: output key, default is "output".
|
||||
:param summary_frequency: integer specifying how often to generate a summary (default is 3).
|
||||
"""
|
||||
super().__init__(input_key, output_key)
|
||||
self.save_count = 0
|
||||
self.prompt_node = prompt_node
|
||||
|
||||
template = (
|
||||
prompt_template
|
||||
if prompt_template is not None
|
||||
else prompt_node.default_prompt_template or "conversational-summary"
|
||||
)
|
||||
self.template = prompt_node.get_prompt_template(template)
|
||||
self.summary_frequency = summary_frequency
|
||||
self.summary = ""
|
||||
|
||||
def load(self, keys: Optional[List[str]] = None, **kwargs) -> str:
|
||||
"""
|
||||
Load conversation history as a formatted string, including the latest summary.
|
||||
|
||||
:param keys: Optional list of keys (ignored in this implementation).
|
||||
:param kwargs: Optional keyword arguments
|
||||
- window_size: integer specifying the number of most recent conversation snippets to load.
|
||||
:return: A formatted string containing the conversation history with the latest summary.
|
||||
"""
|
||||
if self.has_unsummarized_snippets():
|
||||
unsummarized = super().load(keys=keys, window_size=self.unsummarized_snippets())
|
||||
return f"{self.summary}\n{unsummarized}"
|
||||
else:
|
||||
return self.summary
|
||||
|
||||
def summarize(self) -> str:
|
||||
"""
|
||||
Generate a summary of the conversation history and clear the history.
|
||||
|
||||
:return: A string containing the generated summary.
|
||||
"""
|
||||
most_recent_chat_snippets = self.load(window_size=self.summary_frequency)
|
||||
pn_response = self.prompt_node.prompt(self.template, chat_transcript=most_recent_chat_snippets)
|
||||
return pn_response[0]
|
||||
|
||||
def needs_summary(self) -> bool:
|
||||
"""
|
||||
Determine if a new summary should be generated.
|
||||
|
||||
:return: True if a new summary should be generated, otherwise False.
|
||||
"""
|
||||
return self.save_count % self.summary_frequency == 0
|
||||
|
||||
def unsummarized_snippets(self) -> int:
|
||||
"""
|
||||
Returns how many conversation snippets have not been summarized.
|
||||
:return: The number of conversation snippets that have not been summarized.
|
||||
"""
|
||||
return self.save_count % self.summary_frequency
|
||||
|
||||
def has_unsummarized_snippets(self) -> bool:
|
||||
"""
|
||||
Returns True if there are any conversation snippets that have not been summarized.
|
||||
:return: True if there are unsummarized snippets, otherwise False.
|
||||
"""
|
||||
return self.unsummarized_snippets() != 0
|
||||
|
||||
def save(self, data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Save a conversation snippet to memory and update the save count.
|
||||
Generate a summary if needed.
|
||||
|
||||
:param data: A dictionary containing the conversation snippet to save.
|
||||
"""
|
||||
super().save(data)
|
||||
self.save_count += 1
|
||||
if self.needs_summary():
|
||||
self.summary = self.summarize()
|
||||
|
||||
def clear(self) -> None:
|
||||
"""
|
||||
Clear the conversation history and the summary.
|
||||
"""
|
||||
super().clear()
|
||||
self.save_count = 0
|
||||
self.summary = ""
|
||||
32
haystack/agents/memory/no_memory.py
Normal file
32
haystack/agents/memory/no_memory.py
Normal file
@ -0,0 +1,32 @@
|
||||
from typing import Optional, List, Any, Dict
|
||||
|
||||
from haystack.agents.memory import Memory
|
||||
|
||||
|
||||
class NoMemory(Memory):
|
||||
"""
|
||||
A memory class that doesn't store any data.
|
||||
"""
|
||||
|
||||
def load(self, keys: Optional[List[str]] = None, **kwargs) -> str:
|
||||
"""
|
||||
Load an empty dictionary.
|
||||
|
||||
:param keys: Optional list of keys (ignored in this implementation).
|
||||
:return: An empty str.
|
||||
"""
|
||||
return ""
|
||||
|
||||
def save(self, data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Save method that does nothing.
|
||||
|
||||
:param data: A dictionary containing the data to save (ignored in this implementation).
|
||||
"""
|
||||
pass
|
||||
|
||||
def clear(self) -> None:
|
||||
"""
|
||||
Clear method that does nothing.
|
||||
"""
|
||||
pass
|
||||
@ -434,4 +434,8 @@ def get_predefined_prompt_templates() -> List[PromptTemplate]:
|
||||
"Question: {query}\n"
|
||||
"Thought: Let's think step-by-step, I first need to ",
|
||||
),
|
||||
PromptTemplate(
|
||||
name="conversational-summary",
|
||||
prompt_text="Condense the following chat transcript by shortening and summarizing the content without losing important information:\n{chat_transcript}\nCondensed Transcript:",
|
||||
),
|
||||
]
|
||||
|
||||
45
test/agents/test_memory.py
Normal file
45
test/agents/test_memory.py
Normal file
@ -0,0 +1,45 @@
|
||||
import pytest
|
||||
from typing import Dict, Any
|
||||
from haystack.agents.memory import NoMemory, ConversationMemory
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_memory():
|
||||
no_mem = NoMemory()
|
||||
assert no_mem.load() == ""
|
||||
no_mem.save({"key": "value"})
|
||||
no_mem.clear()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_conversation_memory():
|
||||
conv_mem = ConversationMemory()
|
||||
assert conv_mem.load() == ""
|
||||
data: Dict[str, Any] = {"input": "Hello", "output": "Hi there"}
|
||||
conv_mem.save(data)
|
||||
assert conv_mem.load() == "Human: Hello\nAI: Hi there\n"
|
||||
|
||||
data: Dict[str, Any] = {"input": "How are you?", "output": "I'm doing well, thanks."}
|
||||
conv_mem.save(data)
|
||||
assert conv_mem.load() == "Human: Hello\nAI: Hi there\nHuman: How are you?\nAI: I'm doing well, thanks.\n"
|
||||
assert conv_mem.load(window_size=1) == "Human: How are you?\nAI: I'm doing well, thanks.\n"
|
||||
|
||||
conv_mem.clear()
|
||||
assert conv_mem.load() == ""
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_conversation_memory_window_size():
|
||||
conv_mem = ConversationMemory()
|
||||
assert conv_mem.load() == ""
|
||||
data: Dict[str, Any] = {"input": "Hello", "output": "Hi there"}
|
||||
conv_mem.save(data)
|
||||
data: Dict[str, Any] = {"input": "How are you?", "output": "I'm doing well, thanks."}
|
||||
conv_mem.save(data)
|
||||
assert conv_mem.load() == "Human: Hello\nAI: Hi there\nHuman: How are you?\nAI: I'm doing well, thanks.\n"
|
||||
assert conv_mem.load(window_size=1) == "Human: How are you?\nAI: I'm doing well, thanks.\n"
|
||||
|
||||
# clear the memory
|
||||
conv_mem.clear()
|
||||
assert conv_mem.load() == ""
|
||||
assert conv_mem.load(window_size=1) == ""
|
||||
109
test/agents/test_summary_memory.py
Normal file
109
test/agents/test_summary_memory.py
Normal file
@ -0,0 +1,109 @@
|
||||
from unittest.mock import MagicMock
|
||||
from haystack.nodes import PromptNode, PromptTemplate
|
||||
import pytest
|
||||
from typing import Dict, Any
|
||||
|
||||
from haystack.agents.memory import ConversationSummaryMemory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mocked_prompt_node():
|
||||
mock_prompt_node = MagicMock(spec=PromptNode)
|
||||
mock_prompt_node.default_prompt_template = PromptTemplate(
|
||||
"conversational-summary", "Summarize the conversation: {chat_transcript}"
|
||||
)
|
||||
mock_prompt_node.prompt.return_value = ["This is a summary."]
|
||||
return mock_prompt_node
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_conversation_summary_memory(mocked_prompt_node):
|
||||
summary = "This is a fake summary definitely."
|
||||
mocked_prompt_node.prompt.return_value = [summary]
|
||||
summary_mem = ConversationSummaryMemory(mocked_prompt_node)
|
||||
|
||||
# Test saving and loading without summaries
|
||||
data1: Dict[str, Any] = {"input": "Hello", "output": "Hi there"}
|
||||
summary_mem.save(data1)
|
||||
assert summary_mem.load() == "\nHuman: Hello\nAI: Hi there\n"
|
||||
assert summary_mem.has_unsummarized_snippets()
|
||||
assert summary_mem.unsummarized_snippets() == 1
|
||||
|
||||
data2: Dict[str, Any] = {"input": "How are you?", "output": "I'm doing well, thanks."}
|
||||
summary_mem.save(data2)
|
||||
assert summary_mem.load() == "\nHuman: Hello\nAI: Hi there\nHuman: How are you?\nAI: I'm doing well, thanks.\n"
|
||||
assert summary_mem.has_unsummarized_snippets()
|
||||
assert summary_mem.unsummarized_snippets() == 2
|
||||
|
||||
# Test summarization
|
||||
data3: Dict[str, Any] = {"input": "What's the weather like?", "output": "It's sunny outside."}
|
||||
summary_mem.save(data3)
|
||||
assert summary_mem.load() == summary
|
||||
assert not summary_mem.has_unsummarized_snippets()
|
||||
assert summary_mem.unsummarized_snippets() == 0
|
||||
|
||||
summary_mem.clear()
|
||||
assert summary_mem.load() == ""
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_conversation_summary_memory_lower_summary_frequency(mocked_prompt_node):
|
||||
summary = "This is a fake summary definitely."
|
||||
mocked_prompt_node.prompt.return_value = [summary]
|
||||
summary_mem = ConversationSummaryMemory(mocked_prompt_node, summary_frequency=2)
|
||||
|
||||
data1: Dict[str, Any] = {"input": "Hello", "output": "Hi there"}
|
||||
summary_mem.save(data1)
|
||||
assert summary_mem.load() == "\nHuman: Hello\nAI: Hi there\n"
|
||||
assert summary_mem.has_unsummarized_snippets()
|
||||
assert summary_mem.unsummarized_snippets() == 1
|
||||
|
||||
# Test summarization
|
||||
data2: Dict[str, Any] = {"input": "How are you?", "output": "I'm doing well, thanks."}
|
||||
summary_mem.save(data2)
|
||||
assert summary_mem.load() == summary
|
||||
assert not summary_mem.has_unsummarized_snippets()
|
||||
assert summary_mem.unsummarized_snippets() == 0
|
||||
|
||||
data3: Dict[str, Any] = {"input": "What's the weather like?", "output": "It's sunny outside."}
|
||||
summary_mem.save(data3)
|
||||
assert summary_mem.load() == summary + "\nHuman: What's the weather like?\nAI: It's sunny outside.\n"
|
||||
assert summary_mem.has_unsummarized_snippets()
|
||||
assert summary_mem.unsummarized_snippets() == 1
|
||||
|
||||
summary_mem.clear()
|
||||
assert summary_mem.load() == ""
|
||||
|
||||
# start over
|
||||
summary_mem.save(data1)
|
||||
assert summary_mem.load() == "\nHuman: Hello\nAI: Hi there\n"
|
||||
assert summary_mem.has_unsummarized_snippets()
|
||||
assert summary_mem.unsummarized_snippets() == 1
|
||||
|
||||
# Test summarization
|
||||
data2: Dict[str, Any] = {"input": "How are you?", "output": "I'm doing well, thanks."}
|
||||
summary_mem.save(data2)
|
||||
assert summary_mem.load() == summary
|
||||
assert not summary_mem.has_unsummarized_snippets()
|
||||
assert summary_mem.unsummarized_snippets() == 0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_conversation_summary_memory_with_template(mocked_prompt_node):
|
||||
pt = PromptTemplate("conversational-summary", "Summarize the conversation: {chat_transcript}")
|
||||
summary_mem = ConversationSummaryMemory(mocked_prompt_node, prompt_template=pt)
|
||||
|
||||
data1: Dict[str, Any] = {"input": "Hello", "output": "Hi there"}
|
||||
summary_mem.save(data1)
|
||||
assert summary_mem.load() == "\nHuman: Hello\nAI: Hi there\n"
|
||||
|
||||
data2: Dict[str, Any] = {"input": "How are you?", "output": "I'm doing well, thanks."}
|
||||
summary_mem.save(data2)
|
||||
assert summary_mem.load() == "\nHuman: Hello\nAI: Hi there\nHuman: How are you?\nAI: I'm doing well, thanks.\n"
|
||||
|
||||
data3: Dict[str, Any] = {"input": "What's the weather like?", "output": "It's sunny outside."}
|
||||
summary_mem.save(data3)
|
||||
assert summary_mem.load() == "This is a summary."
|
||||
|
||||
summary_mem.clear()
|
||||
assert summary_mem.load() == ""
|
||||
@ -28,14 +28,14 @@ def get_api_key(request):
|
||||
def test_add_and_remove_template():
|
||||
with patch("haystack.nodes.prompt.prompt_node.PromptModel"):
|
||||
node = PromptNode()
|
||||
|
||||
total_count = 15
|
||||
# Verifies default
|
||||
assert len(node.get_prompt_template_names()) == 14
|
||||
assert len(node.get_prompt_template_names()) == total_count
|
||||
|
||||
# Add a fake template
|
||||
fake_template = PromptTemplate(name="fake-template", prompt_text="Fake prompt")
|
||||
node.add_prompt_template(fake_template)
|
||||
assert len(node.get_prompt_template_names()) == 15
|
||||
assert len(node.get_prompt_template_names()) == total_count + 1
|
||||
assert "fake-template" in node.get_prompt_template_names()
|
||||
|
||||
# Verify that adding the same template throws an expection
|
||||
@ -47,7 +47,7 @@ def test_add_and_remove_template():
|
||||
|
||||
# Verify template is correctly removed
|
||||
assert node.remove_prompt_template("fake-template")
|
||||
assert len(node.get_prompt_template_names()) == 14
|
||||
assert len(node.get_prompt_template_names()) == total_count
|
||||
assert "fake-template" not in node.get_prompt_template_names()
|
||||
|
||||
# Verify that removing the same template throws an expection
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user