mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-26 14:38:50 +00:00
Add mem0 Memory Implementation (#6510)
<!-- Thank you for your contribution! Please review https://microsoft.github.io/autogen/docs/Contribute before opening a pull request. --> <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> ## Why are these changes needed? These changes are needed to expand AutoGen's memory capabilities with a robust, production-ready integration with Mem0.ai. <!-- Please give a short summary of the change and the problem this solves. --> This PR adds a new memory component for AutoGen that integrates with Mem0.ai, providing a robust memory solution that supports both cloud and local backends. The Mem0Memory class enables agents to store and retrieve information persistently across conversation sessions. ## Key Features - Seamless integration with Mem0.ai memory system - Support for both cloud-based and local storage backends - Robust error handling with detailed logging - Full implementation of AutoGen's Memory interface - Context updating for enhanced agent conversations - Configurable search parameters for memory retrieval ## Related issue number <!-- For example: "Closes #1234" --> ## Checks - [x] I've included any doc changes needed for <https://microsoft.github.io/autogen/>. See <https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to build and test documentation locally. - [x] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [ ] I've made sure all auto checks have passed. --------- Co-authored-by: Victor Dibia <victordibia@microsoft.com> Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> Co-authored-by: Ricky Loynd <riloynd@microsoft.com>
This commit is contained in:
parent
f101469e29
commit
89927ca436
94
.github/workflows/pytest-mem0.yml
vendored
Normal file
94
.github/workflows/pytest-mem0.yml
vendored
Normal file
@ -0,0 +1,94 @@
|
||||
name: Mem0 Memory Tests
|
||||
|
||||
on:
|
||||
# Run on pushes to any branch
|
||||
push:
|
||||
# Also run on pull requests to main
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
services:
|
||||
neo4j:
|
||||
image: neo4j:5.26.6
|
||||
ports:
|
||||
- 7474:7474 # HTTP
|
||||
- 7687:7687 # BOLT
|
||||
env:
|
||||
NEO4J_AUTH: neo4j/password
|
||||
NEO4J_dbms_security_procedures_unrestricted: apoc.*
|
||||
# Add this to ensure Neo4j is ready for connections quickly
|
||||
NEO4J_dbms_memory_pagecache_size: 100M
|
||||
NEO4J_dbms_memory_heap_initial__size: 100M
|
||||
NEO4J_dbms_memory_heap_max__size: 500M
|
||||
# Try a different health check approach
|
||||
options: >-
|
||||
--health-cmd "wget -O /dev/null -q http://localhost:7474 || exit 1"
|
||||
--health-interval 5s
|
||||
--health-timeout 15s
|
||||
--health-retries 10
|
||||
--health-start-period 30s
|
||||
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Wait for Neo4j
|
||||
run: |
|
||||
# Give Neo4j some extra time to start up
|
||||
sleep 10
|
||||
# Try to connect to Neo4j
|
||||
timeout 30s bash -c 'until curl -s http://localhost:7474 > /dev/null; do sleep 1; done'
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
# Install core packages first (in the right order)
|
||||
cd python/packages/autogen-core
|
||||
pip install -e .
|
||||
|
||||
cd ../autogen-agentchat
|
||||
pip install -e .
|
||||
|
||||
# Now install autogen-ext with its dependencies
|
||||
cd ../autogen-ext
|
||||
pip install -e ".[dev,mem0,mem0-local]"
|
||||
|
||||
# Install test dependencies
|
||||
pip install pytest pytest-asyncio pytest-cov
|
||||
pip install python-dotenv
|
||||
|
||||
# Install dependencies for complex configuration tests
|
||||
pip install "openai>=1.0.0"
|
||||
pip install deepseek-ai
|
||||
|
||||
# Update test config to match the simplified Neo4j setup
|
||||
- name: Update Neo4j password in tests
|
||||
run: |
|
||||
echo "NEO4J_PASSWORD=password" >> $GITHUB_ENV
|
||||
|
||||
- name: Run tests with coverage
|
||||
# env:
|
||||
# MEM0_API_KEY: ${{ secrets.MEM0_API_KEY }}
|
||||
# SF_API_KEY: ${{ secrets.SF_API_KEY }}
|
||||
# DEEPSEEK_API_KEY: ${{ secrets.DEEPSEEK_API_KEY }}
|
||||
run: |
|
||||
cd python/packages/autogen-ext
|
||||
pytest --cov=autogen_ext.memory.mem0 tests/memory/test_mem0.py -v --cov-report=xml
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v3
|
||||
with:
|
||||
file: ./python/packages/autogen-ext/coverage.xml
|
||||
name: codecov-mem0
|
||||
fail_ci_if_error: false
|
||||
@ -536,6 +536,98 @@
|
||||
"4. Optimize embedding models for your specific domain\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Mem0Memory Example\n",
|
||||
"\n",
|
||||
"`autogen_ext.memory.mem0.Mem0Memory` provides integration with `Mem0.ai`'s memory system. It supports both cloud-based and local backends, offering advanced memory capabilities for agents. The implementation handles proper retrieval and context updating, making it suitable for production environments.\n",
|
||||
"\n",
|
||||
"In the following example, we'll demonstrate how to use `Mem0Memory` to maintain persistent memories across conversations:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from autogen_agentchat.agents import AssistantAgent\n",
|
||||
"from autogen_agentchat.ui import Console\n",
|
||||
"from autogen_core.memory import MemoryContent, MemoryMimeType\n",
|
||||
"from autogen_ext.memory.mem0 import Mem0Memory\n",
|
||||
"from autogen_ext.models.openai import OpenAIChatCompletionClient\n",
|
||||
"\n",
|
||||
"# Initialize Mem0 cloud memory (requires API key)\n",
|
||||
"# For local deployment, use is_cloud=False with appropriate config\n",
|
||||
"mem0_memory = Mem0Memory(\n",
|
||||
" is_cloud=True,\n",
|
||||
" limit=5, # Maximum number of memories to retrieve\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Add user preferences to memory\n",
|
||||
"await mem0_memory.add(\n",
|
||||
" MemoryContent(\n",
|
||||
" content=\"The weather should be in metric units\",\n",
|
||||
" mime_type=MemoryMimeType.TEXT,\n",
|
||||
" metadata={\"category\": \"preferences\", \"type\": \"units\"},\n",
|
||||
" )\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"await mem0_memory.add(\n",
|
||||
" MemoryContent(\n",
|
||||
" content=\"Meal recipe must be vegan\",\n",
|
||||
" mime_type=MemoryMimeType.TEXT,\n",
|
||||
" metadata={\"category\": \"preferences\", \"type\": \"dietary\"},\n",
|
||||
" )\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Create assistant with mem0 memory\n",
|
||||
"assistant_agent = AssistantAgent(\n",
|
||||
" name=\"assistant_agent\",\n",
|
||||
" model_client=OpenAIChatCompletionClient(\n",
|
||||
" model=\"gpt-4o-2024-08-06\",\n",
|
||||
" ),\n",
|
||||
" tools=[get_weather],\n",
|
||||
" memory=[mem0_memory],\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Ask about the weather\n",
|
||||
"stream = assistant_agent.run_stream(task=\"What are my dietary preferences?\")\n",
|
||||
"await Console(stream)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The example above demonstrates how Mem0Memory can be used with an assistant agent. The memory integration ensures that:\n",
|
||||
"\n",
|
||||
"1. All agent interactions are stored in Mem0 for future reference\n",
|
||||
"2. Relevant memories (like user preferences) are automatically retrieved and added to the context\n",
|
||||
"3. The agent can maintain consistent behavior based on stored memories\n",
|
||||
"\n",
|
||||
"Mem0Memory is particularly useful for:\n",
|
||||
"- Long-running agent deployments that need persistent memory\n",
|
||||
"- Applications requiring enhanced privacy controls\n",
|
||||
"- Teams wanting unified memory management across agents\n",
|
||||
"- Use cases needing advanced memory filtering and analytics\n",
|
||||
"\n",
|
||||
"Just like ChromaDBVectorMemory, you can serialize Mem0Memory configurations:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Serialize the memory configuration\n",
|
||||
"config_json = mem0_memory.dump_component().model_dump_json()\n",
|
||||
"print(f\"Memory config JSON: {config_json[:100]}...\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
|
||||
@ -43,6 +43,12 @@ llama-cpp = [
|
||||
|
||||
graphrag = ["graphrag>=1.0.1"]
|
||||
chromadb = ["chromadb>=1.0.0"]
|
||||
mem0 = ["mem0ai>=0.1.98"]
|
||||
mem0-local = [
|
||||
"mem0ai>=0.1.98",
|
||||
"neo4j>=5.25.0",
|
||||
"chromadb>=1.0.0"
|
||||
]
|
||||
web-surfer = [
|
||||
"autogen-agentchat==0.6.1",
|
||||
"playwright>=1.48.0",
|
||||
|
||||
382
python/packages/autogen-ext/src/autogen_ext/memory/mem0.py
Normal file
382
python/packages/autogen-ext/src/autogen_ext/memory/mem0.py
Normal file
@ -0,0 +1,382 @@
|
||||
import io
|
||||
import logging
|
||||
import uuid
|
||||
from contextlib import redirect_stderr, redirect_stdout
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, TypedDict, cast
|
||||
|
||||
from autogen_core import CancellationToken, Component, ComponentBase
|
||||
from autogen_core.memory import Memory, MemoryContent, MemoryQueryResult, UpdateContextResult
|
||||
from autogen_core.model_context import ChatCompletionContext
|
||||
from autogen_core.models import SystemMessage
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Self
|
||||
|
||||
try:
|
||||
from mem0 import Memory as Memory0
|
||||
from mem0 import MemoryClient
|
||||
except ImportError as e:
|
||||
raise ImportError("`mem0ai` not installed. Please install it with `pip install mem0ai`") from e
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.getLogger("chromadb").setLevel(logging.ERROR)
|
||||
|
||||
|
||||
class Mem0MemoryConfig(BaseModel):
|
||||
"""Configuration for Mem0Memory component.
|
||||
|
||||
Attributes:
|
||||
user_id: Optional user ID for memory operations. If not provided, a UUID will be generated.
|
||||
limit: Maximum number of results to return in memory queries.
|
||||
is_cloud: Whether to use cloud Mem0 client (True) or local client (False).
|
||||
api_key: API key for cloud Mem0 client. Required if is_cloud=True.
|
||||
config: Configuration dictionary for local Mem0 client. Required if is_cloud=False.
|
||||
"""
|
||||
|
||||
user_id: Optional[str] = Field(
|
||||
default=None, description="User ID for memory operations. If not provided, a UUID will be generated."
|
||||
)
|
||||
limit: int = Field(default=10, description="Maximum number of results to return in memory queries.")
|
||||
is_cloud: bool = Field(default=True, description="Whether to use cloud Mem0 client (True) or local client (False).")
|
||||
api_key: Optional[str] = Field(
|
||||
default=None, description="API key for cloud Mem0 client. Required if is_cloud=True."
|
||||
)
|
||||
config: Optional[Dict[str, Any]] = Field(
|
||||
default=None, description="Configuration dictionary for local Mem0 client. Required if is_cloud=False."
|
||||
)
|
||||
|
||||
|
||||
class MemoryResult(TypedDict, total=False):
|
||||
memory: str
|
||||
score: float
|
||||
metadata: Dict[str, Any]
|
||||
created_at: str
|
||||
updated_at: str
|
||||
categories: List[str]
|
||||
|
||||
|
||||
# pyright: reportGeneralTypeIssues=false
|
||||
class Mem0Memory(Memory, Component[Mem0MemoryConfig], ComponentBase[Mem0MemoryConfig]):
|
||||
"""Mem0 memory implementation for AutoGen.
|
||||
|
||||
This component integrates with Mem0.ai's memory system, providing an implementation
|
||||
of AutoGen's Memory interface. It supports both cloud and local backends through the
|
||||
mem0ai Python package.
|
||||
|
||||
The memory component can store and retrieve information that agents need to remember
|
||||
across conversations. It also provides context updating for language models with
|
||||
relevant memories.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
# Create a cloud Mem0Memory
|
||||
memory = Mem0Memory(is_cloud=True)
|
||||
|
||||
# Add something to memory
|
||||
await memory.add(MemoryContent(content="Important information to remember"))
|
||||
|
||||
# Retrieve memories with a search query
|
||||
results = await memory.query("relevant information")
|
||||
```
|
||||
"""
|
||||
|
||||
component_type = "memory"
|
||||
component_provider_override = "autogen_ext.memory.mem0.Mem0Memory"
|
||||
component_config_schema = Mem0MemoryConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
is_cloud: bool = True,
|
||||
api_key: Optional[str] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Initialize Mem0Memory.
|
||||
|
||||
Args:
|
||||
user_id: Optional user ID for memory operations. If not provided, a UUID will be generated.
|
||||
limit: Maximum number of results to return in memory queries.
|
||||
is_cloud: Whether to use cloud Mem0 client (True) or local client (False).
|
||||
api_key: API key for cloud Mem0 client. It will read from the environment MEM0_API_KEY if not provided.
|
||||
config: Configuration dictionary for local Mem0 client. Required if is_cloud=False.
|
||||
|
||||
Raises:
|
||||
ValueError: If is_cloud=True and api_key is None, or if is_cloud=False and config is None.
|
||||
"""
|
||||
# Validate parameters
|
||||
if not is_cloud and config is None:
|
||||
raise ValueError("config is required when using local Mem0 client (is_cloud=False)")
|
||||
|
||||
# Initialize instance variables
|
||||
self._user_id = user_id or str(uuid.uuid4())
|
||||
self._limit = limit
|
||||
self._is_cloud = is_cloud
|
||||
self._api_key = api_key
|
||||
self._config = config
|
||||
|
||||
# Initialize client
|
||||
if self._is_cloud:
|
||||
self._client = MemoryClient(api_key=self._api_key)
|
||||
else:
|
||||
assert self._config is not None
|
||||
config_dict = self._config
|
||||
self._client = Memory0.from_config(config_dict=config_dict) # type: ignore
|
||||
|
||||
@property
|
||||
def user_id(self) -> str:
|
||||
"""Get the user ID for memory operations."""
|
||||
return self._user_id
|
||||
|
||||
@property
|
||||
def limit(self) -> int:
|
||||
"""Get the maximum number of results to return in memory queries."""
|
||||
return self._limit
|
||||
|
||||
@property
|
||||
def is_cloud(self) -> bool:
|
||||
"""Check if the Mem0 client is cloud-based."""
|
||||
return self._is_cloud
|
||||
|
||||
@property
|
||||
def config(self) -> Optional[Dict[str, Any]]:
|
||||
"""Get the configuration for the Mem0 client."""
|
||||
return self._config
|
||||
|
||||
async def add(
|
||||
self,
|
||||
content: MemoryContent,
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> None:
|
||||
"""Add content to memory.
|
||||
|
||||
Args:
|
||||
content: The memory content to add.
|
||||
cancellation_token: Optional token to cancel operation.
|
||||
|
||||
Raises:
|
||||
Exception: If there's an error adding content to mem0 memory.
|
||||
"""
|
||||
# Extract content based on mime type
|
||||
if hasattr(content, "content") and hasattr(content, "mime_type"):
|
||||
if content.mime_type in ["text/plain", "text/markdown"]:
|
||||
message = str(content.content)
|
||||
elif content.mime_type == "application/json":
|
||||
# Convert JSON content to string representation
|
||||
if isinstance(content.content, str):
|
||||
message = content.content
|
||||
else:
|
||||
# Convert dict or other JSON serializable objects to string
|
||||
import json
|
||||
|
||||
message = json.dumps(content.content)
|
||||
else:
|
||||
message = str(content.content)
|
||||
|
||||
# Extract metadata
|
||||
metadata = content.metadata or {}
|
||||
else:
|
||||
# Handle case where content is directly provided as string
|
||||
message = str(content)
|
||||
metadata = {}
|
||||
|
||||
# Check if operation is cancelled
|
||||
if cancellation_token is not None and cancellation_token.cancelled: # type: ignore
|
||||
return
|
||||
|
||||
# Add to mem0 client
|
||||
try:
|
||||
user_id = metadata.pop("user_id", self._user_id)
|
||||
# Suppress warning messages from mem0 MemoryClient
|
||||
kwargs = {} if self._client.__class__.__name__ == "Memory" else {"output_format": "v1.1"}
|
||||
with redirect_stdout(io.StringIO()), redirect_stderr(io.StringIO()):
|
||||
self._client.add(message, user_id=user_id, metadata=metadata, **kwargs) # type: ignore
|
||||
except Exception as e:
|
||||
# Log the error but don't crash
|
||||
logger.error(f"Error adding to mem0 memory: {str(e)}")
|
||||
raise
|
||||
|
||||
async def query(
|
||||
self,
|
||||
query: str | MemoryContent = "",
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
**kwargs: Any,
|
||||
) -> MemoryQueryResult:
|
||||
"""Query memory for relevant content.
|
||||
|
||||
Args:
|
||||
query: The query to search for, either as string or MemoryContent.
|
||||
cancellation_token: Optional token to cancel operation.
|
||||
**kwargs: Additional query parameters to pass to mem0.
|
||||
|
||||
Returns:
|
||||
MemoryQueryResult containing search results.
|
||||
"""
|
||||
# Extract query text
|
||||
if isinstance(query, str):
|
||||
query_text = query
|
||||
elif hasattr(query, "content"):
|
||||
query_text = str(query.content)
|
||||
else:
|
||||
query_text = str(query)
|
||||
|
||||
# Check if operation is cancelled
|
||||
if (
|
||||
cancellation_token
|
||||
and hasattr(cancellation_token, "cancelled")
|
||||
and getattr(cancellation_token, "cancelled", False)
|
||||
):
|
||||
return MemoryQueryResult(results=[])
|
||||
|
||||
try:
|
||||
limit = kwargs.pop("limit", self._limit)
|
||||
with redirect_stdout(io.StringIO()), redirect_stderr(io.StringIO()):
|
||||
# Query mem0 client
|
||||
results = self._client.search( # type: ignore
|
||||
query_text,
|
||||
user_id=self._user_id,
|
||||
limit=limit,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Type-safe handling of results
|
||||
if isinstance(results, dict) and "results" in results:
|
||||
result_list = cast(List[MemoryResult], results["results"])
|
||||
else:
|
||||
result_list = cast(List[MemoryResult], results)
|
||||
|
||||
# Convert results to MemoryContent objects
|
||||
memory_contents: List[MemoryContent] = []
|
||||
for result in result_list:
|
||||
content_text = result.get("memory", "")
|
||||
metadata: Dict[str, Any] = {}
|
||||
|
||||
if "metadata" in result and result["metadata"]:
|
||||
metadata = result["metadata"]
|
||||
|
||||
# Add relevant fields to metadata
|
||||
if "score" in result:
|
||||
metadata["score"] = result["score"]
|
||||
|
||||
# For created_at
|
||||
if "created_at" in result and result.get("created_at"):
|
||||
try:
|
||||
metadata["created_at"] = datetime.fromisoformat(result["created_at"])
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# For updated_at
|
||||
if "updated_at" in result and result.get("updated_at"):
|
||||
try:
|
||||
metadata["updated_at"] = datetime.fromisoformat(result["updated_at"])
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# For categories
|
||||
if "categories" in result and result.get("categories"):
|
||||
metadata["categories"] = result["categories"]
|
||||
|
||||
# Create MemoryContent object
|
||||
memory_content = MemoryContent(
|
||||
content=content_text,
|
||||
mime_type="text/plain", # Default to text/plain
|
||||
metadata=metadata,
|
||||
)
|
||||
memory_contents.append(memory_content)
|
||||
|
||||
return MemoryQueryResult(results=memory_contents)
|
||||
|
||||
except Exception as e:
|
||||
# Log the error but return empty results
|
||||
logger.error(f"Error querying mem0 memory: {str(e)}")
|
||||
return MemoryQueryResult(results=[])
|
||||
|
||||
async def update_context(
|
||||
self,
|
||||
model_context: ChatCompletionContext,
|
||||
) -> UpdateContextResult:
|
||||
"""Update the model context with relevant memories.
|
||||
|
||||
This method retrieves the conversation history from the model context,
|
||||
uses the last message as a query to find relevant memories, and then
|
||||
adds those memories to the context as a system message.
|
||||
|
||||
Args:
|
||||
model_context: The model context to update.
|
||||
|
||||
Returns:
|
||||
UpdateContextResult containing memories added to the context.
|
||||
"""
|
||||
# Get messages from context
|
||||
messages = await model_context.get_messages()
|
||||
if not messages:
|
||||
return UpdateContextResult(memories=MemoryQueryResult(results=[]))
|
||||
|
||||
# Use the last message as query
|
||||
last_message = messages[-1]
|
||||
query_text = last_message.content if isinstance(last_message.content, str) else str(last_message)
|
||||
|
||||
# Query memory
|
||||
query_results = await self.query(query_text, limit=self._limit)
|
||||
|
||||
# If we have results, add them to the context
|
||||
if query_results.results:
|
||||
# Format memories as numbered list
|
||||
memory_strings = [f"{i}. {str(memory.content)}" for i, memory in enumerate(query_results.results, 1)]
|
||||
memory_context = "\nRelevant memories:\n" + "\n".join(memory_strings)
|
||||
|
||||
# Add as system message
|
||||
await model_context.add_message(SystemMessage(content=memory_context))
|
||||
|
||||
return UpdateContextResult(memories=query_results)
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear all content from memory for the current user.
|
||||
|
||||
Raises:
|
||||
Exception: If there's an error clearing mem0 memory.
|
||||
"""
|
||||
try:
|
||||
self._client.delete_all(user_id=self._user_id) # type: ignore
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing mem0 memory: {str(e)}")
|
||||
raise
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Clean up resources if needed.
|
||||
|
||||
This is a no-op for Mem0 clients as they don't require explicit cleanup.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: Mem0MemoryConfig) -> Self:
|
||||
"""Create instance from configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration for Mem0Memory component.
|
||||
|
||||
Returns:
|
||||
A new Mem0Memory instance.
|
||||
"""
|
||||
return cls(
|
||||
user_id=config.user_id,
|
||||
limit=config.limit,
|
||||
is_cloud=config.is_cloud,
|
||||
api_key=config.api_key,
|
||||
config=config.config,
|
||||
)
|
||||
|
||||
def _to_config(self) -> Mem0MemoryConfig:
|
||||
"""Convert instance to configuration.
|
||||
|
||||
Returns:
|
||||
Configuration representing this Mem0Memory instance.
|
||||
"""
|
||||
return Mem0MemoryConfig(
|
||||
user_id=self._user_id,
|
||||
limit=self._limit,
|
||||
is_cloud=self._is_cloud,
|
||||
api_key=self._api_key,
|
||||
config=self._config,
|
||||
)
|
||||
522
python/packages/autogen-ext/tests/memory/test_mem0.py
Normal file
522
python/packages/autogen-ext/tests/memory/test_mem0.py
Normal file
@ -0,0 +1,522 @@
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from autogen_core.memory import MemoryContent, MemoryMimeType
|
||||
from autogen_core.model_context import BufferedChatCompletionContext
|
||||
from autogen_core.models import SystemMessage, UserMessage
|
||||
from autogen_ext.memory.mem0 import Mem0Memory, Mem0MemoryConfig
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
# Skip tests if required environment variables are not set
|
||||
mem0_api_key = os.environ.get("MEM0_API_KEY")
|
||||
requires_mem0_api = pytest.mark.skipif(mem0_api_key is None, reason="MEM0_API_KEY environment variable not set")
|
||||
|
||||
# Skip tests if mem0ai is not installed
|
||||
mem0 = pytest.importorskip("mem0")
|
||||
|
||||
# Define local configuration at the top of the module
|
||||
FULL_LOCAL_CONFIG: Dict[str, Any] = {
|
||||
"history_db_path": ":memory:", # Use in-memory DB for tests
|
||||
"graph_store": {
|
||||
"provider": "mock_graph",
|
||||
"config": {"url": "mock://localhost:7687", "username": "mock", "password": "mock_password"},
|
||||
},
|
||||
"embedder": {
|
||||
"provider": "mock_embedder",
|
||||
"config": {
|
||||
"model": "mock-embedding-model",
|
||||
"embedding_dims": 1024,
|
||||
"api_key": "mock-api-key",
|
||||
},
|
||||
},
|
||||
"vector_store": {"provider": "mock_vector", "config": {"path": ":memory:", "collection_name": "test_memories"}},
|
||||
"llm": {
|
||||
"provider": "mock_llm",
|
||||
"config": {
|
||||
"model": "mock-chat-model",
|
||||
"api_key": "mock-api-key",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def full_local_config() -> Dict[str, Any]:
|
||||
"""Return the local configuration for testing."""
|
||||
return FULL_LOCAL_CONFIG
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cloud_config() -> Mem0MemoryConfig:
|
||||
"""Create cloud configuration with real API key."""
|
||||
api_key = os.environ.get("MEM0_API_KEY")
|
||||
return Mem0MemoryConfig(user_id="test-user", limit=3, is_cloud=True, api_key=api_key)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def local_config() -> Mem0MemoryConfig:
|
||||
"""Create local configuration for testing."""
|
||||
return Mem0MemoryConfig(user_id="test-user", limit=3, is_cloud=False, config={"path": ":memory:"})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("autogen_ext.memory.mem0.Memory0")
|
||||
async def test_basic_workflow(mock_mem0_class: MagicMock, local_config: Mem0MemoryConfig) -> None:
|
||||
"""Test basic memory operations."""
|
||||
# Setup mock
|
||||
mock_mem0 = MagicMock()
|
||||
mock_mem0_class.from_config.return_value = mock_mem0
|
||||
|
||||
# Mock search results
|
||||
mock_mem0.search.return_value = [
|
||||
{
|
||||
"memory": "Paris is known for the Eiffel Tower and amazing cuisine.",
|
||||
"score": 0.95,
|
||||
"metadata": {"category": "city", "country": "France"},
|
||||
}
|
||||
]
|
||||
|
||||
memory = Mem0Memory(
|
||||
user_id=local_config.user_id,
|
||||
limit=local_config.limit,
|
||||
is_cloud=local_config.is_cloud,
|
||||
api_key=local_config.api_key,
|
||||
config=local_config.config,
|
||||
)
|
||||
|
||||
# Add content to memory
|
||||
await memory.add(
|
||||
MemoryContent(
|
||||
content="Paris is known for the Eiffel Tower and amazing cuisine.",
|
||||
mime_type=MemoryMimeType.TEXT,
|
||||
metadata={"category": "city", "country": "France"},
|
||||
)
|
||||
)
|
||||
|
||||
# Verify add was called correctly
|
||||
mock_mem0.add.assert_called_once()
|
||||
call_args = mock_mem0.add.call_args[0]
|
||||
assert call_args[0] == "Paris is known for the Eiffel Tower and amazing cuisine."
|
||||
call_kwargs = mock_mem0.add.call_args[1]
|
||||
assert call_kwargs["metadata"] == {"category": "city", "country": "France"}
|
||||
|
||||
# Query memory
|
||||
results = await memory.query("Tell me about Paris")
|
||||
|
||||
# Verify search was called correctly
|
||||
mock_mem0.search.assert_called_once()
|
||||
search_args = mock_mem0.search.call_args
|
||||
assert search_args[0][0] == "Tell me about Paris"
|
||||
assert search_args[1]["user_id"] == "test-user"
|
||||
assert search_args[1]["limit"] == 3
|
||||
|
||||
# Verify results
|
||||
assert len(results.results) == 1
|
||||
assert "Paris" in str(results.results[0].content)
|
||||
res_metadata = results.results[0].metadata
|
||||
assert res_metadata is not None and res_metadata.get("score") == 0.95
|
||||
assert res_metadata is not None and res_metadata.get("country") == "France"
|
||||
|
||||
# Test clear (only do this once)
|
||||
await memory.clear()
|
||||
mock_mem0.delete_all.assert_called_once_with(user_id="test-user")
|
||||
|
||||
# Cleanup
|
||||
await memory.close()
|
||||
|
||||
|
||||
@requires_mem0_api
|
||||
@pytest.mark.asyncio
|
||||
@patch("autogen_ext.memory.mem0.MemoryClient") # Patch MemoryClient instead of Memory0
|
||||
async def test_basic_workflow_with_cloud(mock_memory_client_class: MagicMock, cloud_config: Mem0MemoryConfig) -> None:
|
||||
"""Test basic memory operations with cloud client (mocked instead of real API)."""
|
||||
# Setup mock
|
||||
mock_client = MagicMock()
|
||||
mock_memory_client_class.return_value = mock_client
|
||||
|
||||
# Mock search results
|
||||
mock_client.search.return_value = [
|
||||
{
|
||||
"memory": "Test memory content for cloud",
|
||||
"score": 0.98,
|
||||
"metadata": {"test": True, "source": "cloud"},
|
||||
}
|
||||
]
|
||||
|
||||
memory = Mem0Memory(
|
||||
user_id=cloud_config.user_id,
|
||||
limit=cloud_config.limit,
|
||||
is_cloud=cloud_config.is_cloud,
|
||||
api_key=cloud_config.api_key,
|
||||
config=cloud_config.config,
|
||||
)
|
||||
|
||||
# Generate a unique test content string
|
||||
test_content = f"Test memory content {uuid.uuid4()}"
|
||||
|
||||
# Add content to memory
|
||||
await memory.add(
|
||||
MemoryContent(
|
||||
content=test_content,
|
||||
mime_type=MemoryMimeType.TEXT,
|
||||
metadata={"test": True, "timestamp": datetime.now().isoformat()},
|
||||
)
|
||||
)
|
||||
|
||||
# Verify add was called correctly
|
||||
mock_client.add.assert_called_once()
|
||||
call_args = mock_client.add.call_args
|
||||
assert test_content in str(call_args[0][0]) # Check that content was passed
|
||||
assert call_args[1]["user_id"] == cloud_config.user_id
|
||||
assert call_args[1]["metadata"]["test"] is True
|
||||
|
||||
# Query memory
|
||||
results = await memory.query(test_content)
|
||||
|
||||
# Verify search was called correctly
|
||||
mock_client.search.assert_called_once()
|
||||
search_args = mock_client.search.call_args
|
||||
assert test_content in search_args[0][0]
|
||||
assert search_args[1]["user_id"] == cloud_config.user_id
|
||||
|
||||
# Verify results
|
||||
assert len(results.results) == 1
|
||||
assert "Test memory content for cloud" in str(results.results[0].content)
|
||||
assert results.results[0].metadata is not None
|
||||
assert results.results[0].metadata.get("score") == 0.98
|
||||
|
||||
# Test clear
|
||||
await memory.clear()
|
||||
mock_client.delete_all.assert_called_once_with(user_id=cloud_config.user_id)
|
||||
|
||||
# Cleanup
|
||||
await memory.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("autogen_ext.memory.mem0.Memory0")
|
||||
async def test_metadata_handling(mock_mem0_class: MagicMock, local_config: Mem0MemoryConfig) -> None:
|
||||
"""Test metadata handling."""
|
||||
# Setup mock
|
||||
mock_mem0 = MagicMock()
|
||||
mock_mem0_class.from_config.return_value = mock_mem0
|
||||
|
||||
# Setup mock search results with rich metadata
|
||||
mock_mem0.search.return_value = [
|
||||
{
|
||||
"memory": "Test content with metadata",
|
||||
"score": 0.9,
|
||||
"metadata": {"test_category": "test", "test_priority": 1, "test_weight": 0.5, "test_verified": True},
|
||||
"created_at": "2023-01-01T12:00:00",
|
||||
"updated_at": "2023-01-02T12:00:00",
|
||||
"categories": ["test", "example"],
|
||||
}
|
||||
]
|
||||
|
||||
memory = Mem0Memory(
|
||||
user_id=local_config.user_id,
|
||||
limit=local_config.limit,
|
||||
is_cloud=local_config.is_cloud,
|
||||
api_key=local_config.api_key,
|
||||
config=local_config.config,
|
||||
)
|
||||
|
||||
# Add content with metadata
|
||||
test_content = "Test content with specific metadata"
|
||||
content = MemoryContent(
|
||||
content=test_content,
|
||||
mime_type=MemoryMimeType.TEXT,
|
||||
metadata={"test_category": "test", "test_priority": 1, "test_weight": 0.5, "test_verified": True},
|
||||
)
|
||||
await memory.add(content)
|
||||
|
||||
# Verify metadata was passed correctly
|
||||
add_kwargs = mock_mem0.add.call_args[1]
|
||||
assert add_kwargs["metadata"]["test_category"] == "test"
|
||||
assert add_kwargs["metadata"]["test_priority"] == 1
|
||||
assert add_kwargs["metadata"]["test_weight"] == 0.5
|
||||
assert add_kwargs["metadata"]["test_verified"] is True
|
||||
|
||||
# Query and check returned metadata
|
||||
results = await memory.query(test_content)
|
||||
assert len(results.results) == 1
|
||||
result = results.results[0]
|
||||
|
||||
# Verify metadata in results
|
||||
assert result.metadata is not None and result.metadata.get("test_category") == "test"
|
||||
assert result.metadata is not None and result.metadata.get("test_priority") == 1
|
||||
assert result.metadata is not None and isinstance(result.metadata.get("test_weight"), float)
|
||||
assert result.metadata is not None and result.metadata.get("test_verified") is True
|
||||
assert result.metadata is not None and "created_at" in result.metadata
|
||||
assert result.metadata is not None and "updated_at" in result.metadata
|
||||
assert result.metadata is not None and result.metadata.get("categories") == ["test", "example"]
|
||||
|
||||
# Cleanup
|
||||
await memory.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("autogen_ext.memory.mem0.Memory0")
|
||||
async def test_update_context(mock_mem0_class: MagicMock, local_config: Mem0MemoryConfig) -> None:
|
||||
"""Test updating model context with retrieved memories."""
|
||||
# Setup mock
|
||||
mock_mem0 = MagicMock()
|
||||
mock_mem0_class.from_config.return_value = mock_mem0
|
||||
|
||||
# Setup mock search results
|
||||
mock_mem0.search.return_value = [
|
||||
{"memory": "Jupiter is the largest planet in our solar system.", "score": 0.9},
|
||||
{"memory": "Jupiter has many moons, including Ganymede, Europa, and Io.", "score": 0.8},
|
||||
]
|
||||
|
||||
memory = Mem0Memory(
|
||||
user_id=local_config.user_id,
|
||||
limit=local_config.limit,
|
||||
is_cloud=local_config.is_cloud,
|
||||
api_key=local_config.api_key,
|
||||
config=local_config.config,
|
||||
)
|
||||
|
||||
# Create a model context with a message
|
||||
context = BufferedChatCompletionContext(buffer_size=5)
|
||||
await context.add_message(UserMessage(content="Tell me about Jupiter", source="user"))
|
||||
|
||||
# Update context with memory
|
||||
result = await memory.update_context(context)
|
||||
|
||||
# Verify results
|
||||
assert len(result.memories.results) == 2
|
||||
assert "Jupiter" in str(result.memories.results[0].content)
|
||||
|
||||
# Verify search was called with correct query
|
||||
mock_mem0.search.assert_called_once()
|
||||
search_args = mock_mem0.search.call_args
|
||||
assert "Jupiter" in search_args[0][0]
|
||||
|
||||
# Verify context was updated with a system message
|
||||
messages = await context.get_messages()
|
||||
assert len(messages) == 2 # Original message + system message with memories
|
||||
|
||||
# Verify system message content
|
||||
system_message = messages[1]
|
||||
assert isinstance(system_message, SystemMessage)
|
||||
assert "Jupiter is the largest planet" in system_message.content
|
||||
assert "Jupiter has many moons" in system_message.content
|
||||
|
||||
# Cleanup
|
||||
await memory.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("autogen_ext.memory.mem0.MemoryClient") # Patch for cloud mode
|
||||
async def test_component_serialization(mock_memory_client_class: MagicMock) -> None:
|
||||
"""Test serialization and deserialization of the component."""
|
||||
# Setup mock
|
||||
mock_client = MagicMock()
|
||||
mock_memory_client_class.return_value = mock_client
|
||||
|
||||
# Create configuration
|
||||
user_id = str(uuid.uuid4())
|
||||
config = Mem0MemoryConfig(
|
||||
user_id=user_id,
|
||||
limit=5,
|
||||
is_cloud=True,
|
||||
)
|
||||
|
||||
# Create memory instance
|
||||
memory = Mem0Memory(
|
||||
user_id=config.user_id,
|
||||
limit=config.limit,
|
||||
is_cloud=config.is_cloud,
|
||||
api_key=config.api_key,
|
||||
config=config.config,
|
||||
)
|
||||
|
||||
# Dump config
|
||||
memory_config = memory.dump_component()
|
||||
|
||||
# Verify dumped config
|
||||
assert memory_config.config["user_id"] == user_id
|
||||
assert memory_config.config["limit"] == 5
|
||||
assert memory_config.config["is_cloud"] is True
|
||||
|
||||
# Load from config
|
||||
loaded_memory = Mem0Memory(
|
||||
user_id=config.user_id,
|
||||
limit=config.limit,
|
||||
is_cloud=config.is_cloud,
|
||||
api_key=config.api_key,
|
||||
config=config.config,
|
||||
)
|
||||
|
||||
# Verify loaded instance
|
||||
assert isinstance(loaded_memory, Mem0Memory)
|
||||
assert loaded_memory.user_id == user_id
|
||||
assert loaded_memory.limit == 5
|
||||
assert loaded_memory.is_cloud is True
|
||||
assert loaded_memory.config is None
|
||||
|
||||
# Cleanup
|
||||
await memory.close()
|
||||
await loaded_memory.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("autogen_ext.memory.mem0.Memory0")
|
||||
async def test_result_format_handling(mock_mem0_class: MagicMock, local_config: Mem0MemoryConfig) -> None:
|
||||
"""Test handling of different result formats."""
|
||||
# Setup mock
|
||||
mock_mem0 = MagicMock()
|
||||
mock_mem0_class.from_config.return_value = mock_mem0
|
||||
|
||||
# Test dictionary format with "results" key
|
||||
mock_mem0.search.return_value = {
|
||||
"results": [
|
||||
{"memory": "Dictionary format result 1", "score": 0.9},
|
||||
{"memory": "Dictionary format result 2", "score": 0.8},
|
||||
]
|
||||
}
|
||||
|
||||
memory = Mem0Memory(
|
||||
user_id=local_config.user_id,
|
||||
limit=local_config.limit,
|
||||
is_cloud=local_config.is_cloud,
|
||||
api_key=local_config.api_key,
|
||||
config=local_config.config,
|
||||
)
|
||||
|
||||
# Query with dictionary format
|
||||
results_dict = await memory.query("test query")
|
||||
|
||||
# Verify results were extracted correctly
|
||||
assert len(results_dict.results) == 2
|
||||
assert "Dictionary format result 1" in str(results_dict.results[0].content)
|
||||
|
||||
# Test list format
|
||||
mock_mem0.search.return_value = [
|
||||
{"memory": "List format result 1", "score": 0.9},
|
||||
{"memory": "List format result 2", "score": 0.8},
|
||||
]
|
||||
|
||||
# Query with list format
|
||||
results_list = await memory.query("test query")
|
||||
|
||||
# Verify results were processed correctly
|
||||
assert len(results_list.results) == 2
|
||||
assert "List format result 1" in str(results_list.results[0].content)
|
||||
|
||||
# Cleanup
|
||||
await memory.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("autogen_ext.memory.mem0.Memory0")
|
||||
async def test_init_with_local_config(mock_mem0_class: MagicMock, full_local_config: Dict[str, Any]) -> None:
|
||||
"""Test initializing memory with local configuration."""
|
||||
# Setup mock
|
||||
mock_mem0 = MagicMock()
|
||||
mock_mem0_class.from_config.return_value = mock_mem0
|
||||
|
||||
# Initialize memory with local config
|
||||
memory = Mem0Memory(user_id="test-local-config-user", limit=10, is_cloud=False, config=full_local_config)
|
||||
|
||||
# Verify configuration was passed correctly
|
||||
mock_mem0_class.from_config.assert_called_once()
|
||||
|
||||
# Verify memory instance properties (use type: ignore or add public properties)
|
||||
assert memory._user_id == "test-local-config-user" # type: ignore
|
||||
assert memory._limit == 10 # type: ignore
|
||||
assert memory._is_cloud is False # type: ignore
|
||||
assert memory._config == full_local_config # type: ignore
|
||||
|
||||
# Test serialization with local config
|
||||
memory_config = memory.dump_component()
|
||||
|
||||
# Verify serialized config
|
||||
assert memory_config.config["user_id"] == "test-local-config-user"
|
||||
assert memory_config.config["is_cloud"] is False
|
||||
|
||||
# Cleanup
|
||||
await memory.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("autogen_ext.memory.mem0.Memory0") # Patches the underlying mem0.Memory class
|
||||
async def test_local_config_with_memory_operations(
|
||||
mock_mem0_class: MagicMock,
|
||||
full_local_config: Dict[str, Any], # full_local_config fixture provides the mock config
|
||||
) -> None:
|
||||
"""Test memory operations with local configuration."""
|
||||
# Setup mock for the instance that will be created by Mem0Memory
|
||||
mock_mem0_instance = MagicMock()
|
||||
mock_mem0_class.from_config.return_value = mock_mem0_instance
|
||||
|
||||
# Mock search results from the mem0 instance
|
||||
mock_mem0_instance.search.return_value = [
|
||||
{
|
||||
"memory": "Test local config memory content",
|
||||
"score": 0.92,
|
||||
"metadata": {"config_type": "local", "test_case": "advanced"},
|
||||
}
|
||||
]
|
||||
|
||||
# Initialize Mem0Memory with is_cloud=False and the full_local_config
|
||||
memory = Mem0Memory(user_id="test-local-config-user", limit=10, is_cloud=False, config=full_local_config)
|
||||
|
||||
# Verify that mem0.Memory.from_config was called with the provided config
|
||||
mock_mem0_class.from_config.assert_called_once_with(config_dict=full_local_config)
|
||||
|
||||
# Add memory content
|
||||
test_content_str = "Testing local configuration memory operations"
|
||||
await memory.add(
|
||||
MemoryContent(
|
||||
content=test_content_str,
|
||||
mime_type=MemoryMimeType.TEXT,
|
||||
metadata={"config_type": "local", "test_case": "advanced"},
|
||||
)
|
||||
)
|
||||
|
||||
# Verify add was called on the mock_mem0_instance
|
||||
mock_mem0_instance.add.assert_called_once()
|
||||
|
||||
# Query memory
|
||||
results = await memory.query("local configuration test")
|
||||
|
||||
# Verify search was called on the mock_mem0_instance
|
||||
mock_mem0_instance.search.assert_called_once_with(
|
||||
"local configuration test", user_id="test-local-config-user", limit=10
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert len(results.results) == 1
|
||||
assert "Test local config memory content" in str(results.results[0].content)
|
||||
res_metadata = results.results[0].metadata
|
||||
assert res_metadata is not None and res_metadata.get("score") == 0.92
|
||||
assert res_metadata is not None and res_metadata.get("config_type") == "local"
|
||||
|
||||
# Test serialization with local config
|
||||
memory_config = memory.dump_component()
|
||||
|
||||
# Verify serialized config
|
||||
assert memory_config.config["user_id"] == "test-local-config-user"
|
||||
assert memory_config.config["is_cloud"] is False
|
||||
assert "config" in memory_config.config
|
||||
assert memory_config.config["config"]["history_db_path"] == ":memory:"
|
||||
|
||||
# Test clear
|
||||
await memory.clear()
|
||||
mock_mem0_instance.delete_all.assert_called_once_with(user_id="test-local-config-user")
|
||||
|
||||
# Cleanup
|
||||
await memory.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main(["-xvs", __file__])
|
||||
5536
python/uv.lock
generated
5536
python/uv.lock
generated
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user