From 67ebeeda0e9acdaa8b8b45dd5c38df1385b8795b Mon Sep 17 00:00:00 2001 From: Tejas Dharani Date: Fri, 13 Jun 2025 21:36:15 +0530 Subject: [PATCH] Feature/chromadb embedding functions #6267 (#6648) ## Why are these changes needed? This PR adds support for configurable embedding functions in ChromaDBVectorMemory, addressing the need for users to customize how embeddings are generated for vector similarity search. Currently, ChromaDB memory is limited to default embedding functions, which restricts flexibility for different use cases that may require specific embedding models or custom embedding logic. The implementation allows users to: - Use different SentenceTransformer models for domain-specific embeddings - Integrate with OpenAI's embedding API for consistent embedding generation - Define custom embedding functions for specialized requirements - Maintain backward compatibility with existing default behavior ## Related issue number Closes #6267 ## Checks - [x] I've included any doc changes needed for . See to build and test documentation locally. - [x] I've added tests corresponding to the changes introduced in this PR. - [x] I've made sure all auto checks have passed. --------- Co-authored-by: Victor Dibia Co-authored-by: Victor Dibia --- .../agentchat-user-guide/memory.ipynb | 1002 +++++++++-------- .../autogen_ext/memory/chromadb/__init__.py | 21 + .../memory/chromadb/_chroma_configs.py | 148 +++ .../{chromadb.py => chromadb/_chromadb.py} | 131 ++- .../tests/memory/test_chroma_memory.py | 202 +++- 5 files changed, 1029 insertions(+), 475 deletions(-) create mode 100644 python/packages/autogen-ext/src/autogen_ext/memory/chromadb/__init__.py create mode 100644 python/packages/autogen-ext/src/autogen_ext/memory/chromadb/_chroma_configs.py rename python/packages/autogen-ext/src/autogen_ext/memory/{chromadb.py => chromadb/_chromadb.py} (78%) diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/memory.ipynb b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/memory.ipynb index b70c60b61..cf642b861 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/memory.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/memory.ipynb @@ -1,442 +1,566 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Memory and RAG\n", - "\n", - "There are several use cases where it is valuable to maintain a _store_ of useful facts that can be intelligently added to the context of the agent just before a specific step. The typically use case here is a RAG pattern where a query is used to retrieve relevant information from a database that is then added to the agent's context.\n", - "\n", - "\n", - "AgentChat provides a {py:class}`~autogen_core.memory.Memory` protocol that can be extended to provide this functionality. The key methods are `query`, `update_context`, `add`, `clear`, and `close`. \n", - "\n", - "- `add`: add new entries to the memory store\n", - "- `query`: retrieve relevant information from the memory store \n", - "- `update_context`: mutate an agent's internal `model_context` by adding the retrieved information (used in the {py:class}`~autogen_agentchat.agents.AssistantAgent` class) \n", - "- `clear`: clear all entries from the memory store\n", - "- `close`: clean up any resources used by the memory store \n", - "\n", - "\n", - "## ListMemory Example\n", - "\n", - "{py:class}~autogen_core.memory.ListMemory is provided as an example implementation of the {py:class}~autogen_core.memory.Memory protocol. It is a simple list-based memory implementation that maintains memories in chronological order, appending the most recent memories to the model's context. The implementation is designed to be straightforward and predictable, making it easy to understand and debug.\n", - "In the following example, we will use ListMemory to maintain a memory bank of user preferences and demonstrate how it can be used to provide consistent context for agent responses over time." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from autogen_agentchat.agents import AssistantAgent\n", - "from autogen_agentchat.ui import Console\n", - "from autogen_core.memory import ListMemory, MemoryContent, MemoryMimeType\n", - "from autogen_ext.models.openai import OpenAIChatCompletionClient" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Initialize user memory\n", - "user_memory = ListMemory()\n", - "\n", - "# Add user preferences to memory\n", - "await user_memory.add(MemoryContent(content=\"The weather should be in metric units\", mime_type=MemoryMimeType.TEXT))\n", - "\n", - "await user_memory.add(MemoryContent(content=\"Meal recipe must be vegan\", mime_type=MemoryMimeType.TEXT))\n", - "\n", - "\n", - "async def get_weather(city: str, units: str = \"imperial\") -> str:\n", - " if units == \"imperial\":\n", - " return f\"The weather in {city} is 73 °F and Sunny.\"\n", - " elif units == \"metric\":\n", - " return f\"The weather in {city} is 23 °C and Sunny.\"\n", - " else:\n", - " return f\"Sorry, I don't know the weather in {city}.\"\n", - "\n", - "\n", - "assistant_agent = AssistantAgent(\n", - " name=\"assistant_agent\",\n", - " model_client=OpenAIChatCompletionClient(\n", - " model=\"gpt-4o-2024-08-06\",\n", - " ),\n", - " tools=[get_weather],\n", - " memory=[user_memory],\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Run the agent with a task.\n", - "stream = assistant_agent.run_stream(task=\"What is the weather in New York?\")\n", - "await Console(stream)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can inspect that the `assistant_agent` model_context is actually updated with the retrieved memory entries. The `transform` method is used to format the retrieved memory entries into a string that can be used by the agent. In this case, we simply concatenate the content of each memory entry into a single string." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "await assistant_agent._model_context.get_messages()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We see above that the weather is returned in Centigrade as stated in the user preferences. \n", - "\n", - "Similarly, assuming we ask a separate question about generating a meal plan, the agent is able to retrieve relevant information from the memory store and provide a personalized (vegan) response." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "stream = assistant_agent.run_stream(task=\"Write brief meal recipe with broth\")\n", - "await Console(stream)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Custom Memory Stores (Vector DBs, etc.)\n", - "\n", - "You can build on the `Memory` protocol to implement more complex memory stores. For example, you could implement a custom memory store that uses a vector database to store and retrieve information, or a memory store that uses a machine learning model to generate personalized responses based on the user's preferences etc.\n", - "\n", - "Specifically, you will need to overload the `add`, `query` and `update_context` methods to implement the desired functionality and pass the memory store to your agent.\n", - "\n", - "\n", - "Currently the following example memory stores are available as part of the {py:class}`~autogen_ext` extensions package. \n", - "\n", - "- `autogen_ext.memory.chromadb.ChromaDBVectorMemory`: A memory store that uses a vector database to store and retrieve information. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "from pathlib import Path\n", - "\n", - "from autogen_agentchat.agents import AssistantAgent\n", - "from autogen_agentchat.ui import Console\n", - "from autogen_core.memory import MemoryContent, MemoryMimeType\n", - "from autogen_ext.memory.chromadb import ChromaDBVectorMemory, PersistentChromaDBVectorMemoryConfig\n", - "from autogen_ext.models.openai import OpenAIChatCompletionClient\n", - "\n", - "# Initialize ChromaDB memory with custom config\n", - "chroma_user_memory = ChromaDBVectorMemory(\n", - " config=PersistentChromaDBVectorMemoryConfig(\n", - " collection_name=\"preferences\",\n", - " persistence_path=os.path.join(str(Path.home()), \".chromadb_autogen\"),\n", - " k=2, # Return top k results\n", - " score_threshold=0.4, # Minimum similarity score\n", - " )\n", - ")\n", - "# a HttpChromaDBVectorMemoryConfig is also supported for connecting to a remote ChromaDB server\n", - "\n", - "# Add user preferences to memory\n", - "await chroma_user_memory.add(\n", - " MemoryContent(\n", - " content=\"The weather should be in metric units\",\n", - " mime_type=MemoryMimeType.TEXT,\n", - " metadata={\"category\": \"preferences\", \"type\": \"units\"},\n", - " )\n", - ")\n", - "\n", - "await chroma_user_memory.add(\n", - " MemoryContent(\n", - " content=\"Meal recipe must be vegan\",\n", - " mime_type=MemoryMimeType.TEXT,\n", - " metadata={\"category\": \"preferences\", \"type\": \"dietary\"},\n", - " )\n", - ")\n", - "\n", - "model_client = OpenAIChatCompletionClient(\n", - " model=\"gpt-4o\",\n", - ")\n", - "\n", - "# Create assistant agent with ChromaDB memory\n", - "assistant_agent = AssistantAgent(\n", - " name=\"assistant_agent\",\n", - " model_client=model_client,\n", - " tools=[get_weather],\n", - " memory=[chroma_user_memory],\n", - ")\n", - "\n", - "stream = assistant_agent.run_stream(task=\"What is the weather in New York?\")\n", - "await Console(stream)\n", - "\n", - "await model_client.close()\n", - "await chroma_user_memory.close()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Note that you can also serialize the ChromaDBVectorMemory and save it to disk." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "chroma_user_memory.dump_component().model_dump_json()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## RAG Agent: Putting It All Together\n", - "\n", - "The RAG (Retrieval Augmented Generation) pattern which is common in building AI systems encompasses two distinct phases:\n", - "\n", - "1. **Indexing**: Loading documents, chunking them, and storing them in a vector database\n", - "2. **Retrieval**: Finding and using relevant chunks during conversation runtime\n", - "\n", - "In our previous examples, we manually added items to memory and passed them to our agents. In practice, the indexing process is usually automated and based on much larger document sources like product documentation, internal files, or knowledge bases.\n", - "\n", - "> Note: The quality of a RAG system is dependent on the quality of the chunking and retrieval process (models, embeddings, etc.). You may need to experiement with more advanced chunking and retrieval models to get the best results.\n", - "\n", - "### Building a Simple RAG Agent\n", - "\n", - "To begin, let's create a simple document indexer that we will used to load documents, chunk them, and store them in a `ChromaDBVectorMemory` memory store. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import re\n", - "from typing import List\n", - "\n", - "import aiofiles\n", - "import aiohttp\n", - "from autogen_core.memory import Memory, MemoryContent, MemoryMimeType\n", - "\n", - "\n", - "class SimpleDocumentIndexer:\n", - " \"\"\"Basic document indexer for AutoGen Memory.\"\"\"\n", - "\n", - " def __init__(self, memory: Memory, chunk_size: int = 1500) -> None:\n", - " self.memory = memory\n", - " self.chunk_size = chunk_size\n", - "\n", - " async def _fetch_content(self, source: str) -> str:\n", - " \"\"\"Fetch content from URL or file.\"\"\"\n", - " if source.startswith((\"http://\", \"https://\")):\n", - " async with aiohttp.ClientSession() as session:\n", - " async with session.get(source) as response:\n", - " return await response.text()\n", - " else:\n", - " async with aiofiles.open(source, \"r\", encoding=\"utf-8\") as f:\n", - " return await f.read()\n", - "\n", - " def _strip_html(self, text: str) -> str:\n", - " \"\"\"Remove HTML tags and normalize whitespace.\"\"\"\n", - " text = re.sub(r\"<[^>]*>\", \" \", text)\n", - " text = re.sub(r\"\\s+\", \" \", text)\n", - " return text.strip()\n", - "\n", - " def _split_text(self, text: str) -> List[str]:\n", - " \"\"\"Split text into fixed-size chunks.\"\"\"\n", - " chunks: list[str] = []\n", - " # Just split text into fixed-size chunks\n", - " for i in range(0, len(text), self.chunk_size):\n", - " chunk = text[i : i + self.chunk_size]\n", - " chunks.append(chunk.strip())\n", - " return chunks\n", - "\n", - " async def index_documents(self, sources: List[str]) -> int:\n", - " \"\"\"Index documents into memory.\"\"\"\n", - " total_chunks = 0\n", - "\n", - " for source in sources:\n", - " try:\n", - " content = await self._fetch_content(source)\n", - "\n", - " # Strip HTML if content appears to be HTML\n", - " if \"<\" in content and \">\" in content:\n", - " content = self._strip_html(content)\n", - "\n", - " chunks = self._split_text(content)\n", - "\n", - " for i, chunk in enumerate(chunks):\n", - " await self.memory.add(\n", - " MemoryContent(\n", - " content=chunk, mime_type=MemoryMimeType.TEXT, metadata={\"source\": source, \"chunk_index\": i}\n", - " )\n", - " )\n", - "\n", - " total_chunks += len(chunks)\n", - "\n", - " except Exception as e:\n", - " print(f\"Error indexing {source}: {str(e)}\")\n", - "\n", - " return total_chunks" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\n", - " \n", - "Now let's use our indexer with ChromaDBVectorMemory to build a complete RAG agent:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Indexed 72 chunks from 4 AutoGen documents\n" - ] - } - ], - "source": [ - "import os\n", - "from pathlib import Path\n", - "\n", - "from autogen_agentchat.agents import AssistantAgent\n", - "from autogen_agentchat.ui import Console\n", - "from autogen_ext.memory.chromadb import ChromaDBVectorMemory, PersistentChromaDBVectorMemoryConfig\n", - "from autogen_ext.models.openai import OpenAIChatCompletionClient\n", - "\n", - "# Initialize vector memory\n", - "\n", - "rag_memory = ChromaDBVectorMemory(\n", - " config=PersistentChromaDBVectorMemoryConfig(\n", - " collection_name=\"autogen_docs\",\n", - " persistence_path=os.path.join(str(Path.home()), \".chromadb_autogen\"),\n", - " k=3, # Return top 3 results\n", - " score_threshold=0.4, # Minimum similarity score\n", - " )\n", - ")\n", - "\n", - "await rag_memory.clear() # Clear existing memory\n", - "\n", - "\n", - "# Index AutoGen documentation\n", - "async def index_autogen_docs() -> None:\n", - " indexer = SimpleDocumentIndexer(memory=rag_memory)\n", - " sources = [\n", - " \"https://raw.githubusercontent.com/microsoft/autogen/main/README.md\",\n", - " \"https://microsoft.github.io/autogen/dev/user-guide/agentchat-user-guide/tutorial/agents.html\",\n", - " \"https://microsoft.github.io/autogen/dev/user-guide/agentchat-user-guide/tutorial/teams.html\",\n", - " \"https://microsoft.github.io/autogen/dev/user-guide/agentchat-user-guide/tutorial/termination.html\",\n", - " ]\n", - " chunks: int = await indexer.index_documents(sources)\n", - " print(f\"Indexed {chunks} chunks from {len(sources)} AutoGen documents\")\n", - "\n", - "\n", - "await index_autogen_docs()" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "---------- user ----------\n", - "What is AgentChat?\n", - "Query results: results=[MemoryContent(content='ng OpenAI\\'s GPT-4o model. See [other supported models](https://microsoft.github.io/autogen/stable/user-guide/agentchat-user-guide/tutorial/models.html). ```python import asyncio from autogen_agentchat.agents import AssistantAgent from autogen_ext.models.openai import OpenAIChatCompletionClient async def main() -> None: model_client = OpenAIChatCompletionClient(model=\"gpt-4o\") agent = AssistantAgent(\"assistant\", model_client=model_client) print(await agent.run(task=\"Say \\'Hello World!\\'\")) await model_client.close() asyncio.run(main()) ``` ### Web Browsing Agent Team Create a group chat team with a web surfer agent and a user proxy agent for web browsing tasks. You need to install [playwright](https://playwright.dev/python/docs/library). ```python # pip install -U autogen-agentchat autogen-ext[openai,web-surfer] # playwright install import asyncio from autogen_agentchat.agents import UserProxyAgent from autogen_agentchat.conditions import TextMentionTermination from autogen_agentchat.teams import RoundRobinGroupChat from autogen_agentchat.ui import Console from autogen_ext.models.openai import OpenAIChatCompletionClient from autogen_ext.agents.web_surfer import MultimodalWebSurfer async def main() -> None: model_client = OpenAIChatCompletionClient(model=\"gpt-4o\") # The web surfer will open a Chromium browser window to perform web browsing tasks. web_surfer = MultimodalWebSurfer(\"web_surfer\", model_client, headless=False, animate_actions=True) # The user proxy agent is used to ge', mime_type='MemoryMimeType.TEXT', metadata={'chunk_index': 1, 'mime_type': 'MemoryMimeType.TEXT', 'source': 'https://raw.githubusercontent.com/microsoft/autogen/main/README.md', 'score': 0.48810458183288574, 'id': '16088e03-0153-4da3-9dec-643b39c549f5'}), MemoryContent(content='els_usage=None content='AutoGen is a programming framework for building multi-agent applications.' type='ToolCallSummaryMessage' The call to the on_messages() method returns a Response that contains the agent’s final response in the chat_message attribute, as well as a list of inner messages in the inner_messages attribute, which stores the agent’s “thought process” that led to the final response. Note It is important to note that on_messages() will update the internal state of the agent – it will add the messages to the agent’s history. So you should call this method with new messages. You should not repeatedly call this method with the same messages or the complete history. Note Unlike in v0.2 AgentChat, the tools are executed by the same agent directly within the same call to on_messages() . By default, the agent will return the result of the tool call as the final response. You can also call the run() method, which is a convenience method that calls on_messages() . It follows the same interface as Teams and returns a TaskResult object. Multi-Modal Input # The AssistantAgent can handle multi-modal input by providing the input as a MultiModalMessage . from io import BytesIO import PIL import requests from autogen_agentchat.messages import MultiModalMessage from autogen_core import Image # Create a multi-modal message with random image and text. pil_image = PIL . Image . open ( BytesIO ( requests . get ( "https://picsum.photos/300/200" ) . content )', mime_type='MemoryMimeType.TEXT', metadata={'chunk_index': 3, 'mime_type': 'MemoryMimeType.TEXT', 'source': 'https://microsoft.github.io/autogen/dev/user-guide/agentchat-user-guide/tutorial/agents.html', 'score': 0.4665141701698303, 'id': '3d603b62-7cab-4f74-b671-586fe36306f2'}), MemoryContent(content='AgentChat Termination Termination # In the previous section, we explored how to define agents, and organize them into teams that can solve tasks. However, a run can go on forever, and in many cases, we need to know when to stop them. This is the role of the termination condition. AgentChat supports several termination condition by providing a base TerminationCondition class and several implementations that inherit from it. A termination condition is a callable that takes a sequence of BaseAgentEvent or BaseChatMessage objects since the last time the condition was called , and returns a StopMessage if the conversation should be terminated, or None otherwise. Once a termination condition has been reached, it must be reset by calling reset() before it can be used again. Some important things to note about termination conditions: They are stateful but reset automatically after each run ( run() or run_stream() ) is finished. They can be combined using the AND and OR operators. Note For group chat teams (i.e., RoundRobinGroupChat , SelectorGroupChat , and Swarm ), the termination condition is called after each agent responds. While a response may contain multiple inner messages, the team calls its termination condition just once for all the messages from a single response. So the condition is called with the “delta sequence” of messages since the last time it was called. Built-In Termination Conditions: MaxMessageTermination : Stops after a specified number of messages have been produced,', mime_type='MemoryMimeType.TEXT', metadata={'chunk_index': 1, 'mime_type': 'MemoryMimeType.TEXT', 'source': 'https://microsoft.github.io/autogen/dev/user-guide/agentchat-user-guide/tutorial/termination.html', 'score': 0.461774212772051, 'id': '699ef490-d108-4cd3-b629-c1198d6b78ba'})]\n", - "---------- rag_assistant ----------\n", - "[MemoryContent(content='ng OpenAI\\'s GPT-4o model. See [other supported models](https://microsoft.github.io/autogen/stable/user-guide/agentchat-user-guide/tutorial/models.html). ```python import asyncio from autogen_agentchat.agents import AssistantAgent from autogen_ext.models.openai import OpenAIChatCompletionClient async def main() -> None: model_client = OpenAIChatCompletionClient(model=\"gpt-4o\") agent = AssistantAgent(\"assistant\", model_client=model_client) print(await agent.run(task=\"Say \\'Hello World!\\'\")) await model_client.close() asyncio.run(main()) ``` ### Web Browsing Agent Team Create a group chat team with a web surfer agent and a user proxy agent for web browsing tasks. You need to install [playwright](https://playwright.dev/python/docs/library). ```python # pip install -U autogen-agentchat autogen-ext[openai,web-surfer] # playwright install import asyncio from autogen_agentchat.agents import UserProxyAgent from autogen_agentchat.conditions import TextMentionTermination from autogen_agentchat.teams import RoundRobinGroupChat from autogen_agentchat.ui import Console from autogen_ext.models.openai import OpenAIChatCompletionClient from autogen_ext.agents.web_surfer import MultimodalWebSurfer async def main() -> None: model_client = OpenAIChatCompletionClient(model=\"gpt-4o\") # The web surfer will open a Chromium browser window to perform web browsing tasks. web_surfer = MultimodalWebSurfer(\"web_surfer\", model_client, headless=False, animate_actions=True) # The user proxy agent is used to ge', mime_type='MemoryMimeType.TEXT', metadata={'chunk_index': 1, 'mime_type': 'MemoryMimeType.TEXT', 'source': 'https://raw.githubusercontent.com/microsoft/autogen/main/README.md', 'score': 0.48810458183288574, 'id': '16088e03-0153-4da3-9dec-643b39c549f5'}), MemoryContent(content='els_usage=None content='AutoGen is a programming framework for building multi-agent applications.' type='ToolCallSummaryMessage' The call to the on_messages() method returns a Response that contains the agent’s final response in the chat_message attribute, as well as a list of inner messages in the inner_messages attribute, which stores the agent’s “thought process” that led to the final response. Note It is important to note that on_messages() will update the internal state of the agent – it will add the messages to the agent’s history. So you should call this method with new messages. You should not repeatedly call this method with the same messages or the complete history. Note Unlike in v0.2 AgentChat, the tools are executed by the same agent directly within the same call to on_messages() . By default, the agent will return the result of the tool call as the final response. You can also call the run() method, which is a convenience method that calls on_messages() . It follows the same interface as Teams and returns a TaskResult object. Multi-Modal Input # The AssistantAgent can handle multi-modal input by providing the input as a MultiModalMessage . from io import BytesIO import PIL import requests from autogen_agentchat.messages import MultiModalMessage from autogen_core import Image # Create a multi-modal message with random image and text. pil_image = PIL . Image . open ( BytesIO ( requests . get ( "https://picsum.photos/300/200" ) . content )', mime_type='MemoryMimeType.TEXT', metadata={'chunk_index': 3, 'mime_type': 'MemoryMimeType.TEXT', 'source': 'https://microsoft.github.io/autogen/dev/user-guide/agentchat-user-guide/tutorial/agents.html', 'score': 0.4665141701698303, 'id': '3d603b62-7cab-4f74-b671-586fe36306f2'}), MemoryContent(content='AgentChat Termination Termination # In the previous section, we explored how to define agents, and organize them into teams that can solve tasks. However, a run can go on forever, and in many cases, we need to know when to stop them. This is the role of the termination condition. AgentChat supports several termination condition by providing a base TerminationCondition class and several implementations that inherit from it. A termination condition is a callable that takes a sequenceBaseChatMessageent or BaseChatMessage objects since the last time the condition was called , and returns a StopMessage if the conversation should be terminated, or None otherwise. Once a termination condition has been reached, it must be reset by calling reset() before it can be used again. Some important things to note about termination conditions: They are stateful but reset automatically after each run ( run() or run_stream() ) is finished. They can be combined using the AND and OR operators. Note For group chat teams (i.e., RoundRobinGroupChat , SelectorGroupChat , and Swarm ), the termination condition is called after each agent responds. While a response may contain multiple inner messages, the team calls its termination condition just once for all the messages from a single response. So the condition is called with the “delta sequence” of messages since the last time it was called. Built-In Termination Conditions: MaxMessageTermination : Stops after a specified number of messages have been produced,', mime_type='MemoryMimeType.TEXT', metadata={'chunk_index': 1, 'mime_type': 'MemoryMimeType.TEXT', 'source': 'https://microsoft.github.io/autogen/dev/user-guide/agentchat-user-guide/tutorial/termination.html', 'score': 0.461774212772051, 'id': '699ef490-d108-4cd3-b629-c1198d6b78ba'})]\n", - "---------- rag_assistant ----------\n", - "AgentChat is part of the AutoGen framework, a programming environment for building multi-agent applications. In AgentChat, agents can interact with each other and with users to perform various tasks, including web browsing and engaging in dialogue. It utilizes models from OpenAI for chat completions and supports multi-modal input, which means agents can handle inputs that include both text and images. Additionally, AgentChat provides mechanisms to define termination conditions to control when a conversation or task should be concluded, ensuring that the agent interactions are efficient and goal-oriented. TERMINATE\n" - ] - } - ], - "source": [ - "# Create our RAG assistant agent\n", - "rag_assistant = AssistantAgent(\n", - " name=\"rag_assistant\", model_client=OpenAIChatCompletionClient(model=\"gpt-4o\"), memory=[rag_memory]\n", - ")\n", - "\n", - "# Ask questions about AutoGen\n", - "stream = rag_assistant.run_stream(task=\"What is AgentChat?\")\n", - "await Console(stream)\n", - "\n", - "# Remember to close the memory when done\n", - "await rag_memory.close()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This implementation provides a RAG agent that can answer questions based on AutoGen documentation. When a question is asked, the Memory system retrieves relevant chunks and adds them to the context, enabling the assistant to generate informed responses.\n", - "\n", - "For production systems, you might want to:\n", - "1. Implement more sophisticated chunking strategies\n", - "2. Add metadata filtering capabilities\n", - "3. Customize the retrieval scoring\n", - "4. Optimize embedding models for your specific domain\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.9" - } + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Memory and RAG\n", + "\n", + "There are several use cases where it is valuable to maintain a _store_ of useful facts that can be intelligently added to the context of the agent just before a specific step. The typically use case here is a RAG pattern where a query is used to retrieve relevant information from a database that is then added to the agent's context.\n", + "\n", + "\n", + "AgentChat provides a {py:class}`~autogen_core.memory.Memory` protocol that can be extended to provide this functionality. The key methods are `query`, `update_context`, `add`, `clear`, and `close`. \n", + "\n", + "- `add`: add new entries to the memory store\n", + "- `query`: retrieve relevant information from the memory store \n", + "- `update_context`: mutate an agent's internal `model_context` by adding the retrieved information (used in the {py:class}`~autogen_agentchat.agents.AssistantAgent` class) \n", + "- `clear`: clear all entries from the memory store\n", + "- `close`: clean up any resources used by the memory store \n", + "\n", + "\n", + "## ListMemory Example\n", + "\n", + "{py:class}~autogen_core.memory.ListMemory is provided as an example implementation of the {py:class}~autogen_core.memory.Memory protocol. It is a simple list-based memory implementation that maintains memories in chronological order, appending the most recent memories to the model's context. The implementation is designed to be straightforward and predictable, making it easy to understand and debug.\n", + "In the following example, we will use ListMemory to maintain a memory bank of user preferences and demonstrate how it can be used to provide consistent context for agent responses over time." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from autogen_agentchat.agents import AssistantAgent\n", + "from autogen_agentchat.ui import Console\n", + "from autogen_core.memory import ListMemory, MemoryContent, MemoryMimeType\n", + "from autogen_ext.models.openai import OpenAIChatCompletionClient" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize user memory\n", + "user_memory = ListMemory()\n", + "\n", + "# Add user preferences to memory\n", + "await user_memory.add(MemoryContent(content=\"The weather should be in metric units\", mime_type=MemoryMimeType.TEXT))\n", + "\n", + "await user_memory.add(MemoryContent(content=\"Meal recipe must be vegan\", mime_type=MemoryMimeType.TEXT))\n", + "\n", + "\n", + "async def get_weather(city: str, units: str = \"imperial\") -> str:\n", + " if units == \"imperial\":\n", + " return f\"The weather in {city} is 73 °F and Sunny.\"\n", + " elif units == \"metric\":\n", + " return f\"The weather in {city} is 23 °C and Sunny.\"\n", + " else:\n", + " return f\"Sorry, I don't know the weather in {city}.\"\n", + "\n", + "\n", + "assistant_agent = AssistantAgent(\n", + " name=\"assistant_agent\",\n", + " model_client=OpenAIChatCompletionClient(\n", + " model=\"gpt-4o-2024-08-06\",\n", + " ),\n", + " tools=[get_weather],\n", + " memory=[user_memory],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "---------- TextMessage (user) ----------\n", + "What is the weather in New York?\n", + "---------- MemoryQueryEvent (assistant_agent) ----------\n", + "[MemoryContent(content='The weather should be in metric units', mime_type=, metadata=None), MemoryContent(content='Meal recipe must be vegan', mime_type=, metadata=None)]\n", + "---------- ToolCallRequestEvent (assistant_agent) ----------\n", + "[FunctionCall(id='call_apWw5JOedVvqsPfXWV7c5Uiw', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')]\n", + "---------- ToolCallExecutionEvent (assistant_agent) ----------\n", + "[FunctionExecutionResult(content='The weather in New York is 23 °C and Sunny.', name='get_weather', call_id='call_apWw5JOedVvqsPfXWV7c5Uiw', is_error=False)]\n", + "---------- ToolCallSummaryMessage (assistant_agent) ----------\n", + "The weather in New York is 23 °C and Sunny.\n" + ] }, - "nbformat": 4, - "nbformat_minor": 2 + { + "data": { + "text/plain": [ + "TaskResult(messages=[TextMessage(source='user', models_usage=None, metadata={}, created_at=datetime.datetime(2025, 6, 12, 17, 46, 33, 492791, tzinfo=datetime.timezone.utc), content='What is the weather in New York?', type='TextMessage'), MemoryQueryEvent(source='assistant_agent', models_usage=None, metadata={}, created_at=datetime.datetime(2025, 6, 12, 17, 46, 33, 494162, tzinfo=datetime.timezone.utc), content=[MemoryContent(content='The weather should be in metric units', mime_type=, metadata=None), MemoryContent(content='Meal recipe must be vegan', mime_type=, metadata=None)], type='MemoryQueryEvent'), ToolCallRequestEvent(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=123, completion_tokens=19), metadata={}, created_at=datetime.datetime(2025, 6, 12, 17, 46, 34, 892272, tzinfo=datetime.timezone.utc), content=[FunctionCall(id='call_apWw5JOedVvqsPfXWV7c5Uiw', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='assistant_agent', models_usage=None, metadata={}, created_at=datetime.datetime(2025, 6, 12, 17, 46, 34, 894081, tzinfo=datetime.timezone.utc), content=[FunctionExecutionResult(content='The weather in New York is 23 °C and Sunny.', name='get_weather', call_id='call_apWw5JOedVvqsPfXWV7c5Uiw', is_error=False)], type='ToolCallExecutionEvent'), ToolCallSummaryMessage(source='assistant_agent', models_usage=None, metadata={}, created_at=datetime.datetime(2025, 6, 12, 17, 46, 34, 895054, tzinfo=datetime.timezone.utc), content='The weather in New York is 23 °C and Sunny.', type='ToolCallSummaryMessage', tool_calls=[FunctionCall(id='call_apWw5JOedVvqsPfXWV7c5Uiw', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], results=[FunctionExecutionResult(content='The weather in New York is 23 °C and Sunny.', name='get_weather', call_id='call_apWw5JOedVvqsPfXWV7c5Uiw', is_error=False)])], stop_reason=None)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Run the agent with a task.\n", + "stream = assistant_agent.run_stream(task=\"What is the weather in New York?\")\n", + "await Console(stream)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can inspect that the `assistant_agent` model_context is actually updated with the retrieved memory entries. The `transform` method is used to format the retrieved memory entries into a string that can be used by the agent. In this case, we simply concatenate the content of each memory entry into a single string." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[UserMessage(content='What is the weather in New York?', source='user', type='UserMessage'),\n", + " SystemMessage(content='\\nRelevant memory content (in chronological order):\\n1. The weather should be in metric units\\n2. Meal recipe must be vegan\\n', type='SystemMessage'),\n", + " AssistantMessage(content=[FunctionCall(id='call_apWw5JOedVvqsPfXWV7c5Uiw', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], thought=None, source='assistant_agent', type='AssistantMessage'),\n", + " FunctionExecutionResultMessage(content=[FunctionExecutionResult(content='The weather in New York is 23 °C and Sunny.', name='get_weather', call_id='call_apWw5JOedVvqsPfXWV7c5Uiw', is_error=False)], type='FunctionExecutionResultMessage')]" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "await assistant_agent._model_context.get_messages()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We see above that the weather is returned in Centigrade as stated in the user preferences. \n", + "\n", + "Similarly, assuming we ask a separate question about generating a meal plan, the agent is able to retrieve relevant information from the memory store and provide a personalized (vegan) response." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "---------- TextMessage (user) ----------\n", + "Write brief meal recipe with broth\n", + "---------- MemoryQueryEvent (assistant_agent) ----------\n", + "[MemoryContent(content='The weather should be in metric units', mime_type=, metadata=None), MemoryContent(content='Meal recipe must be vegan', mime_type=, metadata=None)]\n", + "---------- TextMessage (assistant_agent) ----------\n", + "Here's another vegan broth-based recipe:\n", + "\n", + "**Vegan Miso Soup**\n", + "\n", + "**Ingredients:**\n", + "- 4 cups vegetable broth\n", + "- 3 tablespoons white miso paste\n", + "- 1 block firm tofu, cubed\n", + "- 1 cup mushrooms, sliced (shiitake or any variety you prefer)\n", + "- 2 green onions, chopped\n", + "- 1 tablespoon soy sauce (optional)\n", + "- 1/2 cup seaweed (such as wakame)\n", + "- 1 tablespoon sesame oil\n", + "- 1 tablespoon grated ginger\n", + "- Salt to taste\n", + "\n", + "**Instructions:**\n", + "1. In a pot, heat the sesame oil over medium heat.\n", + "2. Add the grated ginger and sauté for about a minute until fragrant.\n", + "3. Pour in the vegetable broth and bring it to a simmer.\n", + "4. Add the miso paste, stirring until fully dissolved.\n", + "5. Add the tofu cubes, mushrooms, and seaweed to the broth and cook for about 5 minutes.\n", + "6. Stir in soy sauce if using, and add salt to taste.\n", + "7. Garnish with chopped green onions before serving.\n", + "\n", + "Enjoy your delicious and nutritious vegan miso soup! TERMINATE\n" + ] + }, + { + "data": { + "text/plain": [ + "TaskResult(messages=[TextMessage(source='user', models_usage=None, metadata={}, created_at=datetime.datetime(2025, 6, 12, 17, 47, 19, 247083, tzinfo=datetime.timezone.utc), content='Write brief meal recipe with broth', type='TextMessage'), MemoryQueryEvent(source='assistant_agent', models_usage=None, metadata={}, created_at=datetime.datetime(2025, 6, 12, 17, 47, 19, 248736, tzinfo=datetime.timezone.utc), content=[MemoryContent(content='The weather should be in metric units', mime_type=, metadata=None), MemoryContent(content='Meal recipe must be vegan', mime_type=, metadata=None)], type='MemoryQueryEvent'), TextMessage(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=528, completion_tokens=233), metadata={}, created_at=datetime.datetime(2025, 6, 12, 17, 47, 26, 130554, tzinfo=datetime.timezone.utc), content=\"Here's another vegan broth-based recipe:\\n\\n**Vegan Miso Soup**\\n\\n**Ingredients:**\\n- 4 cups vegetable broth\\n- 3 tablespoons white miso paste\\n- 1 block firm tofu, cubed\\n- 1 cup mushrooms, sliced (shiitake or any variety you prefer)\\n- 2 green onions, chopped\\n- 1 tablespoon soy sauce (optional)\\n- 1/2 cup seaweed (such as wakame)\\n- 1 tablespoon sesame oil\\n- 1 tablespoon grated ginger\\n- Salt to taste\\n\\n**Instructions:**\\n1. In a pot, heat the sesame oil over medium heat.\\n2. Add the grated ginger and sauté for about a minute until fragrant.\\n3. Pour in the vegetable broth and bring it to a simmer.\\n4. Add the miso paste, stirring until fully dissolved.\\n5. Add the tofu cubes, mushrooms, and seaweed to the broth and cook for about 5 minutes.\\n6. Stir in soy sauce if using, and add salt to taste.\\n7. Garnish with chopped green onions before serving.\\n\\nEnjoy your delicious and nutritious vegan miso soup! TERMINATE\", type='TextMessage')], stop_reason=None)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "stream = assistant_agent.run_stream(task=\"Write brief meal recipe with broth\")\n", + "await Console(stream)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Custom Memory Stores (Vector DBs, etc.)\n", + "\n", + "You can build on the `Memory` protocol to implement more complex memory stores. For example, you could implement a custom memory store that uses a vector database to store and retrieve information, or a memory store that uses a machine learning model to generate personalized responses based on the user's preferences etc.\n", + "\n", + "Specifically, you will need to overload the `add`, `query` and `update_context` methods to implement the desired functionality and pass the memory store to your agent.\n", + "\n", + "\n", + "Currently the following example memory stores are available as part of the {py:class}`~autogen_ext` extensions package. \n", + "\n", + "- `autogen_ext.memory.chromadb.ChromaDBVectorMemory`: A memory store that uses a vector database to store and retrieve information. \n", + "\n", + "- `autogen_ext.memory.chromadb.SentenceTransformerEmbeddingFunctionConfig`: A configuration class for the SentenceTransformer embedding function used by the `ChromaDBVectorMemory` store. Note that other embedding functions such as `autogen_ext.memory.openai.OpenAIEmbeddingFunctionConfig` can also be used with the `ChromaDBVectorMemory` store.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "---------- TextMessage (user) ----------\n", + "What is the weather in New York?\n", + "---------- MemoryQueryEvent (assistant_agent) ----------\n", + "[MemoryContent(content='The weather should be in metric units', mime_type='MemoryMimeType.TEXT', metadata={'type': 'units', 'mime_type': 'MemoryMimeType.TEXT', 'category': 'preferences', 'score': 0.4342840313911438, 'id': 'd7ed6e42-0bf5-4ee8-b5b5-fbe06f583477'})]\n", + "---------- ToolCallRequestEvent (assistant_agent) ----------\n", + "[FunctionCall(id='call_ufpz7LGcn19ZroowyEraj9bd', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')]\n", + "---------- ToolCallExecutionEvent (assistant_agent) ----------\n", + "[FunctionExecutionResult(content='The weather in New York is 23 °C and Sunny.', name='get_weather', call_id='call_ufpz7LGcn19ZroowyEraj9bd', is_error=False)]\n", + "---------- ToolCallSummaryMessage (assistant_agent) ----------\n", + "The weather in New York is 23 °C and Sunny.\n" + ] + } + ], + "source": [ + "import tempfile\n", + "\n", + "from autogen_agentchat.agents import AssistantAgent\n", + "from autogen_agentchat.ui import Console\n", + "from autogen_core.memory import MemoryContent, MemoryMimeType\n", + "from autogen_ext.memory.chromadb import (\n", + " ChromaDBVectorMemory,\n", + " PersistentChromaDBVectorMemoryConfig,\n", + " SentenceTransformerEmbeddingFunctionConfig,\n", + ")\n", + "from autogen_ext.models.openai import OpenAIChatCompletionClient\n", + "\n", + "# Use a temporary directory for ChromaDB persistence\n", + "with tempfile.TemporaryDirectory() as tmpdir:\n", + " chroma_user_memory = ChromaDBVectorMemory(\n", + " config=PersistentChromaDBVectorMemoryConfig(\n", + " collection_name=\"preferences\",\n", + " persistence_path=tmpdir, # Use the temp directory here\n", + " k=2, # Return top k results\n", + " score_threshold=0.4, # Minimum similarity score\n", + " embedding_function_config=SentenceTransformerEmbeddingFunctionConfig(\n", + " model_name=\"all-MiniLM-L6-v2\" # Use default model for testing\n", + " ),\n", + " )\n", + " )\n", + " # Add user preferences to memory\n", + " await chroma_user_memory.add(\n", + " MemoryContent(\n", + " content=\"The weather should be in metric units\",\n", + " mime_type=MemoryMimeType.TEXT,\n", + " metadata={\"category\": \"preferences\", \"type\": \"units\"},\n", + " )\n", + " )\n", + "\n", + " await chroma_user_memory.add(\n", + " MemoryContent(\n", + " content=\"Meal recipe must be vegan\",\n", + " mime_type=MemoryMimeType.TEXT,\n", + " metadata={\"category\": \"preferences\", \"type\": \"dietary\"},\n", + " )\n", + " )\n", + "\n", + " model_client = OpenAIChatCompletionClient(\n", + " model=\"gpt-4o\",\n", + " )\n", + "\n", + " # Create assistant agent with ChromaDB memory\n", + " assistant_agent = AssistantAgent(\n", + " name=\"assistant_agent\",\n", + " model_client=model_client,\n", + " tools=[get_weather],\n", + " memory=[chroma_user_memory],\n", + " )\n", + "\n", + " stream = assistant_agent.run_stream(task=\"What is the weather in New York?\")\n", + " await Console(stream)\n", + "\n", + " await model_client.close()\n", + " await chroma_user_memory.close()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that you can also serialize the ChromaDBVectorMemory and save it to disk." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'{\"provider\":\"autogen_ext.memory.chromadb.ChromaDBVectorMemory\",\"component_type\":\"memory\",\"version\":1,\"component_version\":1,\"description\":\"Store and retrieve memory using vector similarity search powered by ChromaDB.\",\"label\":\"ChromaDBVectorMemory\",\"config\":{\"client_type\":\"persistent\",\"collection_name\":\"preferences\",\"distance_metric\":\"cosine\",\"k\":2,\"score_threshold\":0.4,\"allow_reset\":false,\"tenant\":\"default_tenant\",\"database\":\"default_database\",\"embedding_function_config\":{\"function_type\":\"sentence_transformer\",\"model_name\":\"all-MiniLM-L6-v2\"},\"persistence_path\":\"/var/folders/wg/hgs_dt8n5lbd3gx3pq7k6lym0000gn/T/tmp9qcaqchy\"}}'" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chroma_user_memory.dump_component().model_dump_json()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## RAG Agent: Putting It All Together\n", + "\n", + "The RAG (Retrieval Augmented Generation) pattern which is common in building AI systems encompasses two distinct phases:\n", + "\n", + "1. **Indexing**: Loading documents, chunking them, and storing them in a vector database\n", + "2. **Retrieval**: Finding and using relevant chunks during conversation runtime\n", + "\n", + "In our previous examples, we manually added items to memory and passed them to our agents. In practice, the indexing process is usually automated and based on much larger document sources like product documentation, internal files, or knowledge bases.\n", + "\n", + "> Note: The quality of a RAG system is dependent on the quality of the chunking and retrieval process (models, embeddings, etc.). You may need to experiement with more advanced chunking and retrieval models to get the best results.\n", + "\n", + "### Building a Simple RAG Agent\n", + "\n", + "To begin, let's create a simple document indexer that we will used to load documents, chunk them, and store them in a `ChromaDBVectorMemory` memory store. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import re\n", + "from typing import List\n", + "\n", + "import aiofiles\n", + "import aiohttp\n", + "from autogen_core.memory import Memory, MemoryContent, MemoryMimeType\n", + "\n", + "\n", + "class SimpleDocumentIndexer:\n", + " \"\"\"Basic document indexer for AutoGen Memory.\"\"\"\n", + "\n", + " def __init__(self, memory: Memory, chunk_size: int = 1500) -> None:\n", + " self.memory = memory\n", + " self.chunk_size = chunk_size\n", + "\n", + " async def _fetch_content(self, source: str) -> str:\n", + " \"\"\"Fetch content from URL or file.\"\"\"\n", + " if source.startswith((\"http://\", \"https://\")):\n", + " async with aiohttp.ClientSession() as session:\n", + " async with session.get(source) as response:\n", + " return await response.text()\n", + " else:\n", + " async with aiofiles.open(source, \"r\", encoding=\"utf-8\") as f:\n", + " return await f.read()\n", + "\n", + " def _strip_html(self, text: str) -> str:\n", + " \"\"\"Remove HTML tags and normalize whitespace.\"\"\"\n", + " text = re.sub(r\"<[^>]*>\", \" \", text)\n", + " text = re.sub(r\"\\s+\", \" \", text)\n", + " return text.strip()\n", + "\n", + " def _split_text(self, text: str) -> List[str]:\n", + " \"\"\"Split text into fixed-size chunks.\"\"\"\n", + " chunks: list[str] = []\n", + " # Just split text into fixed-size chunks\n", + " for i in range(0, len(text), self.chunk_size):\n", + " chunk = text[i : i + self.chunk_size]\n", + " chunks.append(chunk.strip())\n", + " return chunks\n", + "\n", + " async def index_documents(self, sources: List[str]) -> int:\n", + " \"\"\"Index documents into memory.\"\"\"\n", + " total_chunks = 0\n", + "\n", + " for source in sources:\n", + " try:\n", + " content = await self._fetch_content(source)\n", + "\n", + " # Strip HTML if content appears to be HTML\n", + " if \"<\" in content and \">\" in content:\n", + " content = self._strip_html(content)\n", + "\n", + " chunks = self._split_text(content)\n", + "\n", + " for i, chunk in enumerate(chunks):\n", + " await self.memory.add(\n", + " MemoryContent(\n", + " content=chunk, mime_type=MemoryMimeType.TEXT, metadata={\"source\": source, \"chunk_index\": i}\n", + " )\n", + " )\n", + "\n", + " total_chunks += len(chunks)\n", + "\n", + " except Exception as e:\n", + " print(f\"Error indexing {source}: {str(e)}\")\n", + "\n", + " return total_chunks" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + " \n", + "Now let's use our indexer with ChromaDBVectorMemory to build a complete RAG agent:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Indexed 72 chunks from 4 AutoGen documents\n" + ] + } + ], + "source": [ + "import os\n", + "from pathlib import Path\n", + "\n", + "from autogen_agentchat.agents import AssistantAgent\n", + "from autogen_agentchat.ui import Console\n", + "from autogen_ext.memory.chromadb import ChromaDBVectorMemory, PersistentChromaDBVectorMemoryConfig\n", + "from autogen_ext.models.openai import OpenAIChatCompletionClient\n", + "\n", + "# Initialize vector memory\n", + "\n", + "rag_memory = ChromaDBVectorMemory(\n", + " config=PersistentChromaDBVectorMemoryConfig(\n", + " collection_name=\"autogen_docs\",\n", + " persistence_path=os.path.join(str(Path.home()), \".chromadb_autogen\"),\n", + " k=3, # Return top 3 results\n", + " score_threshold=0.4, # Minimum similarity score\n", + " )\n", + ")\n", + "\n", + "await rag_memory.clear() # Clear existing memory\n", + "\n", + "\n", + "# Index AutoGen documentation\n", + "async def index_autogen_docs() -> None:\n", + " indexer = SimpleDocumentIndexer(memory=rag_memory)\n", + " sources = [\n", + " \"https://raw.githubusercontent.com/microsoft/autogen/main/README.md\",\n", + " \"https://microsoft.github.io/autogen/dev/user-guide/agentchat-user-guide/tutorial/agents.html\",\n", + " \"https://microsoft.github.io/autogen/dev/user-guide/agentchat-user-guide/tutorial/teams.html\",\n", + " \"https://microsoft.github.io/autogen/dev/user-guide/agentchat-user-guide/tutorial/termination.html\",\n", + " ]\n", + " chunks: int = await indexer.index_documents(sources)\n", + " print(f\"Indexed {chunks} chunks from {len(sources)} AutoGen documents\")\n", + "\n", + "\n", + "await index_autogen_docs()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "---------- user ----------\n", + "What is AgentChat?\n", + "Query results: results=[MemoryContent(content='ng OpenAI\\'s GPT-4o model. See [other supported models](https://microsoft.github.io/autogen/stable/user-guide/agentchat-user-guide/tutorial/models.html). ```python import asyncio from autogen_agentchat.agents import AssistantAgent from autogen_ext.models.openai import OpenAIChatCompletionClient async def main() -> None: model_client = OpenAIChatCompletionClient(model=\"gpt-4o\") agent = AssistantAgent(\"assistant\", model_client=model_client) print(await agent.run(task=\"Say \\'Hello World!\\'\")) await model_client.close() asyncio.run(main()) ``` ### Web Browsing Agent Team Create a group chat team with a web surfer agent and a user proxy agent for web browsing tasks. You need to install [playwright](https://playwright.dev/python/docs/library). ```python # pip install -U autogen-agentchat autogen-ext[openai,web-surfer] # playwright install import asyncio from autogen_agentchat.agents import UserProxyAgent from autogen_agentchat.conditions import TextMentionTermination from autogen_agentchat.teams import RoundRobinGroupChat from autogen_agentchat.ui import Console from autogen_ext.models.openai import OpenAIChatCompletionClient from autogen_ext.agents.web_surfer import MultimodalWebSurfer async def main() -> None: model_client = OpenAIChatCompletionClient(model=\"gpt-4o\") # The web surfer will open a Chromium browser window to perform web browsing tasks. web_surfer = MultimodalWebSurfer(\"web_surfer\", model_client, headless=False, animate_actions=True) # The user proxy agent is used to ge', mime_type='MemoryMimeType.TEXT', metadata={'chunk_index': 1, 'mime_type': 'MemoryMimeType.TEXT', 'source': 'https://raw.githubusercontent.com/microsoft/autogen/main/README.md', 'score': 0.48810458183288574, 'id': '16088e03-0153-4da3-9dec-643b39c549f5'}), MemoryContent(content='els_usage=None content='AutoGen is a programming framework for building multi-agent applications.' type='ToolCallSummaryMessage' The call to the on_messages() method returns a Response that contains the agent’s final response in the chat_message attribute, as well as a list of inner messages in the inner_messages attribute, which stores the agent’s “thought process” that led to the final response. Note It is important to note that on_messages() will update the internal state of the agent – it will add the messages to the agent’s history. So you should call this method with new messages. You should not repeatedly call this method with the same messages or the complete history. Note Unlike in v0.2 AgentChat, the tools are executed by the same agent directly within the same call to on_messages() . By default, the agent will return the result of the tool call as the final response. You can also call the run() method, which is a convenience method that calls on_messages() . It follows the same interface as Teams and returns a TaskResult object. Multi-Modal Input # The AssistantAgent can handle multi-modal input by providing the input as a MultiModalMessage . from io import BytesIO import PIL import requests from autogen_agentchat.messages import MultiModalMessage from autogen_core import Image # Create a multi-modal message with random image and text. pil_image = PIL . Image . open ( BytesIO ( requests . get ( "https://picsum.photos/300/200" ) . content )', mime_type='MemoryMimeType.TEXT', metadata={'chunk_index': 3, 'mime_type': 'MemoryMimeType.TEXT', 'source': 'https://microsoft.github.io/autogen/dev/user-guide/agentchat-user-guide/tutorial/agents.html', 'score': 0.4665141701698303, 'id': '3d603b62-7cab-4f74-b671-586fe36306f2'}), MemoryContent(content='AgentChat Termination Termination # In the previous section, we explored how to define agents, and organize them into teams that can solve tasks. However, a run can go on forever, and in many cases, we need to know when to stop them. This is the role of the termination condition. AgentChat supports several termination condition by providing a base TerminationCondition class and several implementations that inherit from it. A termination condition is a callable that takes a sequence of BaseAgentEvent or BaseChatMessage objects since the last time the condition was called , and returns a StopMessage if the conversation should be terminated, or None otherwise. Once a termination condition has been reached, it must be reset by calling reset() before it can be used again. Some important things to note about termination conditions: They are stateful but reset automatically after each run ( run() or run_stream() ) is finished. They can be combined using the AND and OR operators. Note For group chat teams (i.e., RoundRobinGroupChat , SelectorGroupChat , and Swarm ), the termination condition is called after each agent responds. While a response may contain multiple inner messages, the team calls its termination condition just once for all the messages from a single response. So the condition is called with the “delta sequence” of messages since the last time it was called. Built-In Termination Conditions: MaxMessageTermination : Stops after a specified number of messages have been produced,', mime_type='MemoryMimeType.TEXT', metadata={'chunk_index': 1, 'mime_type': 'MemoryMimeType.TEXT', 'source': 'https://microsoft.github.io/autogen/dev/user-guide/agentchat-user-guide/tutorial/termination.html', 'score': 0.461774212772051, 'id': '699ef490-d108-4cd3-b629-c1198d6b78ba'})]\n", + "---------- rag_assistant ----------\n", + "[MemoryContent(content='ng OpenAI\\'s GPT-4o model. See [other supported models](https://microsoft.github.io/autogen/stable/user-guide/agentchat-user-guide/tutorial/models.html). ```python import asyncio from autogen_agentchat.agents import AssistantAgent from autogen_ext.models.openai import OpenAIChatCompletionClient async def main() -> None: model_client = OpenAIChatCompletionClient(model=\"gpt-4o\") agent = AssistantAgent(\"assistant\", model_client=model_client) print(await agent.run(task=\"Say \\'Hello World!\\'\")) await model_client.close() asyncio.run(main()) ``` ### Web Browsing Agent Team Create a group chat team with a web surfer agent and a user proxy agent for web browsing tasks. You need to install [playwright](https://playwright.dev/python/docs/library). ```python # pip install -U autogen-agentchat autogen-ext[openai,web-surfer] # playwright install import asyncio from autogen_agentchat.agents import UserProxyAgent from autogen_agentchat.conditions import TextMentionTermination from autogen_agentchat.teams import RoundRobinGroupChat from autogen_agentchat.ui import Console from autogen_ext.models.openai import OpenAIChatCompletionClient from autogen_ext.agents.web_surfer import MultimodalWebSurfer async def main() -> None: model_client = OpenAIChatCompletionClient(model=\"gpt-4o\") # The web surfer will open a Chromium browser window to perform web browsing tasks. web_surfer = MultimodalWebSurfer(\"web_surfer\", model_client, headless=False, animate_actions=True) # The user proxy agent is used to ge', mime_type='MemoryMimeType.TEXT', metadata={'chunk_index': 1, 'mime_type': 'MemoryMimeType.TEXT', 'source': 'https://raw.githubusercontent.com/microsoft/autogen/main/README.md', 'score': 0.48810458183288574, 'id': '16088e03-0153-4da3-9dec-643b39c549f5'}), MemoryContent(content='els_usage=None content='AutoGen is a programming framework for building multi-agent applications.' type='ToolCallSummaryMessage' The call to the on_messages() method returns a Response that contains the agent’s final response in the chat_message attribute, as well as a list of inner messages in the inner_messages attribute, which stores the agent’s “thought process” that led to the final response. Note It is important to note that on_messages() will update the internal state of the agent – it will add the messages to the agent’s history. So you should call this method with new messages. You should not repeatedly call this method with the same messages or the complete history. Note Unlike in v0.2 AgentChat, the tools are executed by the same agent directly within the same call to on_messages() . By default, the agent will return the result of the tool call as the final response. You can also call the run() method, which is a convenience method that calls on_messages() . It follows the same interface as Teams and returns a TaskResult object. Multi-Modal Input # The AssistantAgent can handle multi-modal input by providing the input as a MultiModalMessage . from io import BytesIO import PIL import requests from autogen_agentchat.messages import MultiModalMessage from autogen_core import Image # Create a multi-modal message with random image and text. pil_image = PIL . Image . open ( BytesIO ( requests . get ( "https://picsum.photos/300/200" ) . content )', mime_type='MemoryMimeType.TEXT', metadata={'chunk_index': 3, 'mime_type': 'MemoryMimeType.TEXT', 'source': 'https://microsoft.github.io/autogen/dev/user-guide/agentchat-user-guide/tutorial/agents.html', 'score': 0.4665141701698303, 'id': '3d603b62-7cab-4f74-b671-586fe36306f2'}), MemoryContent(content='AgentChat Termination Termination # In the previous section, we explored how to define agents, and organize them into teams that can solve tasks. However, a run can go on forever, and in many cases, we need to know when to stop them. This is the role of the termination condition. AgentChat supports several termination condition by providing a base TerminationCondition class and several implementations that inherit from it. A termination condition is a callable that takes a sequenceBaseChatMessageent or BaseChatMessage objects since the last time the condition was called , and returns a StopMessage if the conversation should be terminated, or None otherwise. Once a termination condition has been reached, it must be reset by calling reset() before it can be used again. Some important things to note about termination conditions: They are stateful but reset automatically after each run ( run() or run_stream() ) is finished. They can be combined using the AND and OR operators. Note For group chat teams (i.e., RoundRobinGroupChat , SelectorGroupChat , and Swarm ), the termination condition is called after each agent responds. While a response may contain multiple inner messages, the team calls its termination condition just once for all the messages from a single response. So the condition is called with the “delta sequence” of messages since the last time it was called. Built-In Termination Conditions: MaxMessageTermination : Stops after a specified number of messages have been produced,', mime_type='MemoryMimeType.TEXT', metadata={'chunk_index': 1, 'mime_type': 'MemoryMimeType.TEXT', 'source': 'https://microsoft.github.io/autogen/dev/user-guide/agentchat-user-guide/tutorial/termination.html', 'score': 0.461774212772051, 'id': '699ef490-d108-4cd3-b629-c1198d6b78ba'})]\n", + "---------- rag_assistant ----------\n", + "AgentChat is part of the AutoGen framework, a programming environment for building multi-agent applications. In AgentChat, agents can interact with each other and with users to perform various tasks, including web browsing and engaging in dialogue. It utilizes models from OpenAI for chat completions and supports multi-modal input, which means agents can handle inputs that include both text and images. Additionally, AgentChat provides mechanisms to define termination conditions to control when a conversation or task should be concluded, ensuring that the agent interactions are efficient and goal-oriented. TERMINATE\n" + ] + } + ], + "source": [ + "# Create our RAG assistant agent\n", + "rag_assistant = AssistantAgent(\n", + " name=\"rag_assistant\", model_client=OpenAIChatCompletionClient(model=\"gpt-4o\"), memory=[rag_memory]\n", + ")\n", + "\n", + "# Ask questions about AutoGen\n", + "stream = rag_assistant.run_stream(task=\"What is AgentChat?\")\n", + "await Console(stream)\n", + "\n", + "# Remember to close the memory when done\n", + "await rag_memory.close()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This implementation provides a RAG agent that can answer questions based on AutoGen documentation. When a question is asked, the Memory system retrieves relevant chunks and adds them to the context, enabling the assistant to generate informed responses.\n", + "\n", + "For production systems, you might want to:\n", + "1. Implement more sophisticated chunking strategies\n", + "2. Add metadata filtering capabilities\n", + "3. Customize the retrieval scoring\n", + "4. Optimize embedding models for your specific domain\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 } diff --git a/python/packages/autogen-ext/src/autogen_ext/memory/chromadb/__init__.py b/python/packages/autogen-ext/src/autogen_ext/memory/chromadb/__init__.py new file mode 100644 index 000000000..1d6ad04a0 --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/memory/chromadb/__init__.py @@ -0,0 +1,21 @@ +from ._chroma_configs import ( + ChromaDBVectorMemoryConfig, + CustomEmbeddingFunctionConfig, + DefaultEmbeddingFunctionConfig, + HttpChromaDBVectorMemoryConfig, + OpenAIEmbeddingFunctionConfig, + PersistentChromaDBVectorMemoryConfig, + SentenceTransformerEmbeddingFunctionConfig, +) +from ._chromadb import ChromaDBVectorMemory + +__all__ = [ + "ChromaDBVectorMemory", + "ChromaDBVectorMemoryConfig", + "PersistentChromaDBVectorMemoryConfig", + "HttpChromaDBVectorMemoryConfig", + "DefaultEmbeddingFunctionConfig", + "SentenceTransformerEmbeddingFunctionConfig", + "OpenAIEmbeddingFunctionConfig", + "CustomEmbeddingFunctionConfig", +] diff --git a/python/packages/autogen-ext/src/autogen_ext/memory/chromadb/_chroma_configs.py b/python/packages/autogen-ext/src/autogen_ext/memory/chromadb/_chroma_configs.py new file mode 100644 index 000000000..77e084104 --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/memory/chromadb/_chroma_configs.py @@ -0,0 +1,148 @@ +"""Configuration classes for ChromaDB vector memory.""" + +from typing import Any, Callable, Dict, Literal, Union + +from pydantic import BaseModel, Field +from typing_extensions import Annotated + + +class DefaultEmbeddingFunctionConfig(BaseModel): + """Configuration for the default ChromaDB embedding function. + + Uses ChromaDB's default embedding function (Sentence Transformers all-MiniLM-L6-v2). + + .. versionadded:: v0.4.1 + Support for custom embedding functions in ChromaDB memory. + """ + + function_type: Literal["default"] = "default" + + +class SentenceTransformerEmbeddingFunctionConfig(BaseModel): + """Configuration for SentenceTransformer embedding functions. + + Allows specifying a custom SentenceTransformer model for embeddings. + + .. versionadded:: v0.4.1 + Support for custom embedding functions in ChromaDB memory. + + Args: + model_name (str): Name of the SentenceTransformer model to use. + Defaults to "all-MiniLM-L6-v2". + + Example: + .. code-block:: python + + config = SentenceTransformerEmbeddingFunctionConfig(model_name="paraphrase-multilingual-mpnet-base-v2") + """ + + function_type: Literal["sentence_transformer"] = "sentence_transformer" + model_name: str = Field(default="all-MiniLM-L6-v2", description="SentenceTransformer model name to use") + + +class OpenAIEmbeddingFunctionConfig(BaseModel): + """Configuration for OpenAI embedding functions. + + Uses OpenAI's embedding API for generating embeddings. + + .. versionadded:: v0.4.1 + Support for custom embedding functions in ChromaDB memory. + + Args: + api_key (str): OpenAI API key. If empty, will attempt to use environment variable. + model_name (str): OpenAI embedding model name. Defaults to "text-embedding-ada-002". + + Example: + .. code-block:: python + + config = OpenAIEmbeddingFunctionConfig(api_key="sk-...", model_name="text-embedding-3-small") + """ + + function_type: Literal["openai"] = "openai" + api_key: str = Field(default="", description="OpenAI API key") + model_name: str = Field(default="text-embedding-ada-002", description="OpenAI embedding model name") + + +class CustomEmbeddingFunctionConfig(BaseModel): + """Configuration for custom embedding functions. + + Allows using a custom function that returns a ChromaDB-compatible embedding function. + + .. versionadded:: v0.4.1 + Support for custom embedding functions in ChromaDB memory. + + .. warning:: + Configurations containing custom functions are not serializable. + + Args: + function (Callable): Function that returns a ChromaDB-compatible embedding function. + params (Dict[str, Any]): Parameters to pass to the function. + + Example: + .. code-block:: python + + def create_my_embedder(param1="default"): + # Return a ChromaDB-compatible embedding function + class MyCustomEmbeddingFunction(EmbeddingFunction): + def __call__(self, input: Documents) -> Embeddings: + # Custom embedding logic here + return embeddings + + return MyCustomEmbeddingFunction(param1) + + + config = CustomEmbeddingFunctionConfig(function=create_my_embedder, params={"param1": "custom_value"}) + """ + + function_type: Literal["custom"] = "custom" + function: Callable[..., Any] = Field(description="Function that returns an embedding function") + params: Dict[str, Any] = Field(default_factory=dict, description="Parameters to pass to the function") + + +# Tagged union type for embedding function configurations +EmbeddingFunctionConfig = Annotated[ + Union[ + DefaultEmbeddingFunctionConfig, + SentenceTransformerEmbeddingFunctionConfig, + OpenAIEmbeddingFunctionConfig, + CustomEmbeddingFunctionConfig, + ], + Field(discriminator="function_type"), +] + + +class ChromaDBVectorMemoryConfig(BaseModel): + """Base configuration for ChromaDB-based memory implementation. + + .. versionchanged:: v0.4.1 + Added support for custom embedding functions via embedding_function_config. + """ + + client_type: Literal["persistent", "http"] + collection_name: str = Field(default="memory_store", description="Name of the ChromaDB collection") + distance_metric: str = Field(default="cosine", description="Distance metric for similarity search") + k: int = Field(default=3, description="Number of results to return in queries") + score_threshold: float | None = Field(default=None, description="Minimum similarity score threshold") + allow_reset: bool = Field(default=False, description="Whether to allow resetting the ChromaDB client") + tenant: str = Field(default="default_tenant", description="Tenant to use") + database: str = Field(default="default_database", description="Database to use") + embedding_function_config: EmbeddingFunctionConfig = Field( + default_factory=DefaultEmbeddingFunctionConfig, description="Configuration for the embedding function" + ) + + +class PersistentChromaDBVectorMemoryConfig(ChromaDBVectorMemoryConfig): + """Configuration for persistent ChromaDB memory.""" + + client_type: Literal["persistent", "http"] = "persistent" + persistence_path: str = Field(default="./chroma_db", description="Path for persistent storage") + + +class HttpChromaDBVectorMemoryConfig(ChromaDBVectorMemoryConfig): + """Configuration for HTTP ChromaDB memory.""" + + client_type: Literal["persistent", "http"] = "http" + host: str = Field(default="localhost", description="Host of the remote server") + port: int = Field(default=8000, description="Port of the remote server") + ssl: bool = Field(default=False, description="Whether to use HTTPS") + headers: Dict[str, str] | None = Field(default=None, description="Headers to send to the server") diff --git a/python/packages/autogen-ext/src/autogen_ext/memory/chromadb.py b/python/packages/autogen-ext/src/autogen_ext/memory/chromadb/_chromadb.py similarity index 78% rename from python/packages/autogen-ext/src/autogen_ext/memory/chromadb.py rename to python/packages/autogen-ext/src/autogen_ext/memory/chromadb/_chromadb.py index 3732544fa..ac3f531eb 100644 --- a/python/packages/autogen-ext/src/autogen_ext/memory/chromadb.py +++ b/python/packages/autogen-ext/src/autogen_ext/memory/chromadb/_chromadb.py @@ -1,6 +1,6 @@ import logging import uuid -from typing import Any, Dict, List, Literal +from typing import Any, List from autogen_core import CancellationToken, Component, Image from autogen_core.memory import Memory, MemoryContent, MemoryMimeType, MemoryQueryResult, UpdateContextResult @@ -9,9 +9,18 @@ from autogen_core.models import SystemMessage from chromadb import HttpClient, PersistentClient from chromadb.api.models.Collection import Collection from chromadb.api.types import Document, Metadata -from pydantic import BaseModel, Field from typing_extensions import Self +from ._chroma_configs import ( + ChromaDBVectorMemoryConfig, + CustomEmbeddingFunctionConfig, + DefaultEmbeddingFunctionConfig, + HttpChromaDBVectorMemoryConfig, + OpenAIEmbeddingFunctionConfig, + PersistentChromaDBVectorMemoryConfig, + SentenceTransformerEmbeddingFunctionConfig, +) + logger = logging.getLogger(__name__) @@ -23,36 +32,6 @@ except ImportError as e: ) from e -class ChromaDBVectorMemoryConfig(BaseModel): - """Base configuration for ChromaDB-based memory implementation.""" - - client_type: Literal["persistent", "http"] - collection_name: str = Field(default="memory_store", description="Name of the ChromaDB collection") - distance_metric: str = Field(default="cosine", description="Distance metric for similarity search") - k: int = Field(default=3, description="Number of results to return in queries") - score_threshold: float | None = Field(default=None, description="Minimum similarity score threshold") - allow_reset: bool = Field(default=False, description="Whether to allow resetting the ChromaDB client") - tenant: str = Field(default="default_tenant", description="Tenant to use") - database: str = Field(default="default_database", description="Database to use") - - -class PersistentChromaDBVectorMemoryConfig(ChromaDBVectorMemoryConfig): - """Configuration for persistent ChromaDB memory.""" - - client_type: Literal["persistent", "http"] = "persistent" - persistence_path: str = Field(default="./chroma_db", description="Path for persistent storage") - - -class HttpChromaDBVectorMemoryConfig(ChromaDBVectorMemoryConfig): - """Configuration for HTTP ChromaDB memory.""" - - client_type: Literal["persistent", "http"] = "http" - host: str = Field(default="localhost", description="Host of the remote server") - port: int = Field(default=8000, description="Port of the remote server") - ssl: bool = Field(default=False, description="Whether to use HTTPS") - headers: Dict[str, str] | None = Field(default=None, description="Headers to send to the server") - - class ChromaDBVectorMemory(Memory, Component[ChromaDBVectorMemoryConfig]): """ Store and retrieve memory using vector similarity search powered by ChromaDB. @@ -86,10 +65,15 @@ class ChromaDBVectorMemory(Memory, Component[ChromaDBVectorMemoryConfig]): from pathlib import Path from autogen_agentchat.agents import AssistantAgent from autogen_core.memory import MemoryContent, MemoryMimeType - from autogen_ext.memory.chromadb import ChromaDBVectorMemory, PersistentChromaDBVectorMemoryConfig + from autogen_ext.memory.chromadb import ( + ChromaDBVectorMemory, + PersistentChromaDBVectorMemoryConfig, + SentenceTransformerEmbeddingFunctionConfig, + OpenAIEmbeddingFunctionConfig, + ) from autogen_ext.models.openai import OpenAIChatCompletionClient - # Initialize ChromaDB memory with custom config + # Initialize ChromaDB memory with default embedding function memory = ChromaDBVectorMemory( config=PersistentChromaDBVectorMemoryConfig( collection_name="user_preferences", @@ -99,6 +83,28 @@ class ChromaDBVectorMemory(Memory, Component[ChromaDBVectorMemoryConfig]): ) ) + # Using a custom SentenceTransformer model + memory_custom_st = ChromaDBVectorMemory( + config=PersistentChromaDBVectorMemoryConfig( + collection_name="multilingual_memory", + persistence_path=os.path.join(str(Path.home()), ".chromadb_autogen"), + embedding_function_config=SentenceTransformerEmbeddingFunctionConfig( + model_name="paraphrase-multilingual-mpnet-base-v2" + ), + ) + ) + + # Using OpenAI embeddings + memory_openai = ChromaDBVectorMemory( + config=PersistentChromaDBVectorMemoryConfig( + collection_name="openai_memory", + persistence_path=os.path.join(str(Path.home()), ".chromadb_autogen"), + embedding_function_config=OpenAIEmbeddingFunctionConfig( + api_key="sk-...", model_name="text-embedding-3-small" + ), + ) + ) + # Add user preferences to memory await memory.add( MemoryContent( @@ -138,6 +144,55 @@ class ChromaDBVectorMemory(Memory, Component[ChromaDBVectorMemoryConfig]): """Get the name of the ChromaDB collection.""" return self._config.collection_name + def _create_embedding_function(self) -> Any: + """Create an embedding function based on the configuration. + + Returns: + A ChromaDB-compatible embedding function. + + Raises: + ValueError: If the embedding function type is unsupported. + ImportError: If required dependencies are not installed. + """ + try: + from chromadb.utils import embedding_functions + except ImportError as e: + raise ImportError( + "ChromaDB embedding functions not available. Ensure chromadb is properly installed." + ) from e + + config = self._config.embedding_function_config + + if isinstance(config, DefaultEmbeddingFunctionConfig): + return embedding_functions.DefaultEmbeddingFunction() + + elif isinstance(config, SentenceTransformerEmbeddingFunctionConfig): + try: + return embedding_functions.SentenceTransformerEmbeddingFunction(model_name=config.model_name) + except Exception as e: + raise ImportError( + f"Failed to create SentenceTransformer embedding function with model '{config.model_name}'. " + f"Ensure sentence-transformers is installed and the model is available. Error: {e}" + ) from e + + elif isinstance(config, OpenAIEmbeddingFunctionConfig): + try: + return embedding_functions.OpenAIEmbeddingFunction(api_key=config.api_key, model_name=config.model_name) + except Exception as e: + raise ImportError( + f"Failed to create OpenAI embedding function with model '{config.model_name}'. " + f"Ensure openai is installed and API key is valid. Error: {e}" + ) from e + + elif isinstance(config, CustomEmbeddingFunctionConfig): + try: + return config.function(**config.params) + except Exception as e: + raise ValueError(f"Failed to create custom embedding function. Error: {e}") from e + + else: + raise ValueError(f"Unsupported embedding function config type: {type(config)}") + def _ensure_initialized(self) -> None: """Ensure ChromaDB client and collection are initialized.""" if self._client is None: @@ -171,8 +226,14 @@ class ChromaDBVectorMemory(Memory, Component[ChromaDBVectorMemoryConfig]): if self._collection is None: try: + # Create embedding function + embedding_function = self._create_embedding_function() + + # Create or get collection with embedding function self._collection = self._client.get_or_create_collection( - name=self._config.collection_name, metadata={"distance_metric": self._config.distance_metric} + name=self._config.collection_name, + metadata={"distance_metric": self._config.distance_metric}, + embedding_function=embedding_function, ) except Exception as e: logger.error(f"Failed to get/create collection: {e}") diff --git a/python/packages/autogen-ext/tests/memory/test_chroma_memory.py b/python/packages/autogen-ext/tests/memory/test_chroma_memory.py index 163b77f2c..f62c91c7d 100644 --- a/python/packages/autogen-ext/tests/memory/test_chroma_memory.py +++ b/python/packages/autogen-ext/tests/memory/test_chroma_memory.py @@ -4,7 +4,21 @@ import pytest from autogen_core.memory import MemoryContent, MemoryMimeType from autogen_core.model_context import BufferedChatCompletionContext from autogen_core.models import UserMessage -from autogen_ext.memory.chromadb import ChromaDBVectorMemory, PersistentChromaDBVectorMemoryConfig +from autogen_ext.memory.chromadb import ( + ChromaDBVectorMemory, + CustomEmbeddingFunctionConfig, + DefaultEmbeddingFunctionConfig, + HttpChromaDBVectorMemoryConfig, + OpenAIEmbeddingFunctionConfig, + PersistentChromaDBVectorMemoryConfig, + SentenceTransformerEmbeddingFunctionConfig, +) + +# Skip all tests if ChromaDB is not available +try: + import chromadb # pyright: ignore[reportUnusedImport] +except ImportError: + pytest.skip("ChromaDB not available", allow_module_level=True) @pytest.fixture @@ -240,3 +254,189 @@ async def test_component_serialization(base_config: PersistentChromaDBVectorMemo await memory.close() await loaded_memory.close() + + +@pytest.mark.asyncio +def test_http_config(tmp_path: Path) -> None: + """Test HTTP ChromaDB configuration.""" + config = HttpChromaDBVectorMemoryConfig( + collection_name="test_http", + host="localhost", + port=8000, + ssl=False, + headers={"Authorization": "Bearer test-token"}, + ) + + assert config.client_type == "http" + assert config.host == "localhost" + assert config.port == 8000 + assert config.ssl is False + assert config.headers == {"Authorization": "Bearer test-token"} + + +# ============================================================================ +# Embedding Function Configuration Tests +# ============================================================================ + + +@pytest.mark.asyncio +async def test_default_embedding_function(tmp_path: Path) -> None: + """Test ChromaDB memory with default embedding function.""" + config = PersistentChromaDBVectorMemoryConfig( + collection_name="test_default_embedding", + allow_reset=True, + persistence_path=str(tmp_path / "chroma_db_default"), + embedding_function_config=DefaultEmbeddingFunctionConfig(), + ) + + memory = ChromaDBVectorMemory(config=config) + await memory.clear() + + # Add test content + await memory.add( + MemoryContent( + content="Default embedding function test content", + mime_type=MemoryMimeType.TEXT, + metadata={"test": "default_embedding"}, + ) + ) + + # Query and verify + results = await memory.query("default embedding test") + assert len(results.results) > 0 + assert any("Default embedding" in str(r.content) for r in results.results) + + await memory.close() + + +@pytest.mark.asyncio +async def test_sentence_transformer_embedding_function(tmp_path: Path) -> None: + """Test ChromaDB memory with SentenceTransformer embedding function.""" + config = PersistentChromaDBVectorMemoryConfig( + collection_name="test_st_embedding", + allow_reset=True, + persistence_path=str(tmp_path / "chroma_db_st"), + embedding_function_config=SentenceTransformerEmbeddingFunctionConfig( + model_name="all-MiniLM-L6-v2" # Use default model for testing + ), + ) + + memory = ChromaDBVectorMemory(config=config) + await memory.clear() + + # Add test content + await memory.add( + MemoryContent( + content="SentenceTransformer embedding function test content", + mime_type=MemoryMimeType.TEXT, + metadata={"test": "sentence_transformer"}, + ) + ) + + # Query and verify + results = await memory.query("SentenceTransformer embedding test") + assert len(results.results) > 0 + assert any("SentenceTransformer" in str(r.content) for r in results.results) + + await memory.close() + + +@pytest.mark.asyncio +async def test_custom_embedding_function(tmp_path: Path) -> None: + """Test ChromaDB memory with custom embedding function.""" + from collections.abc import Sequence + + class MockEmbeddingFunction: + def __call__(self, input: Sequence[str]) -> list[list[float]]: + # Return a batch of embeddings (list of lists) + return [[0.0] * 384 for _ in input] + + config = PersistentChromaDBVectorMemoryConfig( + collection_name="test_custom_embedding", + allow_reset=True, + persistence_path=str(tmp_path / "chroma_db_custom"), + embedding_function_config=CustomEmbeddingFunctionConfig(function=MockEmbeddingFunction, params={}), + ) + memory = ChromaDBVectorMemory(config=config) + await memory.clear() + await memory.add( + MemoryContent( + content="Custom embedding function test content", + mime_type=MemoryMimeType.TEXT, + metadata={"test": "custom_embedding"}, + ) + ) + results = await memory.query("custom embedding test") + assert len(results.results) > 0 + assert any("Custom embedding" in str(r.content) for r in results.results) + await memory.close() + + +@pytest.mark.asyncio +async def test_openai_embedding_function(tmp_path: Path) -> None: + """Test OpenAI embedding function configuration (without actual API call).""" + config = PersistentChromaDBVectorMemoryConfig( + collection_name="test_openai_embedding", + allow_reset=True, + persistence_path=str(tmp_path / "chroma_db_openai"), + embedding_function_config=OpenAIEmbeddingFunctionConfig( + api_key="test-key", model_name="text-embedding-3-small" + ), + ) + + # Just test that the config is valid - don't actually try to use OpenAI API + assert config.embedding_function_config.function_type == "openai" + assert config.embedding_function_config.api_key == "test-key" + assert config.embedding_function_config.model_name == "text-embedding-3-small" + + +@pytest.mark.asyncio +async def test_embedding_function_error_handling(tmp_path: Path) -> None: + """Test error handling for embedding function configurations.""" + + def failing_embedding_function() -> None: + """A function that raises an error.""" + raise ValueError("Test embedding function error") + + config = PersistentChromaDBVectorMemoryConfig( + collection_name="test_error_embedding", + allow_reset=True, + persistence_path=str(tmp_path / "chroma_db_error"), + embedding_function_config=CustomEmbeddingFunctionConfig(function=failing_embedding_function, params={}), + ) + + memory = ChromaDBVectorMemory(config=config) + + # Should raise an error when trying to initialize + with pytest.raises((ValueError, Exception)): # Catch ValueError or any other exception + await memory.add(MemoryContent(content="This should fail", mime_type=MemoryMimeType.TEXT)) + + await memory.close() + + +def test_embedding_function_config_validation() -> None: + """Test validation of embedding function configurations.""" + + # Test default config + default_config = DefaultEmbeddingFunctionConfig() + assert default_config.function_type == "default" + + # Test SentenceTransformer config + st_config = SentenceTransformerEmbeddingFunctionConfig(model_name="test-model") + assert st_config.function_type == "sentence_transformer" + assert st_config.model_name == "test-model" + + # Test OpenAI config + openai_config = OpenAIEmbeddingFunctionConfig(api_key="test-key", model_name="test-model") + assert openai_config.function_type == "openai" + assert openai_config.api_key == "test-key" + assert openai_config.model_name == "test-model" + + # Test custom config + def dummy_function() -> None: + return None + + custom_config = CustomEmbeddingFunctionConfig(function=dummy_function, params={"test": "value"}) + assert custom_config.function_type == "custom" + assert custom_config.function == dummy_function + assert custom_config.params == {"test": "value"}