LightRAG/lightrag/kg/qdrant_impl.py

489 lines
17 KiB
Python
Raw Normal View History

2025-02-10 00:57:28 +08:00
import asyncio
import os
2025-03-04 15:50:53 +08:00
from typing import Any, final, List
2025-02-10 00:57:28 +08:00
from dataclasses import dataclass
import numpy as np
import hashlib
import uuid
from ..utils import logger
from ..base import BaseVectorStorage
from ..kg.shared_storage import get_data_init_lock, get_storage_lock
import configparser
2025-02-16 15:08:50 +01:00
import pipmaster as pm
2025-02-19 19:51:39 +01:00
if not pm.is_installed("qdrant-client"):
pm.install("qdrant-client")
2025-02-16 15:08:50 +01:00
2025-03-31 23:22:27 +08:00
from qdrant_client import QdrantClient, models # type: ignore
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
2025-02-11 03:29:40 +08:00
2025-03-31 23:22:27 +08:00
2025-02-10 00:57:28 +08:00
def compute_mdhash_id_for_qdrant(
content: str, prefix: str = "", style: str = "simple"
) -> str:
"""
Generate a UUID based on the content and support multiple formats.
:param content: The content used to generate the UUID.
:param style: The format of the UUID, optional values are "simple", "hyphenated", "urn".
:return: A UUID that meets the requirements of Qdrant.
"""
if not content:
raise ValueError("Content must not be empty.")
# Use the hash value of the content to create a UUID.
hashed_content = hashlib.sha256((prefix + content).encode("utf-8")).digest()
generated_uuid = uuid.UUID(bytes=hashed_content[:16], version=4)
# Return the UUID according to the specified format.
if style == "simple":
return generated_uuid.hex
elif style == "hyphenated":
return str(generated_uuid)
elif style == "urn":
return f"urn:uuid:{generated_uuid}"
else:
raise ValueError("Invalid style. Choose from 'simple', 'hyphenated', or 'urn'.")
@final
2025-02-10 00:57:28 +08:00
@dataclass
class QdrantVectorDBStorage(BaseVectorStorage):
def __init__(
self, namespace, global_config, embedding_func, workspace=None, meta_fields=None
):
super().__init__(
namespace=namespace,
workspace=workspace or "",
global_config=global_config,
embedding_func=embedding_func,
meta_fields=meta_fields or set(),
)
self.__post_init__()
2025-02-10 00:57:28 +08:00
@staticmethod
def create_collection_if_not_exist(
client: QdrantClient, collection_name: str, **kwargs
):
if client.collection_exists(collection_name):
return
client.create_collection(collection_name, **kwargs)
def __post_init__(self):
# Check for QDRANT_WORKSPACE environment variable first (higher priority)
# This allows administrators to force a specific workspace for all Qdrant storage instances
qdrant_workspace = os.environ.get("QDRANT_WORKSPACE")
if qdrant_workspace and qdrant_workspace.strip():
# Use environment variable value, overriding the passed workspace parameter
effective_workspace = qdrant_workspace.strip()
logger.info(
f"Using QDRANT_WORKSPACE environment variable: '{effective_workspace}' (overriding passed workspace: '{self.workspace}')"
)
else:
# Use the workspace parameter passed during initialization
effective_workspace = self.workspace
if effective_workspace:
logger.debug(
f"Using passed workspace parameter: '{effective_workspace}'"
)
# Build final_namespace with workspace prefix for data isolation
# Keep original namespace unchanged for type detection logic
if effective_workspace:
self.final_namespace = f"{effective_workspace}_{self.namespace}"
logger.debug(
f"Final namespace with workspace prefix: '{self.final_namespace}'"
)
else:
# When workspace is empty, final_namespace equals original namespace
self.final_namespace = self.namespace
self.workspace = "_"
logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'")
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = kwargs.get("cosine_better_than_threshold")
if cosine_threshold is None:
2025-02-13 04:12:00 +08:00
raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold
# Initialize client as None - will be created in initialize() method
self._client = None
2025-02-10 00:57:28 +08:00
self._max_batch_size = self.global_config["embedding_batch_num"]
self._initialized = False
async def initialize(self):
"""Initialize Qdrant collection"""
async with get_data_init_lock():
if self._initialized:
return
try:
# Create QdrantClient if not already created
if self._client is None:
self._client = QdrantClient(
url=os.environ.get(
"QDRANT_URL", config.get("qdrant", "uri", fallback=None)
),
api_key=os.environ.get(
"QDRANT_API_KEY",
config.get("qdrant", "apikey", fallback=None),
),
)
logger.debug(
f"[{self.workspace}] QdrantClient created successfully"
)
# Create collection if not exists
QdrantVectorDBStorage.create_collection_if_not_exist(
self._client,
self.final_namespace,
vectors_config=models.VectorParams(
size=self.embedding_func.embedding_dim,
distance=models.Distance.COSINE,
),
)
self._initialized = True
logger.info(
f"[{self.workspace}] Qdrant collection '{self.namespace}' initialized successfully"
)
except Exception as e:
logger.error(
f"[{self.workspace}] Failed to initialize Qdrant collection '{self.namespace}': {e}"
)
raise
2025-02-10 00:57:28 +08:00
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
2025-02-19 22:22:41 +01:00
if not data:
return
2025-05-03 00:46:28 +08:00
import time
2025-05-03 00:46:28 +08:00
current_time = int(time.time())
2025-05-03 00:46:28 +08:00
2025-02-10 00:57:28 +08:00
list_data = [
{
"id": k,
"created_at": current_time,
2025-02-10 00:57:28 +08:00
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
}
for k, v in data.items()
]
contents = [v["content"] for v in data.values()]
batches = [
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
embedding_tasks = [self.embedding_func(batch) for batch in batches]
2025-02-10 00:57:28 +08:00
embeddings_list = await asyncio.gather(*embedding_tasks)
embeddings = np.concatenate(embeddings_list)
list_points = []
for i, d in enumerate(list_data):
list_points.append(
models.PointStruct(
id=compute_mdhash_id_for_qdrant(d["id"]),
vector=embeddings[i],
payload=d,
)
)
results = self._client.upsert(
collection_name=self.final_namespace, points=list_points, wait=True
2025-02-10 00:57:28 +08:00
)
return results
async def query(
self, query: str, top_k: int, query_embedding: list[float] = None
) -> list[dict[str, Any]]:
if query_embedding is not None:
embedding = query_embedding
else:
embedding_result = await self.embedding_func(
[query], _priority=5
) # higher priority for query
embedding = embedding_result[0]
2025-02-10 00:57:28 +08:00
results = self._client.search(
collection_name=self.final_namespace,
query_vector=embedding,
2025-02-10 00:57:28 +08:00
limit=top_k,
with_payload=True,
score_threshold=self.cosine_better_than_threshold,
2025-02-10 00:57:28 +08:00
)
# logger.debug(f"[{self.workspace}] query result: {results}")
return [
{
2025-05-03 00:46:28 +08:00
**dp.payload,
"distance": dp.score,
2025-05-03 00:46:28 +08:00
"created_at": dp.payload.get("created_at"),
}
for dp in results
]
async def index_done_callback(self) -> None:
2025-02-16 16:04:07 +01:00
# Qdrant handles persistence automatically
pass
2025-02-16 13:55:30 +01:00
2025-03-04 15:50:53 +08:00
async def delete(self, ids: List[str]) -> None:
"""Delete vectors with specified IDs
2025-03-04 15:53:20 +08:00
2025-03-04 15:50:53 +08:00
Args:
ids: List of vector IDs to be deleted
"""
try:
# Convert regular ids to Qdrant compatible ids
qdrant_ids = [compute_mdhash_id_for_qdrant(id) for id in ids]
# Delete points from the collection
self._client.delete(
collection_name=self.final_namespace,
2025-03-04 15:50:53 +08:00
points_selector=models.PointIdsList(
points=qdrant_ids,
),
2025-03-04 15:53:20 +08:00
wait=True,
)
logger.debug(
f"[{self.workspace}] Successfully deleted {len(ids)} vectors from {self.namespace}"
2025-03-04 15:50:53 +08:00
)
except Exception as e:
logger.error(
f"[{self.workspace}] Error while deleting vectors from {self.namespace}: {e}"
)
2025-03-04 15:50:53 +08:00
async def delete_entity(self, entity_name: str) -> None:
2025-03-04 15:50:53 +08:00
"""Delete an entity by name
2025-03-04 15:53:20 +08:00
2025-03-04 15:50:53 +08:00
Args:
entity_name: Name of the entity to delete
"""
try:
# Generate the entity ID
entity_id = compute_mdhash_id_for_qdrant(entity_name, prefix="ent-")
# logger.debug(
# f"[{self.workspace}] Attempting to delete entity {entity_name} with ID {entity_id}"
# )
2025-03-04 15:53:20 +08:00
2025-03-04 15:50:53 +08:00
# Delete the entity point from the collection
self._client.delete(
collection_name=self.final_namespace,
2025-03-04 15:50:53 +08:00
points_selector=models.PointIdsList(
points=[entity_id],
),
2025-03-04 15:53:20 +08:00
wait=True,
2025-03-04 15:50:53 +08:00
)
logger.debug(
f"[{self.workspace}] Successfully deleted entity {entity_name}"
)
2025-03-04 15:50:53 +08:00
except Exception as e:
logger.error(f"[{self.workspace}] Error deleting entity {entity_name}: {e}")
async def delete_entity_relation(self, entity_name: str) -> None:
2025-03-04 15:50:53 +08:00
"""Delete all relations associated with an entity
2025-03-04 15:53:20 +08:00
2025-03-04 15:50:53 +08:00
Args:
entity_name: Name of the entity whose relations should be deleted
"""
try:
# Find relations where the entity is either source or target
results = self._client.scroll(
collection_name=self.final_namespace,
2025-03-04 15:50:53 +08:00
scroll_filter=models.Filter(
should=[
models.FieldCondition(
2025-03-04 15:53:20 +08:00
key="src_id", match=models.MatchValue(value=entity_name)
2025-03-04 15:50:53 +08:00
),
models.FieldCondition(
2025-03-04 15:53:20 +08:00
key="tgt_id", match=models.MatchValue(value=entity_name)
),
2025-03-04 15:50:53 +08:00
]
),
with_payload=True,
2025-03-04 15:53:20 +08:00
limit=1000, # Adjust as needed for your use case
2025-03-04 15:50:53 +08:00
)
2025-03-04 15:53:20 +08:00
2025-03-04 15:50:53 +08:00
# Extract points that need to be deleted
relation_points = results[0]
ids_to_delete = [point.id for point in relation_points]
2025-03-04 15:53:20 +08:00
2025-03-04 15:50:53 +08:00
if ids_to_delete:
# Delete the relations
self._client.delete(
collection_name=self.final_namespace,
2025-03-04 15:50:53 +08:00
points_selector=models.PointIdsList(
points=ids_to_delete,
),
2025-03-04 15:53:20 +08:00
wait=True,
)
logger.debug(
f"[{self.workspace}] Deleted {len(ids_to_delete)} relations for {entity_name}"
2025-03-04 15:50:53 +08:00
)
else:
logger.debug(
f"[{self.workspace}] No relations found for entity {entity_name}"
)
2025-03-04 15:50:53 +08:00
except Exception as e:
logger.error(
f"[{self.workspace}] Error deleting relations for {entity_name}: {e}"
)
2025-03-07 14:39:06 +08:00
async def get_by_id(self, id: str) -> dict[str, Any] | None:
"""Get vector data by its ID
Args:
id: The unique identifier of the vector
Returns:
The vector data if found, or None if not found
"""
try:
# Convert to Qdrant compatible ID
qdrant_id = compute_mdhash_id_for_qdrant(id)
2025-03-31 23:22:27 +08:00
# Retrieve the point by ID
result = self._client.retrieve(
collection_name=self.final_namespace,
ids=[qdrant_id],
with_payload=True,
)
2025-03-31 23:22:27 +08:00
if not result:
return None
2025-05-03 00:46:28 +08:00
# Ensure the result contains created_at field
payload = result[0].payload
if "created_at" not in payload:
payload["created_at"] = None
2025-05-03 00:46:28 +08:00
return payload
except Exception as e:
logger.error(
f"[{self.workspace}] Error retrieving vector data for ID {id}: {e}"
)
return None
2025-03-31 23:22:27 +08:00
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
"""Get multiple vector data by their IDs
Args:
ids: List of unique identifiers
Returns:
List of vector data objects that were found
"""
if not ids:
return []
2025-03-31 23:22:27 +08:00
try:
# Convert to Qdrant compatible IDs
qdrant_ids = [compute_mdhash_id_for_qdrant(id) for id in ids]
2025-03-31 23:22:27 +08:00
# Retrieve the points by IDs
results = self._client.retrieve(
collection_name=self.final_namespace,
ids=qdrant_ids,
with_payload=True,
)
2025-05-03 00:46:28 +08:00
# Ensure each result contains created_at field
payloads = []
for point in results:
payload = point.payload
if "created_at" not in payload:
payload["created_at"] = None
payloads.append(payload)
2025-05-03 00:46:28 +08:00
return payloads
except Exception as e:
logger.error(
f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}"
)
return []
2025-03-31 23:22:27 +08:00
async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]:
"""Get vectors by their IDs, returning only ID and vector data for efficiency
Args:
ids: List of unique identifiers
Returns:
Dictionary mapping IDs to their vector embeddings
Format: {id: [vector_values], ...}
"""
if not ids:
return {}
try:
# Convert to Qdrant compatible IDs
qdrant_ids = [compute_mdhash_id_for_qdrant(id) for id in ids]
# Retrieve the points by IDs with vectors
results = self._client.retrieve(
collection_name=self.final_namespace,
ids=qdrant_ids,
with_vectors=True, # Important: request vectors
with_payload=True,
)
vectors_dict = {}
for point in results:
if point and point.vector is not None and point.payload:
# Get original ID from payload
original_id = point.payload.get("id")
if original_id:
# Convert numpy array to list if needed
vector_data = point.vector
if isinstance(vector_data, np.ndarray):
vector_data = vector_data.tolist()
vectors_dict[original_id] = vector_data
return vectors_dict
except Exception as e:
logger.error(
f"[{self.workspace}] Error retrieving vectors by IDs from {self.namespace}: {e}"
)
return {}
async def drop(self) -> dict[str, str]:
"""Drop all vector data from storage and clean up resources
2025-03-31 23:22:27 +08:00
This method will delete all data from the Qdrant collection.
2025-03-31 23:22:27 +08:00
Returns:
dict[str, str]: Operation status and message
- On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"}
"""
async with get_storage_lock():
try:
# Delete the collection and recreate it
if self._client.collection_exists(self.final_namespace):
self._client.delete_collection(self.final_namespace)
# Recreate the collection
QdrantVectorDBStorage.create_collection_if_not_exist(
self._client,
self.final_namespace,
vectors_config=models.VectorParams(
size=self.embedding_func.embedding_dim,
distance=models.Distance.COSINE,
),
)
2025-03-31 23:22:27 +08:00
logger.info(
f"[{self.workspace}] Process {os.getpid()} drop Qdrant collection {self.namespace}"
)
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(
f"[{self.workspace}] Error dropping Qdrant collection {self.namespace}: {e}"
)
return {"status": "error", "message": str(e)}