LightRAG/lightrag/kg/chroma_impl.py
2025-02-16 14:38:09 +01:00

203 lines
7.9 KiB
Python

import asyncio
from dataclasses import dataclass
from typing import Any, final
import numpy as np
from lightrag.base import BaseVectorStorage
from lightrag.utils import logger
try:
from chromadb import HttpClient, PersistentClient
from chromadb.config import Settings
except ImportError as e:
raise ImportError(
"chromadb library is not installed. Please install it to proceed."
) from e
@final
@dataclass
class ChromaVectorDBStorage(BaseVectorStorage):
"""ChromaDB vector storage implementation."""
def __post_init__(self):
try:
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = config.get("cosine_better_than_threshold")
if cosine_threshold is None:
raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold
user_collection_settings = config.get("collection_settings", {})
# Default HNSW index settings for ChromaDB
default_collection_settings = {
# Distance metric used for similarity search (cosine similarity)
"hnsw:space": "cosine",
# Number of nearest neighbors to explore during index construction
# Higher values = better recall but slower indexing
"hnsw:construction_ef": 128,
# Number of nearest neighbors to explore during search
# Higher values = better recall but slower search
"hnsw:search_ef": 128,
# Number of connections per node in the HNSW graph
# Higher values = better recall but more memory usage
"hnsw:M": 16,
# Number of vectors to process in one batch during indexing
"hnsw:batch_size": 100,
# Number of updates before forcing index synchronization
# Lower values = more frequent syncs but slower indexing
"hnsw:sync_threshold": 1000,
}
collection_settings = {
**default_collection_settings,
**user_collection_settings,
}
local_path = config.get("local_path", None)
if local_path:
self._client = PersistentClient(
path=local_path,
settings=Settings(
allow_reset=True,
anonymized_telemetry=False,
),
)
else:
auth_provider = config.get(
"auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider"
)
auth_credentials = config.get("auth_token", "secret-token")
headers = {}
if "token_authn" in auth_provider:
headers = {
config.get(
"auth_header_name", "X-Chroma-Token"
): auth_credentials
}
elif "basic_authn" in auth_provider:
auth_credentials = config.get("auth_credentials", "admin:admin")
self._client = HttpClient(
host=config.get("host", "localhost"),
port=config.get("port", 8000),
headers=headers,
settings=Settings(
chroma_api_impl="rest",
chroma_client_auth_provider=auth_provider,
chroma_client_auth_credentials=auth_credentials,
allow_reset=True,
anonymized_telemetry=False,
),
)
self._collection = self._client.get_or_create_collection(
name=self.namespace,
metadata={
**collection_settings,
"dimension": self.embedding_func.embedding_dim,
},
)
# Use batch size from collection settings if specified
self._max_batch_size = self.global_config.get(
"embedding_batch_num", collection_settings.get("hnsw:batch_size", 32)
)
except Exception as e:
logger.error(f"ChromaDB initialization failed: {str(e)}")
raise
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
if not data:
logger.warning("Empty data provided to vector DB")
return []
try:
ids = list(data.keys())
documents = [v["content"] for v in data.values()]
metadatas = [
{k: v for k, v in item.items() if k in self.meta_fields}
or {"_default": "true"}
for item in data.values()
]
# Process in batches
batches = [
documents[i : i + self._max_batch_size]
for i in range(0, len(documents), self._max_batch_size)
]
embedding_tasks = [self.embedding_func(batch) for batch in batches]
embeddings_list = []
# Pre-allocate embeddings_list with known size
embeddings_list = [None] * len(embedding_tasks)
# Use asyncio.gather instead of as_completed if order doesn't matter
embeddings_results = await asyncio.gather(*embedding_tasks)
embeddings_list = list(embeddings_results)
embeddings = np.concatenate(embeddings_list)
# Upsert in batches
for i in range(0, len(ids), self._max_batch_size):
batch_slice = slice(i, i + self._max_batch_size)
self._collection.upsert(
ids=ids[batch_slice],
embeddings=embeddings[batch_slice].tolist(),
documents=documents[batch_slice],
metadatas=metadatas[batch_slice],
)
return ids
except Exception as e:
logger.error(f"Error during ChromaDB upsert: {str(e)}")
raise
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
try:
embedding = await self.embedding_func([query])
results = self._collection.query(
query_embeddings=embedding.tolist()
if not isinstance(embedding, list)
else embedding,
n_results=top_k * 2, # Request more results to allow for filtering
include=["metadatas", "distances", "documents"],
)
# Filter results by cosine similarity threshold and take top k
# We request 2x results initially to have enough after filtering
# ChromaDB returns cosine similarity (1 = identical, 0 = orthogonal)
# We convert to distance (0 = identical, 1 = orthogonal) via (1 - similarity)
# Only keep results with distance below threshold, then take top k
return [
{
"id": results["ids"][0][i],
"distance": 1 - results["distances"][0][i],
"content": results["documents"][0][i],
**results["metadatas"][0][i],
}
for i in range(len(results["ids"][0]))
if (1 - results["distances"][0][i]) >= self.cosine_better_than_threshold
][:top_k]
except Exception as e:
logger.error(f"Error during ChromaDB query: {str(e)}")
raise
async def index_done_callback(self) -> None:
# ChromaDB handles persistence automatically
pass
async def delete_entity(self, entity_name: str) -> None:
"""Delete a single entity by its name"""
raise NotImplementedError
async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete relations for a given entity by scanning metadata"""
raise NotImplementedError