mirror of
https://github.com/microsoft/autogen.git
synced 2025-07-13 20:11:00 +00:00
68 lines
2.7 KiB
Python
68 lines
2.7 KiB
Python
from typing import Any, List, Mapping
|
|
|
|
from agnext.components.memory import ChatMemory
|
|
from agnext.components.models import FunctionExecutionResultMessage
|
|
|
|
from ..types import FunctionCallMessage, Message, TextMessage
|
|
|
|
|
|
class HeadAndTailChatMemory(ChatMemory[Message]):
|
|
"""A chat memory that keeps a view of the first n and last m messages,
|
|
where n is the head size and m is the tail size. The head and tail sizes
|
|
are set at initialization.
|
|
|
|
Args:
|
|
head_size (int): The size of the head.
|
|
tail_size (int): The size of the tail.
|
|
"""
|
|
|
|
def __init__(self, head_size: int, tail_size: int) -> None:
|
|
self._messages: List[Message] = []
|
|
self._head_size = head_size
|
|
self._tail_size = tail_size
|
|
|
|
async def add_message(self, message: Message) -> None:
|
|
"""Add a message to the memory."""
|
|
self._messages.append(message)
|
|
|
|
async def get_messages(self) -> List[Message]:
|
|
"""Get at most `head_size` recent messages and `tail_size` oldest messages."""
|
|
head_messages = self._messages[: self._head_size]
|
|
# Handle the last message is a function call message.
|
|
if head_messages and isinstance(head_messages[-1], FunctionCallMessage):
|
|
# Remove the last message from the head.
|
|
head_messages = head_messages[:-1]
|
|
|
|
tail_messages = self._messages[-self._tail_size :]
|
|
# Handle the first message is a function call result message.
|
|
if tail_messages and isinstance(tail_messages[0], FunctionExecutionResultMessage):
|
|
# Remove the first message from the tail.
|
|
tail_messages = tail_messages[1:]
|
|
|
|
num_skipped = len(self._messages) - self._head_size - self._tail_size
|
|
if num_skipped <= 0:
|
|
# If there are not enough messages to fill the head and tail,
|
|
# return all messages.
|
|
return self._messages
|
|
|
|
placeholder_messages = [TextMessage(content=f"Skipped {num_skipped} messages.", source="System")]
|
|
return head_messages + placeholder_messages + tail_messages
|
|
|
|
async def clear(self) -> None:
|
|
"""Clear the message memory."""
|
|
self._messages = []
|
|
|
|
def save_state(self) -> Mapping[str, Any]:
|
|
return {
|
|
"messages": [message for message in self._messages],
|
|
"head_size": self._head_size,
|
|
"tail_size": self._tail_size,
|
|
"placeholder_message": self._placeholder_message,
|
|
}
|
|
|
|
def load_state(self, state: Mapping[str, Any]) -> None:
|
|
self._messages = state["messages"]
|
|
self._head_size = state["head_size"]
|
|
self._tail_size = state["tail_size"]
|
|
self._placeholder_message = state["placeholder_message"]
|