mirror of
https://github.com/microsoft/autogen.git
synced 2025-10-27 07:49:00 +00:00
## Why are these changes needed? This PR adds support for configurable embedding functions in ChromaDBVectorMemory, addressing the need for users to customize how embeddings are generated for vector similarity search. Currently, ChromaDB memory is limited to default embedding functions, which restricts flexibility for different use cases that may require specific embedding models or custom embedding logic. The implementation allows users to: - Use different SentenceTransformer models for domain-specific embeddings - Integrate with OpenAI's embedding API for consistent embedding generation - Define custom embedding functions for specialized requirements - Maintain backward compatibility with existing default behavior ## Related issue number Closes #6267 ## Checks - [x] I've included any doc changes needed for <https://microsoft.github.io/autogen/>. See <https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to build and test documentation locally. - [x] I've added tests corresponding to the changes introduced in this PR. - [x] I've made sure all auto checks have passed. --------- Co-authored-by: Victor Dibia <victordibia@microsoft.com> Co-authored-by: Victor Dibia <victor.dibia@gmail.com>
This commit is contained in:
parent
150ea0192d
commit
67ebeeda0e
File diff suppressed because one or more lines are too long
@ -0,0 +1,21 @@
|
||||
from ._chroma_configs import (
|
||||
ChromaDBVectorMemoryConfig,
|
||||
CustomEmbeddingFunctionConfig,
|
||||
DefaultEmbeddingFunctionConfig,
|
||||
HttpChromaDBVectorMemoryConfig,
|
||||
OpenAIEmbeddingFunctionConfig,
|
||||
PersistentChromaDBVectorMemoryConfig,
|
||||
SentenceTransformerEmbeddingFunctionConfig,
|
||||
)
|
||||
from ._chromadb import ChromaDBVectorMemory
|
||||
|
||||
__all__ = [
|
||||
"ChromaDBVectorMemory",
|
||||
"ChromaDBVectorMemoryConfig",
|
||||
"PersistentChromaDBVectorMemoryConfig",
|
||||
"HttpChromaDBVectorMemoryConfig",
|
||||
"DefaultEmbeddingFunctionConfig",
|
||||
"SentenceTransformerEmbeddingFunctionConfig",
|
||||
"OpenAIEmbeddingFunctionConfig",
|
||||
"CustomEmbeddingFunctionConfig",
|
||||
]
|
||||
@ -0,0 +1,148 @@
|
||||
"""Configuration classes for ChromaDB vector memory."""
|
||||
|
||||
from typing import Any, Callable, Dict, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
class DefaultEmbeddingFunctionConfig(BaseModel):
|
||||
"""Configuration for the default ChromaDB embedding function.
|
||||
|
||||
Uses ChromaDB's default embedding function (Sentence Transformers all-MiniLM-L6-v2).
|
||||
|
||||
.. versionadded:: v0.4.1
|
||||
Support for custom embedding functions in ChromaDB memory.
|
||||
"""
|
||||
|
||||
function_type: Literal["default"] = "default"
|
||||
|
||||
|
||||
class SentenceTransformerEmbeddingFunctionConfig(BaseModel):
|
||||
"""Configuration for SentenceTransformer embedding functions.
|
||||
|
||||
Allows specifying a custom SentenceTransformer model for embeddings.
|
||||
|
||||
.. versionadded:: v0.4.1
|
||||
Support for custom embedding functions in ChromaDB memory.
|
||||
|
||||
Args:
|
||||
model_name (str): Name of the SentenceTransformer model to use.
|
||||
Defaults to "all-MiniLM-L6-v2".
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
config = SentenceTransformerEmbeddingFunctionConfig(model_name="paraphrase-multilingual-mpnet-base-v2")
|
||||
"""
|
||||
|
||||
function_type: Literal["sentence_transformer"] = "sentence_transformer"
|
||||
model_name: str = Field(default="all-MiniLM-L6-v2", description="SentenceTransformer model name to use")
|
||||
|
||||
|
||||
class OpenAIEmbeddingFunctionConfig(BaseModel):
|
||||
"""Configuration for OpenAI embedding functions.
|
||||
|
||||
Uses OpenAI's embedding API for generating embeddings.
|
||||
|
||||
.. versionadded:: v0.4.1
|
||||
Support for custom embedding functions in ChromaDB memory.
|
||||
|
||||
Args:
|
||||
api_key (str): OpenAI API key. If empty, will attempt to use environment variable.
|
||||
model_name (str): OpenAI embedding model name. Defaults to "text-embedding-ada-002".
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
config = OpenAIEmbeddingFunctionConfig(api_key="sk-...", model_name="text-embedding-3-small")
|
||||
"""
|
||||
|
||||
function_type: Literal["openai"] = "openai"
|
||||
api_key: str = Field(default="", description="OpenAI API key")
|
||||
model_name: str = Field(default="text-embedding-ada-002", description="OpenAI embedding model name")
|
||||
|
||||
|
||||
class CustomEmbeddingFunctionConfig(BaseModel):
|
||||
"""Configuration for custom embedding functions.
|
||||
|
||||
Allows using a custom function that returns a ChromaDB-compatible embedding function.
|
||||
|
||||
.. versionadded:: v0.4.1
|
||||
Support for custom embedding functions in ChromaDB memory.
|
||||
|
||||
.. warning::
|
||||
Configurations containing custom functions are not serializable.
|
||||
|
||||
Args:
|
||||
function (Callable): Function that returns a ChromaDB-compatible embedding function.
|
||||
params (Dict[str, Any]): Parameters to pass to the function.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
def create_my_embedder(param1="default"):
|
||||
# Return a ChromaDB-compatible embedding function
|
||||
class MyCustomEmbeddingFunction(EmbeddingFunction):
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
# Custom embedding logic here
|
||||
return embeddings
|
||||
|
||||
return MyCustomEmbeddingFunction(param1)
|
||||
|
||||
|
||||
config = CustomEmbeddingFunctionConfig(function=create_my_embedder, params={"param1": "custom_value"})
|
||||
"""
|
||||
|
||||
function_type: Literal["custom"] = "custom"
|
||||
function: Callable[..., Any] = Field(description="Function that returns an embedding function")
|
||||
params: Dict[str, Any] = Field(default_factory=dict, description="Parameters to pass to the function")
|
||||
|
||||
|
||||
# Tagged union type for embedding function configurations
|
||||
EmbeddingFunctionConfig = Annotated[
|
||||
Union[
|
||||
DefaultEmbeddingFunctionConfig,
|
||||
SentenceTransformerEmbeddingFunctionConfig,
|
||||
OpenAIEmbeddingFunctionConfig,
|
||||
CustomEmbeddingFunctionConfig,
|
||||
],
|
||||
Field(discriminator="function_type"),
|
||||
]
|
||||
|
||||
|
||||
class ChromaDBVectorMemoryConfig(BaseModel):
|
||||
"""Base configuration for ChromaDB-based memory implementation.
|
||||
|
||||
.. versionchanged:: v0.4.1
|
||||
Added support for custom embedding functions via embedding_function_config.
|
||||
"""
|
||||
|
||||
client_type: Literal["persistent", "http"]
|
||||
collection_name: str = Field(default="memory_store", description="Name of the ChromaDB collection")
|
||||
distance_metric: str = Field(default="cosine", description="Distance metric for similarity search")
|
||||
k: int = Field(default=3, description="Number of results to return in queries")
|
||||
score_threshold: float | None = Field(default=None, description="Minimum similarity score threshold")
|
||||
allow_reset: bool = Field(default=False, description="Whether to allow resetting the ChromaDB client")
|
||||
tenant: str = Field(default="default_tenant", description="Tenant to use")
|
||||
database: str = Field(default="default_database", description="Database to use")
|
||||
embedding_function_config: EmbeddingFunctionConfig = Field(
|
||||
default_factory=DefaultEmbeddingFunctionConfig, description="Configuration for the embedding function"
|
||||
)
|
||||
|
||||
|
||||
class PersistentChromaDBVectorMemoryConfig(ChromaDBVectorMemoryConfig):
|
||||
"""Configuration for persistent ChromaDB memory."""
|
||||
|
||||
client_type: Literal["persistent", "http"] = "persistent"
|
||||
persistence_path: str = Field(default="./chroma_db", description="Path for persistent storage")
|
||||
|
||||
|
||||
class HttpChromaDBVectorMemoryConfig(ChromaDBVectorMemoryConfig):
|
||||
"""Configuration for HTTP ChromaDB memory."""
|
||||
|
||||
client_type: Literal["persistent", "http"] = "http"
|
||||
host: str = Field(default="localhost", description="Host of the remote server")
|
||||
port: int = Field(default=8000, description="Port of the remote server")
|
||||
ssl: bool = Field(default=False, description="Whether to use HTTPS")
|
||||
headers: Dict[str, str] | None = Field(default=None, description="Headers to send to the server")
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Literal
|
||||
from typing import Any, List
|
||||
|
||||
from autogen_core import CancellationToken, Component, Image
|
||||
from autogen_core.memory import Memory, MemoryContent, MemoryMimeType, MemoryQueryResult, UpdateContextResult
|
||||
@ -9,9 +9,18 @@ from autogen_core.models import SystemMessage
|
||||
from chromadb import HttpClient, PersistentClient
|
||||
from chromadb.api.models.Collection import Collection
|
||||
from chromadb.api.types import Document, Metadata
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Self
|
||||
|
||||
from ._chroma_configs import (
|
||||
ChromaDBVectorMemoryConfig,
|
||||
CustomEmbeddingFunctionConfig,
|
||||
DefaultEmbeddingFunctionConfig,
|
||||
HttpChromaDBVectorMemoryConfig,
|
||||
OpenAIEmbeddingFunctionConfig,
|
||||
PersistentChromaDBVectorMemoryConfig,
|
||||
SentenceTransformerEmbeddingFunctionConfig,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -23,36 +32,6 @@ except ImportError as e:
|
||||
) from e
|
||||
|
||||
|
||||
class ChromaDBVectorMemoryConfig(BaseModel):
|
||||
"""Base configuration for ChromaDB-based memory implementation."""
|
||||
|
||||
client_type: Literal["persistent", "http"]
|
||||
collection_name: str = Field(default="memory_store", description="Name of the ChromaDB collection")
|
||||
distance_metric: str = Field(default="cosine", description="Distance metric for similarity search")
|
||||
k: int = Field(default=3, description="Number of results to return in queries")
|
||||
score_threshold: float | None = Field(default=None, description="Minimum similarity score threshold")
|
||||
allow_reset: bool = Field(default=False, description="Whether to allow resetting the ChromaDB client")
|
||||
tenant: str = Field(default="default_tenant", description="Tenant to use")
|
||||
database: str = Field(default="default_database", description="Database to use")
|
||||
|
||||
|
||||
class PersistentChromaDBVectorMemoryConfig(ChromaDBVectorMemoryConfig):
|
||||
"""Configuration for persistent ChromaDB memory."""
|
||||
|
||||
client_type: Literal["persistent", "http"] = "persistent"
|
||||
persistence_path: str = Field(default="./chroma_db", description="Path for persistent storage")
|
||||
|
||||
|
||||
class HttpChromaDBVectorMemoryConfig(ChromaDBVectorMemoryConfig):
|
||||
"""Configuration for HTTP ChromaDB memory."""
|
||||
|
||||
client_type: Literal["persistent", "http"] = "http"
|
||||
host: str = Field(default="localhost", description="Host of the remote server")
|
||||
port: int = Field(default=8000, description="Port of the remote server")
|
||||
ssl: bool = Field(default=False, description="Whether to use HTTPS")
|
||||
headers: Dict[str, str] | None = Field(default=None, description="Headers to send to the server")
|
||||
|
||||
|
||||
class ChromaDBVectorMemory(Memory, Component[ChromaDBVectorMemoryConfig]):
|
||||
"""
|
||||
Store and retrieve memory using vector similarity search powered by ChromaDB.
|
||||
@ -86,10 +65,15 @@ class ChromaDBVectorMemory(Memory, Component[ChromaDBVectorMemoryConfig]):
|
||||
from pathlib import Path
|
||||
from autogen_agentchat.agents import AssistantAgent
|
||||
from autogen_core.memory import MemoryContent, MemoryMimeType
|
||||
from autogen_ext.memory.chromadb import ChromaDBVectorMemory, PersistentChromaDBVectorMemoryConfig
|
||||
from autogen_ext.memory.chromadb import (
|
||||
ChromaDBVectorMemory,
|
||||
PersistentChromaDBVectorMemoryConfig,
|
||||
SentenceTransformerEmbeddingFunctionConfig,
|
||||
OpenAIEmbeddingFunctionConfig,
|
||||
)
|
||||
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
# Initialize ChromaDB memory with custom config
|
||||
# Initialize ChromaDB memory with default embedding function
|
||||
memory = ChromaDBVectorMemory(
|
||||
config=PersistentChromaDBVectorMemoryConfig(
|
||||
collection_name="user_preferences",
|
||||
@ -99,6 +83,28 @@ class ChromaDBVectorMemory(Memory, Component[ChromaDBVectorMemoryConfig]):
|
||||
)
|
||||
)
|
||||
|
||||
# Using a custom SentenceTransformer model
|
||||
memory_custom_st = ChromaDBVectorMemory(
|
||||
config=PersistentChromaDBVectorMemoryConfig(
|
||||
collection_name="multilingual_memory",
|
||||
persistence_path=os.path.join(str(Path.home()), ".chromadb_autogen"),
|
||||
embedding_function_config=SentenceTransformerEmbeddingFunctionConfig(
|
||||
model_name="paraphrase-multilingual-mpnet-base-v2"
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Using OpenAI embeddings
|
||||
memory_openai = ChromaDBVectorMemory(
|
||||
config=PersistentChromaDBVectorMemoryConfig(
|
||||
collection_name="openai_memory",
|
||||
persistence_path=os.path.join(str(Path.home()), ".chromadb_autogen"),
|
||||
embedding_function_config=OpenAIEmbeddingFunctionConfig(
|
||||
api_key="sk-...", model_name="text-embedding-3-small"
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Add user preferences to memory
|
||||
await memory.add(
|
||||
MemoryContent(
|
||||
@ -138,6 +144,55 @@ class ChromaDBVectorMemory(Memory, Component[ChromaDBVectorMemoryConfig]):
|
||||
"""Get the name of the ChromaDB collection."""
|
||||
return self._config.collection_name
|
||||
|
||||
def _create_embedding_function(self) -> Any:
|
||||
"""Create an embedding function based on the configuration.
|
||||
|
||||
Returns:
|
||||
A ChromaDB-compatible embedding function.
|
||||
|
||||
Raises:
|
||||
ValueError: If the embedding function type is unsupported.
|
||||
ImportError: If required dependencies are not installed.
|
||||
"""
|
||||
try:
|
||||
from chromadb.utils import embedding_functions
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"ChromaDB embedding functions not available. Ensure chromadb is properly installed."
|
||||
) from e
|
||||
|
||||
config = self._config.embedding_function_config
|
||||
|
||||
if isinstance(config, DefaultEmbeddingFunctionConfig):
|
||||
return embedding_functions.DefaultEmbeddingFunction()
|
||||
|
||||
elif isinstance(config, SentenceTransformerEmbeddingFunctionConfig):
|
||||
try:
|
||||
return embedding_functions.SentenceTransformerEmbeddingFunction(model_name=config.model_name)
|
||||
except Exception as e:
|
||||
raise ImportError(
|
||||
f"Failed to create SentenceTransformer embedding function with model '{config.model_name}'. "
|
||||
f"Ensure sentence-transformers is installed and the model is available. Error: {e}"
|
||||
) from e
|
||||
|
||||
elif isinstance(config, OpenAIEmbeddingFunctionConfig):
|
||||
try:
|
||||
return embedding_functions.OpenAIEmbeddingFunction(api_key=config.api_key, model_name=config.model_name)
|
||||
except Exception as e:
|
||||
raise ImportError(
|
||||
f"Failed to create OpenAI embedding function with model '{config.model_name}'. "
|
||||
f"Ensure openai is installed and API key is valid. Error: {e}"
|
||||
) from e
|
||||
|
||||
elif isinstance(config, CustomEmbeddingFunctionConfig):
|
||||
try:
|
||||
return config.function(**config.params)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to create custom embedding function. Error: {e}") from e
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported embedding function config type: {type(config)}")
|
||||
|
||||
def _ensure_initialized(self) -> None:
|
||||
"""Ensure ChromaDB client and collection are initialized."""
|
||||
if self._client is None:
|
||||
@ -171,8 +226,14 @@ class ChromaDBVectorMemory(Memory, Component[ChromaDBVectorMemoryConfig]):
|
||||
|
||||
if self._collection is None:
|
||||
try:
|
||||
# Create embedding function
|
||||
embedding_function = self._create_embedding_function()
|
||||
|
||||
# Create or get collection with embedding function
|
||||
self._collection = self._client.get_or_create_collection(
|
||||
name=self._config.collection_name, metadata={"distance_metric": self._config.distance_metric}
|
||||
name=self._config.collection_name,
|
||||
metadata={"distance_metric": self._config.distance_metric},
|
||||
embedding_function=embedding_function,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get/create collection: {e}")
|
||||
@ -4,7 +4,21 @@ import pytest
|
||||
from autogen_core.memory import MemoryContent, MemoryMimeType
|
||||
from autogen_core.model_context import BufferedChatCompletionContext
|
||||
from autogen_core.models import UserMessage
|
||||
from autogen_ext.memory.chromadb import ChromaDBVectorMemory, PersistentChromaDBVectorMemoryConfig
|
||||
from autogen_ext.memory.chromadb import (
|
||||
ChromaDBVectorMemory,
|
||||
CustomEmbeddingFunctionConfig,
|
||||
DefaultEmbeddingFunctionConfig,
|
||||
HttpChromaDBVectorMemoryConfig,
|
||||
OpenAIEmbeddingFunctionConfig,
|
||||
PersistentChromaDBVectorMemoryConfig,
|
||||
SentenceTransformerEmbeddingFunctionConfig,
|
||||
)
|
||||
|
||||
# Skip all tests if ChromaDB is not available
|
||||
try:
|
||||
import chromadb # pyright: ignore[reportUnusedImport]
|
||||
except ImportError:
|
||||
pytest.skip("ChromaDB not available", allow_module_level=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -240,3 +254,189 @@ async def test_component_serialization(base_config: PersistentChromaDBVectorMemo
|
||||
|
||||
await memory.close()
|
||||
await loaded_memory.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
def test_http_config(tmp_path: Path) -> None:
|
||||
"""Test HTTP ChromaDB configuration."""
|
||||
config = HttpChromaDBVectorMemoryConfig(
|
||||
collection_name="test_http",
|
||||
host="localhost",
|
||||
port=8000,
|
||||
ssl=False,
|
||||
headers={"Authorization": "Bearer test-token"},
|
||||
)
|
||||
|
||||
assert config.client_type == "http"
|
||||
assert config.host == "localhost"
|
||||
assert config.port == 8000
|
||||
assert config.ssl is False
|
||||
assert config.headers == {"Authorization": "Bearer test-token"}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Embedding Function Configuration Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_embedding_function(tmp_path: Path) -> None:
|
||||
"""Test ChromaDB memory with default embedding function."""
|
||||
config = PersistentChromaDBVectorMemoryConfig(
|
||||
collection_name="test_default_embedding",
|
||||
allow_reset=True,
|
||||
persistence_path=str(tmp_path / "chroma_db_default"),
|
||||
embedding_function_config=DefaultEmbeddingFunctionConfig(),
|
||||
)
|
||||
|
||||
memory = ChromaDBVectorMemory(config=config)
|
||||
await memory.clear()
|
||||
|
||||
# Add test content
|
||||
await memory.add(
|
||||
MemoryContent(
|
||||
content="Default embedding function test content",
|
||||
mime_type=MemoryMimeType.TEXT,
|
||||
metadata={"test": "default_embedding"},
|
||||
)
|
||||
)
|
||||
|
||||
# Query and verify
|
||||
results = await memory.query("default embedding test")
|
||||
assert len(results.results) > 0
|
||||
assert any("Default embedding" in str(r.content) for r in results.results)
|
||||
|
||||
await memory.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sentence_transformer_embedding_function(tmp_path: Path) -> None:
|
||||
"""Test ChromaDB memory with SentenceTransformer embedding function."""
|
||||
config = PersistentChromaDBVectorMemoryConfig(
|
||||
collection_name="test_st_embedding",
|
||||
allow_reset=True,
|
||||
persistence_path=str(tmp_path / "chroma_db_st"),
|
||||
embedding_function_config=SentenceTransformerEmbeddingFunctionConfig(
|
||||
model_name="all-MiniLM-L6-v2" # Use default model for testing
|
||||
),
|
||||
)
|
||||
|
||||
memory = ChromaDBVectorMemory(config=config)
|
||||
await memory.clear()
|
||||
|
||||
# Add test content
|
||||
await memory.add(
|
||||
MemoryContent(
|
||||
content="SentenceTransformer embedding function test content",
|
||||
mime_type=MemoryMimeType.TEXT,
|
||||
metadata={"test": "sentence_transformer"},
|
||||
)
|
||||
)
|
||||
|
||||
# Query and verify
|
||||
results = await memory.query("SentenceTransformer embedding test")
|
||||
assert len(results.results) > 0
|
||||
assert any("SentenceTransformer" in str(r.content) for r in results.results)
|
||||
|
||||
await memory.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_embedding_function(tmp_path: Path) -> None:
|
||||
"""Test ChromaDB memory with custom embedding function."""
|
||||
from collections.abc import Sequence
|
||||
|
||||
class MockEmbeddingFunction:
|
||||
def __call__(self, input: Sequence[str]) -> list[list[float]]:
|
||||
# Return a batch of embeddings (list of lists)
|
||||
return [[0.0] * 384 for _ in input]
|
||||
|
||||
config = PersistentChromaDBVectorMemoryConfig(
|
||||
collection_name="test_custom_embedding",
|
||||
allow_reset=True,
|
||||
persistence_path=str(tmp_path / "chroma_db_custom"),
|
||||
embedding_function_config=CustomEmbeddingFunctionConfig(function=MockEmbeddingFunction, params={}),
|
||||
)
|
||||
memory = ChromaDBVectorMemory(config=config)
|
||||
await memory.clear()
|
||||
await memory.add(
|
||||
MemoryContent(
|
||||
content="Custom embedding function test content",
|
||||
mime_type=MemoryMimeType.TEXT,
|
||||
metadata={"test": "custom_embedding"},
|
||||
)
|
||||
)
|
||||
results = await memory.query("custom embedding test")
|
||||
assert len(results.results) > 0
|
||||
assert any("Custom embedding" in str(r.content) for r in results.results)
|
||||
await memory.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_embedding_function(tmp_path: Path) -> None:
|
||||
"""Test OpenAI embedding function configuration (without actual API call)."""
|
||||
config = PersistentChromaDBVectorMemoryConfig(
|
||||
collection_name="test_openai_embedding",
|
||||
allow_reset=True,
|
||||
persistence_path=str(tmp_path / "chroma_db_openai"),
|
||||
embedding_function_config=OpenAIEmbeddingFunctionConfig(
|
||||
api_key="test-key", model_name="text-embedding-3-small"
|
||||
),
|
||||
)
|
||||
|
||||
# Just test that the config is valid - don't actually try to use OpenAI API
|
||||
assert config.embedding_function_config.function_type == "openai"
|
||||
assert config.embedding_function_config.api_key == "test-key"
|
||||
assert config.embedding_function_config.model_name == "text-embedding-3-small"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedding_function_error_handling(tmp_path: Path) -> None:
|
||||
"""Test error handling for embedding function configurations."""
|
||||
|
||||
def failing_embedding_function() -> None:
|
||||
"""A function that raises an error."""
|
||||
raise ValueError("Test embedding function error")
|
||||
|
||||
config = PersistentChromaDBVectorMemoryConfig(
|
||||
collection_name="test_error_embedding",
|
||||
allow_reset=True,
|
||||
persistence_path=str(tmp_path / "chroma_db_error"),
|
||||
embedding_function_config=CustomEmbeddingFunctionConfig(function=failing_embedding_function, params={}),
|
||||
)
|
||||
|
||||
memory = ChromaDBVectorMemory(config=config)
|
||||
|
||||
# Should raise an error when trying to initialize
|
||||
with pytest.raises((ValueError, Exception)): # Catch ValueError or any other exception
|
||||
await memory.add(MemoryContent(content="This should fail", mime_type=MemoryMimeType.TEXT))
|
||||
|
||||
await memory.close()
|
||||
|
||||
|
||||
def test_embedding_function_config_validation() -> None:
|
||||
"""Test validation of embedding function configurations."""
|
||||
|
||||
# Test default config
|
||||
default_config = DefaultEmbeddingFunctionConfig()
|
||||
assert default_config.function_type == "default"
|
||||
|
||||
# Test SentenceTransformer config
|
||||
st_config = SentenceTransformerEmbeddingFunctionConfig(model_name="test-model")
|
||||
assert st_config.function_type == "sentence_transformer"
|
||||
assert st_config.model_name == "test-model"
|
||||
|
||||
# Test OpenAI config
|
||||
openai_config = OpenAIEmbeddingFunctionConfig(api_key="test-key", model_name="test-model")
|
||||
assert openai_config.function_type == "openai"
|
||||
assert openai_config.api_key == "test-key"
|
||||
assert openai_config.model_name == "test-model"
|
||||
|
||||
# Test custom config
|
||||
def dummy_function() -> None:
|
||||
return None
|
||||
|
||||
custom_config = CustomEmbeddingFunctionConfig(function=dummy_function, params={"test": "value"})
|
||||
assert custom_config.function_type == "custom"
|
||||
assert custom_config.function == dummy_function
|
||||
assert custom_config.params == {"test": "value"}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user