diff --git a/haystack/agents/memory/__init__.py b/haystack/agents/memory/__init__.py new file mode 100644 index 000000000..4b8387c50 --- /dev/null +++ b/haystack/agents/memory/__init__.py @@ -0,0 +1,4 @@ +from haystack.agents.memory.base import Memory +from haystack.agents.memory.no_memory import NoMemory +from haystack.agents.memory.conversation_memory import ConversationMemory +from haystack.agents.memory.conversation_summary_memory import ConversationSummaryMemory diff --git a/haystack/agents/memory/base.py b/haystack/agents/memory/base.py new file mode 100644 index 000000000..b259b5db8 --- /dev/null +++ b/haystack/agents/memory/base.py @@ -0,0 +1,31 @@ +from abc import ABC, abstractmethod +from typing import Dict, Any, List, Optional + + +class Memory(ABC): + """ + Abstract base class for memory management in an Agent. + """ + + @abstractmethod + def load(self, keys: Optional[List[str]] = None, **kwargs) -> Any: + """ + Load the context of this model run from memory. + + :param keys: Optional list of keys to specify the data to load. + :return: The loaded data. + """ + + @abstractmethod + def save(self, data: Dict[str, Any]) -> None: + """ + Save the context of this model run to memory. + + :param data: A dictionary containing the data to save. + """ + + @abstractmethod + def clear(self) -> None: + """ + Clear memory contents. + """ diff --git a/haystack/agents/memory/conversation_memory.py b/haystack/agents/memory/conversation_memory.py new file mode 100644 index 000000000..3312517f1 --- /dev/null +++ b/haystack/agents/memory/conversation_memory.py @@ -0,0 +1,60 @@ +import collections +from typing import OrderedDict, List, Optional, Any, Dict + +from haystack.agents.memory import Memory + + +class ConversationMemory(Memory): + """ + A memory class that stores conversation history. + """ + + def __init__(self, input_key: str = "input", output_key: str = "output"): + """ + Initialize ConversationMemory with input and output keys. + + :param input_key: The key to use for storing user input. + :param output_key: The key to use for storing model output. + """ + self.list: List[OrderedDict] = [] + self.input_key = input_key + self.output_key = output_key + + def load(self, keys: Optional[List[str]] = None, **kwargs) -> str: + """ + Load conversation history as a formatted string. + + :param keys: Optional list of keys (ignored in this implementation). + :param kwargs: Optional keyword arguments + - window_size: integer specifying the number of most recent conversation snippets to load. + :return: A formatted string containing the conversation history. + """ + chat_transcript = "" + window_size = kwargs.get("window_size", None) + + if window_size is not None: + chat_list = self.list[-window_size:] # pylint: disable=invalid-unary-operand-type + else: + chat_list = self.list + + for chat_snippet in chat_list: + chat_transcript += f"Human: {chat_snippet['Human']}\n" + chat_transcript += f"AI: {chat_snippet['AI']}\n" + return chat_transcript + + def save(self, data: Dict[str, Any]) -> None: + """ + Save a conversation snippet to memory. + + :param data: A dictionary containing the conversation snippet to save. + """ + chat_snippet = collections.OrderedDict() + chat_snippet["Human"] = data[self.input_key] + chat_snippet["AI"] = data[self.output_key] + self.list.append(chat_snippet) + + def clear(self) -> None: + """ + Clear the conversation history. + """ + self.list = [] diff --git a/haystack/agents/memory/conversation_summary_memory.py b/haystack/agents/memory/conversation_summary_memory.py new file mode 100644 index 000000000..f6aaf7ac0 --- /dev/null +++ b/haystack/agents/memory/conversation_summary_memory.py @@ -0,0 +1,108 @@ +from typing import Optional, Union, Dict, Any, List + +from haystack.agents.memory import ConversationMemory +from haystack.nodes import PromptTemplate, PromptNode + + +class ConversationSummaryMemory(ConversationMemory): + """ + A memory class that stores conversation history and periodically generates summaries. + """ + + def __init__( + self, + prompt_node: PromptNode, + prompt_template: Optional[Union[str, PromptTemplate]] = None, + input_key: str = "input", + output_key: str = "output", + summary_frequency: int = 3, + ): + """ + Initialize ConversationSummaryMemory with a PromptNode, optional prompt_template, + input and output keys, and a summary_frequency. + + :param prompt_node: A PromptNode object for generating conversation summaries. + :param prompt_template: Optional prompt template as a string or PromptTemplate object. + :param input_key: input key, default is "input". + :param output_key: output key, default is "output". + :param summary_frequency: integer specifying how often to generate a summary (default is 3). + """ + super().__init__(input_key, output_key) + self.save_count = 0 + self.prompt_node = prompt_node + + template = ( + prompt_template + if prompt_template is not None + else prompt_node.default_prompt_template or "conversational-summary" + ) + self.template = prompt_node.get_prompt_template(template) + self.summary_frequency = summary_frequency + self.summary = "" + + def load(self, keys: Optional[List[str]] = None, **kwargs) -> str: + """ + Load conversation history as a formatted string, including the latest summary. + + :param keys: Optional list of keys (ignored in this implementation). + :param kwargs: Optional keyword arguments + - window_size: integer specifying the number of most recent conversation snippets to load. + :return: A formatted string containing the conversation history with the latest summary. + """ + if self.has_unsummarized_snippets(): + unsummarized = super().load(keys=keys, window_size=self.unsummarized_snippets()) + return f"{self.summary}\n{unsummarized}" + else: + return self.summary + + def summarize(self) -> str: + """ + Generate a summary of the conversation history and clear the history. + + :return: A string containing the generated summary. + """ + most_recent_chat_snippets = self.load(window_size=self.summary_frequency) + pn_response = self.prompt_node.prompt(self.template, chat_transcript=most_recent_chat_snippets) + return pn_response[0] + + def needs_summary(self) -> bool: + """ + Determine if a new summary should be generated. + + :return: True if a new summary should be generated, otherwise False. + """ + return self.save_count % self.summary_frequency == 0 + + def unsummarized_snippets(self) -> int: + """ + Returns how many conversation snippets have not been summarized. + :return: The number of conversation snippets that have not been summarized. + """ + return self.save_count % self.summary_frequency + + def has_unsummarized_snippets(self) -> bool: + """ + Returns True if there are any conversation snippets that have not been summarized. + :return: True if there are unsummarized snippets, otherwise False. + """ + return self.unsummarized_snippets() != 0 + + def save(self, data: Dict[str, Any]) -> None: + """ + Save a conversation snippet to memory and update the save count. + Generate a summary if needed. + + :param data: A dictionary containing the conversation snippet to save. + """ + super().save(data) + self.save_count += 1 + if self.needs_summary(): + self.summary = self.summarize() + + def clear(self) -> None: + """ + Clear the conversation history and the summary. + """ + super().clear() + self.save_count = 0 + self.summary = "" diff --git a/haystack/agents/memory/no_memory.py b/haystack/agents/memory/no_memory.py new file mode 100644 index 000000000..4f70ed3e8 --- /dev/null +++ b/haystack/agents/memory/no_memory.py @@ -0,0 +1,32 @@ +from typing import Optional, List, Any, Dict + +from haystack.agents.memory import Memory + + +class NoMemory(Memory): + """ + A memory class that doesn't store any data. + """ + + def load(self, keys: Optional[List[str]] = None, **kwargs) -> str: + """ + Load an empty dictionary. + + :param keys: Optional list of keys (ignored in this implementation). + :return: An empty str. + """ + return "" + + def save(self, data: Dict[str, Any]) -> None: + """ + Save method that does nothing. + + :param data: A dictionary containing the data to save (ignored in this implementation). + """ + pass + + def clear(self) -> None: + """ + Clear method that does nothing. + """ + pass diff --git a/haystack/nodes/prompt/prompt_template.py b/haystack/nodes/prompt/prompt_template.py index 871b8eb21..4811d7d28 100644 --- a/haystack/nodes/prompt/prompt_template.py +++ b/haystack/nodes/prompt/prompt_template.py @@ -434,4 +434,8 @@ def get_predefined_prompt_templates() -> List[PromptTemplate]: "Question: {query}\n" "Thought: Let's think step-by-step, I first need to ", ), + PromptTemplate( + name="conversational-summary", + prompt_text="Condense the following chat transcript by shortening and summarizing the content without losing important information:\n{chat_transcript}\nCondensed Transcript:", + ), ] diff --git a/test/agents/test_memory.py b/test/agents/test_memory.py new file mode 100644 index 000000000..92bf79d13 --- /dev/null +++ b/test/agents/test_memory.py @@ -0,0 +1,45 @@ +import pytest +from typing import Dict, Any +from haystack.agents.memory import NoMemory, ConversationMemory + + +@pytest.mark.unit +def test_no_memory(): + no_mem = NoMemory() + assert no_mem.load() == "" + no_mem.save({"key": "value"}) + no_mem.clear() + + +@pytest.mark.unit +def test_conversation_memory(): + conv_mem = ConversationMemory() + assert conv_mem.load() == "" + data: Dict[str, Any] = {"input": "Hello", "output": "Hi there"} + conv_mem.save(data) + assert conv_mem.load() == "Human: Hello\nAI: Hi there\n" + + data: Dict[str, Any] = {"input": "How are you?", "output": "I'm doing well, thanks."} + conv_mem.save(data) + assert conv_mem.load() == "Human: Hello\nAI: Hi there\nHuman: How are you?\nAI: I'm doing well, thanks.\n" + assert conv_mem.load(window_size=1) == "Human: How are you?\nAI: I'm doing well, thanks.\n" + + conv_mem.clear() + assert conv_mem.load() == "" + + +@pytest.mark.unit +def test_conversation_memory_window_size(): + conv_mem = ConversationMemory() + assert conv_mem.load() == "" + data: Dict[str, Any] = {"input": "Hello", "output": "Hi there"} + conv_mem.save(data) + data: Dict[str, Any] = {"input": "How are you?", "output": "I'm doing well, thanks."} + conv_mem.save(data) + assert conv_mem.load() == "Human: Hello\nAI: Hi there\nHuman: How are you?\nAI: I'm doing well, thanks.\n" + assert conv_mem.load(window_size=1) == "Human: How are you?\nAI: I'm doing well, thanks.\n" + + # clear the memory + conv_mem.clear() + assert conv_mem.load() == "" + assert conv_mem.load(window_size=1) == "" diff --git a/test/agents/test_summary_memory.py b/test/agents/test_summary_memory.py new file mode 100644 index 000000000..99bc50fd0 --- /dev/null +++ b/test/agents/test_summary_memory.py @@ -0,0 +1,109 @@ +from unittest.mock import MagicMock +from haystack.nodes import PromptNode, PromptTemplate +import pytest +from typing import Dict, Any + +from haystack.agents.memory import ConversationSummaryMemory + + +@pytest.fixture +def mocked_prompt_node(): + mock_prompt_node = MagicMock(spec=PromptNode) + mock_prompt_node.default_prompt_template = PromptTemplate( + "conversational-summary", "Summarize the conversation: {chat_transcript}" + ) + mock_prompt_node.prompt.return_value = ["This is a summary."] + return mock_prompt_node + + +@pytest.mark.unit +def test_conversation_summary_memory(mocked_prompt_node): + summary = "This is a fake summary definitely." + mocked_prompt_node.prompt.return_value = [summary] + summary_mem = ConversationSummaryMemory(mocked_prompt_node) + + # Test saving and loading without summaries + data1: Dict[str, Any] = {"input": "Hello", "output": "Hi there"} + summary_mem.save(data1) + assert summary_mem.load() == "\nHuman: Hello\nAI: Hi there\n" + assert summary_mem.has_unsummarized_snippets() + assert summary_mem.unsummarized_snippets() == 1 + + data2: Dict[str, Any] = {"input": "How are you?", "output": "I'm doing well, thanks."} + summary_mem.save(data2) + assert summary_mem.load() == "\nHuman: Hello\nAI: Hi there\nHuman: How are you?\nAI: I'm doing well, thanks.\n" + assert summary_mem.has_unsummarized_snippets() + assert summary_mem.unsummarized_snippets() == 2 + + # Test summarization + data3: Dict[str, Any] = {"input": "What's the weather like?", "output": "It's sunny outside."} + summary_mem.save(data3) + assert summary_mem.load() == summary + assert not summary_mem.has_unsummarized_snippets() + assert summary_mem.unsummarized_snippets() == 0 + + summary_mem.clear() + assert summary_mem.load() == "" + + +@pytest.mark.unit +def test_conversation_summary_memory_lower_summary_frequency(mocked_prompt_node): + summary = "This is a fake summary definitely." + mocked_prompt_node.prompt.return_value = [summary] + summary_mem = ConversationSummaryMemory(mocked_prompt_node, summary_frequency=2) + + data1: Dict[str, Any] = {"input": "Hello", "output": "Hi there"} + summary_mem.save(data1) + assert summary_mem.load() == "\nHuman: Hello\nAI: Hi there\n" + assert summary_mem.has_unsummarized_snippets() + assert summary_mem.unsummarized_snippets() == 1 + + # Test summarization + data2: Dict[str, Any] = {"input": "How are you?", "output": "I'm doing well, thanks."} + summary_mem.save(data2) + assert summary_mem.load() == summary + assert not summary_mem.has_unsummarized_snippets() + assert summary_mem.unsummarized_snippets() == 0 + + data3: Dict[str, Any] = {"input": "What's the weather like?", "output": "It's sunny outside."} + summary_mem.save(data3) + assert summary_mem.load() == summary + "\nHuman: What's the weather like?\nAI: It's sunny outside.\n" + assert summary_mem.has_unsummarized_snippets() + assert summary_mem.unsummarized_snippets() == 1 + + summary_mem.clear() + assert summary_mem.load() == "" + + # start over + summary_mem.save(data1) + assert summary_mem.load() == "\nHuman: Hello\nAI: Hi there\n" + assert summary_mem.has_unsummarized_snippets() + assert summary_mem.unsummarized_snippets() == 1 + + # Test summarization + data2: Dict[str, Any] = {"input": "How are you?", "output": "I'm doing well, thanks."} + summary_mem.save(data2) + assert summary_mem.load() == summary + assert not summary_mem.has_unsummarized_snippets() + assert summary_mem.unsummarized_snippets() == 0 + + +@pytest.mark.unit +def test_conversation_summary_memory_with_template(mocked_prompt_node): + pt = PromptTemplate("conversational-summary", "Summarize the conversation: {chat_transcript}") + summary_mem = ConversationSummaryMemory(mocked_prompt_node, prompt_template=pt) + + data1: Dict[str, Any] = {"input": "Hello", "output": "Hi there"} + summary_mem.save(data1) + assert summary_mem.load() == "\nHuman: Hello\nAI: Hi there\n" + + data2: Dict[str, Any] = {"input": "How are you?", "output": "I'm doing well, thanks."} + summary_mem.save(data2) + assert summary_mem.load() == "\nHuman: Hello\nAI: Hi there\nHuman: How are you?\nAI: I'm doing well, thanks.\n" + + data3: Dict[str, Any] = {"input": "What's the weather like?", "output": "It's sunny outside."} + summary_mem.save(data3) + assert summary_mem.load() == "This is a summary." + + summary_mem.clear() + assert summary_mem.load() == "" diff --git a/test/prompt/test_prompt_node.py b/test/prompt/test_prompt_node.py index 11a383441..33acf5edc 100644 --- a/test/prompt/test_prompt_node.py +++ b/test/prompt/test_prompt_node.py @@ -28,14 +28,14 @@ def get_api_key(request): def test_add_and_remove_template(): with patch("haystack.nodes.prompt.prompt_node.PromptModel"): node = PromptNode() - + total_count = 15 # Verifies default - assert len(node.get_prompt_template_names()) == 14 + assert len(node.get_prompt_template_names()) == total_count # Add a fake template fake_template = PromptTemplate(name="fake-template", prompt_text="Fake prompt") node.add_prompt_template(fake_template) - assert len(node.get_prompt_template_names()) == 15 + assert len(node.get_prompt_template_names()) == total_count + 1 assert "fake-template" in node.get_prompt_template_names() # Verify that adding the same template throws an expection @@ -47,7 +47,7 @@ def test_add_and_remove_template(): # Verify template is correctly removed assert node.remove_prompt_template("fake-template") - assert len(node.get_prompt_template_names()) == 14 + assert len(node.get_prompt_template_names()) == total_count assert "fake-template" not in node.get_prompt_template_names() # Verify that removing the same template throws an expection