Feature/chromadb embedding functions #6267 (#6648)

## Why are these changes needed?

This PR adds support for configurable embedding functions in
ChromaDBVectorMemory, addressing the need for users to customize how
embeddings are generated for vector similarity search. Currently,
ChromaDB memory is limited to default embedding functions, which
restricts flexibility for different use cases that may require specific
embedding models or custom embedding logic.

The implementation allows users to:
- Use different SentenceTransformer models for domain-specific
embeddings
- Integrate with OpenAI's embedding API for consistent embedding
generation
- Define custom embedding functions for specialized requirements
- Maintain backward compatibility with existing default behavior

## Related issue number

Closes #6267

## Checks

- [x] I've included any doc changes needed for
<https://microsoft.github.io/autogen/>. See
<https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to
build and test documentation locally.
- [x] I've added tests corresponding to the changes introduced in this
PR.
- [x] I've made sure all auto checks have passed.

---------

Co-authored-by: Victor Dibia <victordibia@microsoft.com>
Co-authored-by: Victor Dibia <victor.dibia@gmail.com>
This commit is contained in:
Tejas Dharani 2025-06-13 21:36:15 +05:30 committed by GitHub
parent 150ea0192d
commit 67ebeeda0e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 1029 additions and 475 deletions

View File

@ -26,7 +26,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@ -38,7 +38,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@ -72,9 +72,36 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"---------- TextMessage (user) ----------\n",
"What is the weather in New York?\n",
"---------- MemoryQueryEvent (assistant_agent) ----------\n",
"[MemoryContent(content='The weather should be in metric units', mime_type=<MemoryMimeType.TEXT: 'text/plain'>, metadata=None), MemoryContent(content='Meal recipe must be vegan', mime_type=<MemoryMimeType.TEXT: 'text/plain'>, metadata=None)]\n",
"---------- ToolCallRequestEvent (assistant_agent) ----------\n",
"[FunctionCall(id='call_apWw5JOedVvqsPfXWV7c5Uiw', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')]\n",
"---------- ToolCallExecutionEvent (assistant_agent) ----------\n",
"[FunctionExecutionResult(content='The weather in New York is 23 °C and Sunny.', name='get_weather', call_id='call_apWw5JOedVvqsPfXWV7c5Uiw', is_error=False)]\n",
"---------- ToolCallSummaryMessage (assistant_agent) ----------\n",
"The weather in New York is 23 °C and Sunny.\n"
]
},
{
"data": {
"text/plain": [
"TaskResult(messages=[TextMessage(source='user', models_usage=None, metadata={}, created_at=datetime.datetime(2025, 6, 12, 17, 46, 33, 492791, tzinfo=datetime.timezone.utc), content='What is the weather in New York?', type='TextMessage'), MemoryQueryEvent(source='assistant_agent', models_usage=None, metadata={}, created_at=datetime.datetime(2025, 6, 12, 17, 46, 33, 494162, tzinfo=datetime.timezone.utc), content=[MemoryContent(content='The weather should be in metric units', mime_type=<MemoryMimeType.TEXT: 'text/plain'>, metadata=None), MemoryContent(content='Meal recipe must be vegan', mime_type=<MemoryMimeType.TEXT: 'text/plain'>, metadata=None)], type='MemoryQueryEvent'), ToolCallRequestEvent(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=123, completion_tokens=19), metadata={}, created_at=datetime.datetime(2025, 6, 12, 17, 46, 34, 892272, tzinfo=datetime.timezone.utc), content=[FunctionCall(id='call_apWw5JOedVvqsPfXWV7c5Uiw', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='assistant_agent', models_usage=None, metadata={}, created_at=datetime.datetime(2025, 6, 12, 17, 46, 34, 894081, tzinfo=datetime.timezone.utc), content=[FunctionExecutionResult(content='The weather in New York is 23 °C and Sunny.', name='get_weather', call_id='call_apWw5JOedVvqsPfXWV7c5Uiw', is_error=False)], type='ToolCallExecutionEvent'), ToolCallSummaryMessage(source='assistant_agent', models_usage=None, metadata={}, created_at=datetime.datetime(2025, 6, 12, 17, 46, 34, 895054, tzinfo=datetime.timezone.utc), content='The weather in New York is 23 °C and Sunny.', type='ToolCallSummaryMessage', tool_calls=[FunctionCall(id='call_apWw5JOedVvqsPfXWV7c5Uiw', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], results=[FunctionExecutionResult(content='The weather in New York is 23 °C and Sunny.', name='get_weather', call_id='call_apWw5JOedVvqsPfXWV7c5Uiw', is_error=False)])], stop_reason=None)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Run the agent with a task.\n",
"stream = assistant_agent.run_stream(task=\"What is the weather in New York?\")\n",
@ -90,9 +117,23 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"[UserMessage(content='What is the weather in New York?', source='user', type='UserMessage'),\n",
" SystemMessage(content='\\nRelevant memory content (in chronological order):\\n1. The weather should be in metric units\\n2. Meal recipe must be vegan\\n', type='SystemMessage'),\n",
" AssistantMessage(content=[FunctionCall(id='call_apWw5JOedVvqsPfXWV7c5Uiw', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], thought=None, source='assistant_agent', type='AssistantMessage'),\n",
" FunctionExecutionResultMessage(content=[FunctionExecutionResult(content='The weather in New York is 23 °C and Sunny.', name='get_weather', call_id='call_apWw5JOedVvqsPfXWV7c5Uiw', is_error=False)], type='FunctionExecutionResultMessage')]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"await assistant_agent._model_context.get_messages()"
]
@ -108,9 +149,57 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"---------- TextMessage (user) ----------\n",
"Write brief meal recipe with broth\n",
"---------- MemoryQueryEvent (assistant_agent) ----------\n",
"[MemoryContent(content='The weather should be in metric units', mime_type=<MemoryMimeType.TEXT: 'text/plain'>, metadata=None), MemoryContent(content='Meal recipe must be vegan', mime_type=<MemoryMimeType.TEXT: 'text/plain'>, metadata=None)]\n",
"---------- TextMessage (assistant_agent) ----------\n",
"Here's another vegan broth-based recipe:\n",
"\n",
"**Vegan Miso Soup**\n",
"\n",
"**Ingredients:**\n",
"- 4 cups vegetable broth\n",
"- 3 tablespoons white miso paste\n",
"- 1 block firm tofu, cubed\n",
"- 1 cup mushrooms, sliced (shiitake or any variety you prefer)\n",
"- 2 green onions, chopped\n",
"- 1 tablespoon soy sauce (optional)\n",
"- 1/2 cup seaweed (such as wakame)\n",
"- 1 tablespoon sesame oil\n",
"- 1 tablespoon grated ginger\n",
"- Salt to taste\n",
"\n",
"**Instructions:**\n",
"1. In a pot, heat the sesame oil over medium heat.\n",
"2. Add the grated ginger and sauté for about a minute until fragrant.\n",
"3. Pour in the vegetable broth and bring it to a simmer.\n",
"4. Add the miso paste, stirring until fully dissolved.\n",
"5. Add the tofu cubes, mushrooms, and seaweed to the broth and cook for about 5 minutes.\n",
"6. Stir in soy sauce if using, and add salt to taste.\n",
"7. Garnish with chopped green onions before serving.\n",
"\n",
"Enjoy your delicious and nutritious vegan miso soup! TERMINATE\n"
]
},
{
"data": {
"text/plain": [
"TaskResult(messages=[TextMessage(source='user', models_usage=None, metadata={}, created_at=datetime.datetime(2025, 6, 12, 17, 47, 19, 247083, tzinfo=datetime.timezone.utc), content='Write brief meal recipe with broth', type='TextMessage'), MemoryQueryEvent(source='assistant_agent', models_usage=None, metadata={}, created_at=datetime.datetime(2025, 6, 12, 17, 47, 19, 248736, tzinfo=datetime.timezone.utc), content=[MemoryContent(content='The weather should be in metric units', mime_type=<MemoryMimeType.TEXT: 'text/plain'>, metadata=None), MemoryContent(content='Meal recipe must be vegan', mime_type=<MemoryMimeType.TEXT: 'text/plain'>, metadata=None)], type='MemoryQueryEvent'), TextMessage(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=528, completion_tokens=233), metadata={}, created_at=datetime.datetime(2025, 6, 12, 17, 47, 26, 130554, tzinfo=datetime.timezone.utc), content=\"Here's another vegan broth-based recipe:\\n\\n**Vegan Miso Soup**\\n\\n**Ingredients:**\\n- 4 cups vegetable broth\\n- 3 tablespoons white miso paste\\n- 1 block firm tofu, cubed\\n- 1 cup mushrooms, sliced (shiitake or any variety you prefer)\\n- 2 green onions, chopped\\n- 1 tablespoon soy sauce (optional)\\n- 1/2 cup seaweed (such as wakame)\\n- 1 tablespoon sesame oil\\n- 1 tablespoon grated ginger\\n- Salt to taste\\n\\n**Instructions:**\\n1. In a pot, heat the sesame oil over medium heat.\\n2. Add the grated ginger and sauté for about a minute until fragrant.\\n3. Pour in the vegetable broth and bring it to a simmer.\\n4. Add the miso paste, stirring until fully dissolved.\\n5. Add the tofu cubes, mushrooms, and seaweed to the broth and cook for about 5 minutes.\\n6. Stir in soy sauce if using, and add salt to taste.\\n7. Garnish with chopped green onions before serving.\\n\\nEnjoy your delicious and nutritious vegan miso soup! TERMINATE\", type='TextMessage')], stop_reason=None)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"stream = assistant_agent.run_stream(task=\"Write brief meal recipe with broth\")\n",
"await Console(stream)"
@ -129,35 +218,59 @@
"\n",
"Currently the following example memory stores are available as part of the {py:class}`~autogen_ext` extensions package. \n",
"\n",
"- `autogen_ext.memory.chromadb.ChromaDBVectorMemory`: A memory store that uses a vector database to store and retrieve information. "
"- `autogen_ext.memory.chromadb.ChromaDBVectorMemory`: A memory store that uses a vector database to store and retrieve information. \n",
"\n",
"- `autogen_ext.memory.chromadb.SentenceTransformerEmbeddingFunctionConfig`: A configuration class for the SentenceTransformer embedding function used by the `ChromaDBVectorMemory` store. Note that other embedding functions such as `autogen_ext.memory.openai.OpenAIEmbeddingFunctionConfig` can also be used with the `ChromaDBVectorMemory` store.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"---------- TextMessage (user) ----------\n",
"What is the weather in New York?\n",
"---------- MemoryQueryEvent (assistant_agent) ----------\n",
"[MemoryContent(content='The weather should be in metric units', mime_type='MemoryMimeType.TEXT', metadata={'type': 'units', 'mime_type': 'MemoryMimeType.TEXT', 'category': 'preferences', 'score': 0.4342840313911438, 'id': 'd7ed6e42-0bf5-4ee8-b5b5-fbe06f583477'})]\n",
"---------- ToolCallRequestEvent (assistant_agent) ----------\n",
"[FunctionCall(id='call_ufpz7LGcn19ZroowyEraj9bd', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')]\n",
"---------- ToolCallExecutionEvent (assistant_agent) ----------\n",
"[FunctionExecutionResult(content='The weather in New York is 23 °C and Sunny.', name='get_weather', call_id='call_ufpz7LGcn19ZroowyEraj9bd', is_error=False)]\n",
"---------- ToolCallSummaryMessage (assistant_agent) ----------\n",
"The weather in New York is 23 °C and Sunny.\n"
]
}
],
"source": [
"import os\n",
"from pathlib import Path\n",
"import tempfile\n",
"\n",
"from autogen_agentchat.agents import AssistantAgent\n",
"from autogen_agentchat.ui import Console\n",
"from autogen_core.memory import MemoryContent, MemoryMimeType\n",
"from autogen_ext.memory.chromadb import ChromaDBVectorMemory, PersistentChromaDBVectorMemoryConfig\n",
"from autogen_ext.memory.chromadb import (\n",
" ChromaDBVectorMemory,\n",
" PersistentChromaDBVectorMemoryConfig,\n",
" SentenceTransformerEmbeddingFunctionConfig,\n",
")\n",
"from autogen_ext.models.openai import OpenAIChatCompletionClient\n",
"\n",
"# Initialize ChromaDB memory with custom config\n",
"# Use a temporary directory for ChromaDB persistence\n",
"with tempfile.TemporaryDirectory() as tmpdir:\n",
" chroma_user_memory = ChromaDBVectorMemory(\n",
" config=PersistentChromaDBVectorMemoryConfig(\n",
" collection_name=\"preferences\",\n",
" persistence_path=os.path.join(str(Path.home()), \".chromadb_autogen\"),\n",
" persistence_path=tmpdir, # Use the temp directory here\n",
" k=2, # Return top k results\n",
" score_threshold=0.4, # Minimum similarity score\n",
" embedding_function_config=SentenceTransformerEmbeddingFunctionConfig(\n",
" model_name=\"all-MiniLM-L6-v2\" # Use default model for testing\n",
" ),\n",
" )\n",
" )\n",
"# a HttpChromaDBVectorMemoryConfig is also supported for connecting to a remote ChromaDB server\n",
"\n",
" # Add user preferences to memory\n",
" await chroma_user_memory.add(\n",
" MemoryContent(\n",
@ -203,9 +316,20 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"'{\"provider\":\"autogen_ext.memory.chromadb.ChromaDBVectorMemory\",\"component_type\":\"memory\",\"version\":1,\"component_version\":1,\"description\":\"Store and retrieve memory using vector similarity search powered by ChromaDB.\",\"label\":\"ChromaDBVectorMemory\",\"config\":{\"client_type\":\"persistent\",\"collection_name\":\"preferences\",\"distance_metric\":\"cosine\",\"k\":2,\"score_threshold\":0.4,\"allow_reset\":false,\"tenant\":\"default_tenant\",\"database\":\"default_database\",\"embedding_function_config\":{\"function_type\":\"sentence_transformer\",\"model_name\":\"all-MiniLM-L6-v2\"},\"persistence_path\":\"/var/folders/wg/hgs_dt8n5lbd3gx3pq7k6lym0000gn/T/tmp9qcaqchy\"}}'"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chroma_user_memory.dump_component().model_dump_json()"
]
@ -434,7 +558,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"version": "3.11.13"
}
},
"nbformat": 4,

View File

@ -0,0 +1,21 @@
from ._chroma_configs import (
ChromaDBVectorMemoryConfig,
CustomEmbeddingFunctionConfig,
DefaultEmbeddingFunctionConfig,
HttpChromaDBVectorMemoryConfig,
OpenAIEmbeddingFunctionConfig,
PersistentChromaDBVectorMemoryConfig,
SentenceTransformerEmbeddingFunctionConfig,
)
from ._chromadb import ChromaDBVectorMemory
__all__ = [
"ChromaDBVectorMemory",
"ChromaDBVectorMemoryConfig",
"PersistentChromaDBVectorMemoryConfig",
"HttpChromaDBVectorMemoryConfig",
"DefaultEmbeddingFunctionConfig",
"SentenceTransformerEmbeddingFunctionConfig",
"OpenAIEmbeddingFunctionConfig",
"CustomEmbeddingFunctionConfig",
]

View File

@ -0,0 +1,148 @@
"""Configuration classes for ChromaDB vector memory."""
from typing import Any, Callable, Dict, Literal, Union
from pydantic import BaseModel, Field
from typing_extensions import Annotated
class DefaultEmbeddingFunctionConfig(BaseModel):
"""Configuration for the default ChromaDB embedding function.
Uses ChromaDB's default embedding function (Sentence Transformers all-MiniLM-L6-v2).
.. versionadded:: v0.4.1
Support for custom embedding functions in ChromaDB memory.
"""
function_type: Literal["default"] = "default"
class SentenceTransformerEmbeddingFunctionConfig(BaseModel):
"""Configuration for SentenceTransformer embedding functions.
Allows specifying a custom SentenceTransformer model for embeddings.
.. versionadded:: v0.4.1
Support for custom embedding functions in ChromaDB memory.
Args:
model_name (str): Name of the SentenceTransformer model to use.
Defaults to "all-MiniLM-L6-v2".
Example:
.. code-block:: python
config = SentenceTransformerEmbeddingFunctionConfig(model_name="paraphrase-multilingual-mpnet-base-v2")
"""
function_type: Literal["sentence_transformer"] = "sentence_transformer"
model_name: str = Field(default="all-MiniLM-L6-v2", description="SentenceTransformer model name to use")
class OpenAIEmbeddingFunctionConfig(BaseModel):
"""Configuration for OpenAI embedding functions.
Uses OpenAI's embedding API for generating embeddings.
.. versionadded:: v0.4.1
Support for custom embedding functions in ChromaDB memory.
Args:
api_key (str): OpenAI API key. If empty, will attempt to use environment variable.
model_name (str): OpenAI embedding model name. Defaults to "text-embedding-ada-002".
Example:
.. code-block:: python
config = OpenAIEmbeddingFunctionConfig(api_key="sk-...", model_name="text-embedding-3-small")
"""
function_type: Literal["openai"] = "openai"
api_key: str = Field(default="", description="OpenAI API key")
model_name: str = Field(default="text-embedding-ada-002", description="OpenAI embedding model name")
class CustomEmbeddingFunctionConfig(BaseModel):
"""Configuration for custom embedding functions.
Allows using a custom function that returns a ChromaDB-compatible embedding function.
.. versionadded:: v0.4.1
Support for custom embedding functions in ChromaDB memory.
.. warning::
Configurations containing custom functions are not serializable.
Args:
function (Callable): Function that returns a ChromaDB-compatible embedding function.
params (Dict[str, Any]): Parameters to pass to the function.
Example:
.. code-block:: python
def create_my_embedder(param1="default"):
# Return a ChromaDB-compatible embedding function
class MyCustomEmbeddingFunction(EmbeddingFunction):
def __call__(self, input: Documents) -> Embeddings:
# Custom embedding logic here
return embeddings
return MyCustomEmbeddingFunction(param1)
config = CustomEmbeddingFunctionConfig(function=create_my_embedder, params={"param1": "custom_value"})
"""
function_type: Literal["custom"] = "custom"
function: Callable[..., Any] = Field(description="Function that returns an embedding function")
params: Dict[str, Any] = Field(default_factory=dict, description="Parameters to pass to the function")
# Tagged union type for embedding function configurations
EmbeddingFunctionConfig = Annotated[
Union[
DefaultEmbeddingFunctionConfig,
SentenceTransformerEmbeddingFunctionConfig,
OpenAIEmbeddingFunctionConfig,
CustomEmbeddingFunctionConfig,
],
Field(discriminator="function_type"),
]
class ChromaDBVectorMemoryConfig(BaseModel):
"""Base configuration for ChromaDB-based memory implementation.
.. versionchanged:: v0.4.1
Added support for custom embedding functions via embedding_function_config.
"""
client_type: Literal["persistent", "http"]
collection_name: str = Field(default="memory_store", description="Name of the ChromaDB collection")
distance_metric: str = Field(default="cosine", description="Distance metric for similarity search")
k: int = Field(default=3, description="Number of results to return in queries")
score_threshold: float | None = Field(default=None, description="Minimum similarity score threshold")
allow_reset: bool = Field(default=False, description="Whether to allow resetting the ChromaDB client")
tenant: str = Field(default="default_tenant", description="Tenant to use")
database: str = Field(default="default_database", description="Database to use")
embedding_function_config: EmbeddingFunctionConfig = Field(
default_factory=DefaultEmbeddingFunctionConfig, description="Configuration for the embedding function"
)
class PersistentChromaDBVectorMemoryConfig(ChromaDBVectorMemoryConfig):
"""Configuration for persistent ChromaDB memory."""
client_type: Literal["persistent", "http"] = "persistent"
persistence_path: str = Field(default="./chroma_db", description="Path for persistent storage")
class HttpChromaDBVectorMemoryConfig(ChromaDBVectorMemoryConfig):
"""Configuration for HTTP ChromaDB memory."""
client_type: Literal["persistent", "http"] = "http"
host: str = Field(default="localhost", description="Host of the remote server")
port: int = Field(default=8000, description="Port of the remote server")
ssl: bool = Field(default=False, description="Whether to use HTTPS")
headers: Dict[str, str] | None = Field(default=None, description="Headers to send to the server")

View File

@ -1,6 +1,6 @@
import logging
import uuid
from typing import Any, Dict, List, Literal
from typing import Any, List
from autogen_core import CancellationToken, Component, Image
from autogen_core.memory import Memory, MemoryContent, MemoryMimeType, MemoryQueryResult, UpdateContextResult
@ -9,9 +9,18 @@ from autogen_core.models import SystemMessage
from chromadb import HttpClient, PersistentClient
from chromadb.api.models.Collection import Collection
from chromadb.api.types import Document, Metadata
from pydantic import BaseModel, Field
from typing_extensions import Self
from ._chroma_configs import (
ChromaDBVectorMemoryConfig,
CustomEmbeddingFunctionConfig,
DefaultEmbeddingFunctionConfig,
HttpChromaDBVectorMemoryConfig,
OpenAIEmbeddingFunctionConfig,
PersistentChromaDBVectorMemoryConfig,
SentenceTransformerEmbeddingFunctionConfig,
)
logger = logging.getLogger(__name__)
@ -23,36 +32,6 @@ except ImportError as e:
) from e
class ChromaDBVectorMemoryConfig(BaseModel):
"""Base configuration for ChromaDB-based memory implementation."""
client_type: Literal["persistent", "http"]
collection_name: str = Field(default="memory_store", description="Name of the ChromaDB collection")
distance_metric: str = Field(default="cosine", description="Distance metric for similarity search")
k: int = Field(default=3, description="Number of results to return in queries")
score_threshold: float | None = Field(default=None, description="Minimum similarity score threshold")
allow_reset: bool = Field(default=False, description="Whether to allow resetting the ChromaDB client")
tenant: str = Field(default="default_tenant", description="Tenant to use")
database: str = Field(default="default_database", description="Database to use")
class PersistentChromaDBVectorMemoryConfig(ChromaDBVectorMemoryConfig):
"""Configuration for persistent ChromaDB memory."""
client_type: Literal["persistent", "http"] = "persistent"
persistence_path: str = Field(default="./chroma_db", description="Path for persistent storage")
class HttpChromaDBVectorMemoryConfig(ChromaDBVectorMemoryConfig):
"""Configuration for HTTP ChromaDB memory."""
client_type: Literal["persistent", "http"] = "http"
host: str = Field(default="localhost", description="Host of the remote server")
port: int = Field(default=8000, description="Port of the remote server")
ssl: bool = Field(default=False, description="Whether to use HTTPS")
headers: Dict[str, str] | None = Field(default=None, description="Headers to send to the server")
class ChromaDBVectorMemory(Memory, Component[ChromaDBVectorMemoryConfig]):
"""
Store and retrieve memory using vector similarity search powered by ChromaDB.
@ -86,10 +65,15 @@ class ChromaDBVectorMemory(Memory, Component[ChromaDBVectorMemoryConfig]):
from pathlib import Path
from autogen_agentchat.agents import AssistantAgent
from autogen_core.memory import MemoryContent, MemoryMimeType
from autogen_ext.memory.chromadb import ChromaDBVectorMemory, PersistentChromaDBVectorMemoryConfig
from autogen_ext.memory.chromadb import (
ChromaDBVectorMemory,
PersistentChromaDBVectorMemoryConfig,
SentenceTransformerEmbeddingFunctionConfig,
OpenAIEmbeddingFunctionConfig,
)
from autogen_ext.models.openai import OpenAIChatCompletionClient
# Initialize ChromaDB memory with custom config
# Initialize ChromaDB memory with default embedding function
memory = ChromaDBVectorMemory(
config=PersistentChromaDBVectorMemoryConfig(
collection_name="user_preferences",
@ -99,6 +83,28 @@ class ChromaDBVectorMemory(Memory, Component[ChromaDBVectorMemoryConfig]):
)
)
# Using a custom SentenceTransformer model
memory_custom_st = ChromaDBVectorMemory(
config=PersistentChromaDBVectorMemoryConfig(
collection_name="multilingual_memory",
persistence_path=os.path.join(str(Path.home()), ".chromadb_autogen"),
embedding_function_config=SentenceTransformerEmbeddingFunctionConfig(
model_name="paraphrase-multilingual-mpnet-base-v2"
),
)
)
# Using OpenAI embeddings
memory_openai = ChromaDBVectorMemory(
config=PersistentChromaDBVectorMemoryConfig(
collection_name="openai_memory",
persistence_path=os.path.join(str(Path.home()), ".chromadb_autogen"),
embedding_function_config=OpenAIEmbeddingFunctionConfig(
api_key="sk-...", model_name="text-embedding-3-small"
),
)
)
# Add user preferences to memory
await memory.add(
MemoryContent(
@ -138,6 +144,55 @@ class ChromaDBVectorMemory(Memory, Component[ChromaDBVectorMemoryConfig]):
"""Get the name of the ChromaDB collection."""
return self._config.collection_name
def _create_embedding_function(self) -> Any:
"""Create an embedding function based on the configuration.
Returns:
A ChromaDB-compatible embedding function.
Raises:
ValueError: If the embedding function type is unsupported.
ImportError: If required dependencies are not installed.
"""
try:
from chromadb.utils import embedding_functions
except ImportError as e:
raise ImportError(
"ChromaDB embedding functions not available. Ensure chromadb is properly installed."
) from e
config = self._config.embedding_function_config
if isinstance(config, DefaultEmbeddingFunctionConfig):
return embedding_functions.DefaultEmbeddingFunction()
elif isinstance(config, SentenceTransformerEmbeddingFunctionConfig):
try:
return embedding_functions.SentenceTransformerEmbeddingFunction(model_name=config.model_name)
except Exception as e:
raise ImportError(
f"Failed to create SentenceTransformer embedding function with model '{config.model_name}'. "
f"Ensure sentence-transformers is installed and the model is available. Error: {e}"
) from e
elif isinstance(config, OpenAIEmbeddingFunctionConfig):
try:
return embedding_functions.OpenAIEmbeddingFunction(api_key=config.api_key, model_name=config.model_name)
except Exception as e:
raise ImportError(
f"Failed to create OpenAI embedding function with model '{config.model_name}'. "
f"Ensure openai is installed and API key is valid. Error: {e}"
) from e
elif isinstance(config, CustomEmbeddingFunctionConfig):
try:
return config.function(**config.params)
except Exception as e:
raise ValueError(f"Failed to create custom embedding function. Error: {e}") from e
else:
raise ValueError(f"Unsupported embedding function config type: {type(config)}")
def _ensure_initialized(self) -> None:
"""Ensure ChromaDB client and collection are initialized."""
if self._client is None:
@ -171,8 +226,14 @@ class ChromaDBVectorMemory(Memory, Component[ChromaDBVectorMemoryConfig]):
if self._collection is None:
try:
# Create embedding function
embedding_function = self._create_embedding_function()
# Create or get collection with embedding function
self._collection = self._client.get_or_create_collection(
name=self._config.collection_name, metadata={"distance_metric": self._config.distance_metric}
name=self._config.collection_name,
metadata={"distance_metric": self._config.distance_metric},
embedding_function=embedding_function,
)
except Exception as e:
logger.error(f"Failed to get/create collection: {e}")

View File

@ -4,7 +4,21 @@ import pytest
from autogen_core.memory import MemoryContent, MemoryMimeType
from autogen_core.model_context import BufferedChatCompletionContext
from autogen_core.models import UserMessage
from autogen_ext.memory.chromadb import ChromaDBVectorMemory, PersistentChromaDBVectorMemoryConfig
from autogen_ext.memory.chromadb import (
ChromaDBVectorMemory,
CustomEmbeddingFunctionConfig,
DefaultEmbeddingFunctionConfig,
HttpChromaDBVectorMemoryConfig,
OpenAIEmbeddingFunctionConfig,
PersistentChromaDBVectorMemoryConfig,
SentenceTransformerEmbeddingFunctionConfig,
)
# Skip all tests if ChromaDB is not available
try:
import chromadb # pyright: ignore[reportUnusedImport]
except ImportError:
pytest.skip("ChromaDB not available", allow_module_level=True)
@pytest.fixture
@ -240,3 +254,189 @@ async def test_component_serialization(base_config: PersistentChromaDBVectorMemo
await memory.close()
await loaded_memory.close()
@pytest.mark.asyncio
def test_http_config(tmp_path: Path) -> None:
"""Test HTTP ChromaDB configuration."""
config = HttpChromaDBVectorMemoryConfig(
collection_name="test_http",
host="localhost",
port=8000,
ssl=False,
headers={"Authorization": "Bearer test-token"},
)
assert config.client_type == "http"
assert config.host == "localhost"
assert config.port == 8000
assert config.ssl is False
assert config.headers == {"Authorization": "Bearer test-token"}
# ============================================================================
# Embedding Function Configuration Tests
# ============================================================================
@pytest.mark.asyncio
async def test_default_embedding_function(tmp_path: Path) -> None:
"""Test ChromaDB memory with default embedding function."""
config = PersistentChromaDBVectorMemoryConfig(
collection_name="test_default_embedding",
allow_reset=True,
persistence_path=str(tmp_path / "chroma_db_default"),
embedding_function_config=DefaultEmbeddingFunctionConfig(),
)
memory = ChromaDBVectorMemory(config=config)
await memory.clear()
# Add test content
await memory.add(
MemoryContent(
content="Default embedding function test content",
mime_type=MemoryMimeType.TEXT,
metadata={"test": "default_embedding"},
)
)
# Query and verify
results = await memory.query("default embedding test")
assert len(results.results) > 0
assert any("Default embedding" in str(r.content) for r in results.results)
await memory.close()
@pytest.mark.asyncio
async def test_sentence_transformer_embedding_function(tmp_path: Path) -> None:
"""Test ChromaDB memory with SentenceTransformer embedding function."""
config = PersistentChromaDBVectorMemoryConfig(
collection_name="test_st_embedding",
allow_reset=True,
persistence_path=str(tmp_path / "chroma_db_st"),
embedding_function_config=SentenceTransformerEmbeddingFunctionConfig(
model_name="all-MiniLM-L6-v2" # Use default model for testing
),
)
memory = ChromaDBVectorMemory(config=config)
await memory.clear()
# Add test content
await memory.add(
MemoryContent(
content="SentenceTransformer embedding function test content",
mime_type=MemoryMimeType.TEXT,
metadata={"test": "sentence_transformer"},
)
)
# Query and verify
results = await memory.query("SentenceTransformer embedding test")
assert len(results.results) > 0
assert any("SentenceTransformer" in str(r.content) for r in results.results)
await memory.close()
@pytest.mark.asyncio
async def test_custom_embedding_function(tmp_path: Path) -> None:
"""Test ChromaDB memory with custom embedding function."""
from collections.abc import Sequence
class MockEmbeddingFunction:
def __call__(self, input: Sequence[str]) -> list[list[float]]:
# Return a batch of embeddings (list of lists)
return [[0.0] * 384 for _ in input]
config = PersistentChromaDBVectorMemoryConfig(
collection_name="test_custom_embedding",
allow_reset=True,
persistence_path=str(tmp_path / "chroma_db_custom"),
embedding_function_config=CustomEmbeddingFunctionConfig(function=MockEmbeddingFunction, params={}),
)
memory = ChromaDBVectorMemory(config=config)
await memory.clear()
await memory.add(
MemoryContent(
content="Custom embedding function test content",
mime_type=MemoryMimeType.TEXT,
metadata={"test": "custom_embedding"},
)
)
results = await memory.query("custom embedding test")
assert len(results.results) > 0
assert any("Custom embedding" in str(r.content) for r in results.results)
await memory.close()
@pytest.mark.asyncio
async def test_openai_embedding_function(tmp_path: Path) -> None:
"""Test OpenAI embedding function configuration (without actual API call)."""
config = PersistentChromaDBVectorMemoryConfig(
collection_name="test_openai_embedding",
allow_reset=True,
persistence_path=str(tmp_path / "chroma_db_openai"),
embedding_function_config=OpenAIEmbeddingFunctionConfig(
api_key="test-key", model_name="text-embedding-3-small"
),
)
# Just test that the config is valid - don't actually try to use OpenAI API
assert config.embedding_function_config.function_type == "openai"
assert config.embedding_function_config.api_key == "test-key"
assert config.embedding_function_config.model_name == "text-embedding-3-small"
@pytest.mark.asyncio
async def test_embedding_function_error_handling(tmp_path: Path) -> None:
"""Test error handling for embedding function configurations."""
def failing_embedding_function() -> None:
"""A function that raises an error."""
raise ValueError("Test embedding function error")
config = PersistentChromaDBVectorMemoryConfig(
collection_name="test_error_embedding",
allow_reset=True,
persistence_path=str(tmp_path / "chroma_db_error"),
embedding_function_config=CustomEmbeddingFunctionConfig(function=failing_embedding_function, params={}),
)
memory = ChromaDBVectorMemory(config=config)
# Should raise an error when trying to initialize
with pytest.raises((ValueError, Exception)): # Catch ValueError or any other exception
await memory.add(MemoryContent(content="This should fail", mime_type=MemoryMimeType.TEXT))
await memory.close()
def test_embedding_function_config_validation() -> None:
"""Test validation of embedding function configurations."""
# Test default config
default_config = DefaultEmbeddingFunctionConfig()
assert default_config.function_type == "default"
# Test SentenceTransformer config
st_config = SentenceTransformerEmbeddingFunctionConfig(model_name="test-model")
assert st_config.function_type == "sentence_transformer"
assert st_config.model_name == "test-model"
# Test OpenAI config
openai_config = OpenAIEmbeddingFunctionConfig(api_key="test-key", model_name="test-model")
assert openai_config.function_type == "openai"
assert openai_config.api_key == "test-key"
assert openai_config.model_name == "test-model"
# Test custom config
def dummy_function() -> None:
return None
custom_config = CustomEmbeddingFunctionConfig(function=dummy_function, params={"test": "value"})
assert custom_config.function_type == "custom"
assert custom_config.function == dummy_function
assert custom_config.params == {"test": "value"}