LightRAG/lightrag/kg/qdrant_impl.py

353 lines
12 KiB
Python
Raw Normal View History

2025-02-10 00:57:28 +08:00
import asyncio
import os
2025-03-04 15:50:53 +08:00
from typing import Any, final, List
2025-02-10 00:57:28 +08:00
from dataclasses import dataclass
import numpy as np
import hashlib
import uuid
from ..utils import logger
from ..base import BaseVectorStorage
import configparser
2025-02-16 15:08:50 +01:00
import pipmaster as pm
2025-02-19 19:51:39 +01:00
if not pm.is_installed("qdrant-client"):
pm.install("qdrant-client")
2025-02-16 15:08:50 +01:00
2025-03-31 23:22:27 +08:00
from qdrant_client import QdrantClient, models # type: ignore
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
2025-02-11 03:29:40 +08:00
2025-03-31 23:22:27 +08:00
2025-02-10 00:57:28 +08:00
def compute_mdhash_id_for_qdrant(
content: str, prefix: str = "", style: str = "simple"
) -> str:
"""
Generate a UUID based on the content and support multiple formats.
:param content: The content used to generate the UUID.
:param style: The format of the UUID, optional values are "simple", "hyphenated", "urn".
:return: A UUID that meets the requirements of Qdrant.
"""
if not content:
raise ValueError("Content must not be empty.")
# Use the hash value of the content to create a UUID.
hashed_content = hashlib.sha256((prefix + content).encode("utf-8")).digest()
generated_uuid = uuid.UUID(bytes=hashed_content[:16], version=4)
# Return the UUID according to the specified format.
if style == "simple":
return generated_uuid.hex
elif style == "hyphenated":
return str(generated_uuid)
elif style == "urn":
return f"urn:uuid:{generated_uuid}"
else:
raise ValueError("Invalid style. Choose from 'simple', 'hyphenated', or 'urn'.")
@final
2025-02-10 00:57:28 +08:00
@dataclass
class QdrantVectorDBStorage(BaseVectorStorage):
@staticmethod
def create_collection_if_not_exist(
client: QdrantClient, collection_name: str, **kwargs
):
if client.collection_exists(collection_name):
return
client.create_collection(collection_name, **kwargs)
def __post_init__(self):
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = kwargs.get("cosine_better_than_threshold")
if cosine_threshold is None:
2025-02-13 04:12:00 +08:00
raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold
2025-02-10 00:57:28 +08:00
self._client = QdrantClient(
2025-02-11 03:29:40 +08:00
url=os.environ.get(
"QDRANT_URL", config.get("qdrant", "uri", fallback=None)
),
api_key=os.environ.get(
"QDRANT_API_KEY", config.get("qdrant", "apikey", fallback=None)
),
2025-02-10 00:57:28 +08:00
)
self._max_batch_size = self.global_config["embedding_batch_num"]
QdrantVectorDBStorage.create_collection_if_not_exist(
self._client,
self.namespace,
vectors_config=models.VectorParams(
size=self.embedding_func.embedding_dim, distance=models.Distance.COSINE
),
)
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.debug(f"Inserting {len(data)} to {self.namespace}")
2025-02-19 22:22:41 +01:00
if not data:
return
2025-05-03 00:46:28 +08:00
import time
2025-05-03 00:46:28 +08:00
current_time = int(time.time())
2025-05-03 00:46:28 +08:00
2025-02-10 00:57:28 +08:00
list_data = [
{
"id": k,
"created_at": current_time,
2025-02-10 00:57:28 +08:00
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
}
for k, v in data.items()
]
contents = [v["content"] for v in data.values()]
batches = [
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
embedding_tasks = [self.embedding_func(batch) for batch in batches]
2025-02-10 00:57:28 +08:00
embeddings_list = await asyncio.gather(*embedding_tasks)
embeddings = np.concatenate(embeddings_list)
list_points = []
for i, d in enumerate(list_data):
list_points.append(
models.PointStruct(
id=compute_mdhash_id_for_qdrant(d["id"]),
vector=embeddings[i],
payload=d,
)
)
results = self._client.upsert(
collection_name=self.namespace, points=list_points, wait=True
)
return results
async def query(
self, query: str, top_k: int, ids: list[str] | None = None
) -> list[dict[str, Any]]:
embedding = await self.embedding_func(
[query], _priority=5
) # higher priority for query
2025-02-10 00:57:28 +08:00
results = self._client.search(
collection_name=self.namespace,
query_vector=embedding[0],
limit=top_k,
with_payload=True,
score_threshold=self.cosine_better_than_threshold,
2025-02-10 00:57:28 +08:00
)
2025-02-10 00:57:28 +08:00
logger.debug(f"query result: {results}")
return [
{
2025-05-03 00:46:28 +08:00
**dp.payload,
"distance": dp.score,
2025-05-03 00:46:28 +08:00
"created_at": dp.payload.get("created_at"),
}
for dp in results
]
async def index_done_callback(self) -> None:
2025-02-16 16:04:07 +01:00
# Qdrant handles persistence automatically
pass
2025-02-16 13:55:30 +01:00
2025-03-04 15:50:53 +08:00
async def delete(self, ids: List[str]) -> None:
"""Delete vectors with specified IDs
2025-03-04 15:53:20 +08:00
2025-03-04 15:50:53 +08:00
Args:
ids: List of vector IDs to be deleted
"""
try:
# Convert regular ids to Qdrant compatible ids
qdrant_ids = [compute_mdhash_id_for_qdrant(id) for id in ids]
# Delete points from the collection
self._client.delete(
collection_name=self.namespace,
points_selector=models.PointIdsList(
points=qdrant_ids,
),
2025-03-04 15:53:20 +08:00
wait=True,
)
logger.debug(
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
2025-03-04 15:50:53 +08:00
)
except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
async def delete_entity(self, entity_name: str) -> None:
2025-03-04 15:50:53 +08:00
"""Delete an entity by name
2025-03-04 15:53:20 +08:00
2025-03-04 15:50:53 +08:00
Args:
entity_name: Name of the entity to delete
"""
try:
# Generate the entity ID
entity_id = compute_mdhash_id_for_qdrant(entity_name, prefix="ent-")
2025-03-04 15:53:20 +08:00
logger.debug(
f"Attempting to delete entity {entity_name} with ID {entity_id}"
)
2025-03-04 15:50:53 +08:00
# Delete the entity point from the collection
self._client.delete(
collection_name=self.namespace,
points_selector=models.PointIdsList(
points=[entity_id],
),
2025-03-04 15:53:20 +08:00
wait=True,
2025-03-04 15:50:53 +08:00
)
logger.debug(f"Successfully deleted entity {entity_name}")
except Exception as e:
logger.error(f"Error deleting entity {entity_name}: {e}")
async def delete_entity_relation(self, entity_name: str) -> None:
2025-03-04 15:50:53 +08:00
"""Delete all relations associated with an entity
2025-03-04 15:53:20 +08:00
2025-03-04 15:50:53 +08:00
Args:
entity_name: Name of the entity whose relations should be deleted
"""
try:
# Find relations where the entity is either source or target
results = self._client.scroll(
collection_name=self.namespace,
scroll_filter=models.Filter(
should=[
models.FieldCondition(
2025-03-04 15:53:20 +08:00
key="src_id", match=models.MatchValue(value=entity_name)
2025-03-04 15:50:53 +08:00
),
models.FieldCondition(
2025-03-04 15:53:20 +08:00
key="tgt_id", match=models.MatchValue(value=entity_name)
),
2025-03-04 15:50:53 +08:00
]
),
with_payload=True,
2025-03-04 15:53:20 +08:00
limit=1000, # Adjust as needed for your use case
2025-03-04 15:50:53 +08:00
)
2025-03-04 15:53:20 +08:00
2025-03-04 15:50:53 +08:00
# Extract points that need to be deleted
relation_points = results[0]
ids_to_delete = [point.id for point in relation_points]
2025-03-04 15:53:20 +08:00
2025-03-04 15:50:53 +08:00
if ids_to_delete:
# Delete the relations
self._client.delete(
collection_name=self.namespace,
points_selector=models.PointIdsList(
points=ids_to_delete,
),
2025-03-04 15:53:20 +08:00
wait=True,
)
logger.debug(
f"Deleted {len(ids_to_delete)} relations for {entity_name}"
2025-03-04 15:50:53 +08:00
)
else:
logger.debug(f"No relations found for entity {entity_name}")
except Exception as e:
logger.error(f"Error deleting relations for {entity_name}: {e}")
2025-03-07 14:39:06 +08:00
async def get_by_id(self, id: str) -> dict[str, Any] | None:
"""Get vector data by its ID
Args:
id: The unique identifier of the vector
Returns:
The vector data if found, or None if not found
"""
try:
# Convert to Qdrant compatible ID
qdrant_id = compute_mdhash_id_for_qdrant(id)
2025-03-31 23:22:27 +08:00
# Retrieve the point by ID
result = self._client.retrieve(
collection_name=self.namespace,
ids=[qdrant_id],
with_payload=True,
)
2025-03-31 23:22:27 +08:00
if not result:
return None
2025-05-03 00:46:28 +08:00
# Ensure the result contains created_at field
payload = result[0].payload
if "created_at" not in payload:
payload["created_at"] = None
2025-05-03 00:46:28 +08:00
return payload
except Exception as e:
logger.error(f"Error retrieving vector data for ID {id}: {e}")
return None
2025-03-31 23:22:27 +08:00
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
"""Get multiple vector data by their IDs
Args:
ids: List of unique identifiers
Returns:
List of vector data objects that were found
"""
if not ids:
return []
2025-03-31 23:22:27 +08:00
try:
# Convert to Qdrant compatible IDs
qdrant_ids = [compute_mdhash_id_for_qdrant(id) for id in ids]
2025-03-31 23:22:27 +08:00
# Retrieve the points by IDs
results = self._client.retrieve(
collection_name=self.namespace,
ids=qdrant_ids,
with_payload=True,
)
2025-05-03 00:46:28 +08:00
# Ensure each result contains created_at field
payloads = []
for point in results:
payload = point.payload
if "created_at" not in payload:
payload["created_at"] = None
payloads.append(payload)
2025-05-03 00:46:28 +08:00
return payloads
except Exception as e:
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
return []
2025-03-31 23:22:27 +08:00
async def drop(self) -> dict[str, str]:
"""Drop all vector data from storage and clean up resources
2025-03-31 23:22:27 +08:00
This method will delete all data from the Qdrant collection.
2025-03-31 23:22:27 +08:00
Returns:
dict[str, str]: Operation status and message
- On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"}
"""
try:
# Delete the collection and recreate it
if self._client.collection_exists(self.namespace):
self._client.delete_collection(self.namespace)
2025-03-31 23:22:27 +08:00
# Recreate the collection
QdrantVectorDBStorage.create_collection_if_not_exist(
self._client,
self.namespace,
vectors_config=models.VectorParams(
2025-03-31 23:22:27 +08:00
size=self.embedding_func.embedding_dim,
distance=models.Distance.COSINE,
),
)
2025-03-31 23:22:27 +08:00
logger.info(
f"Process {os.getpid()} drop Qdrant collection {self.namespace}"
)
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"Error dropping Qdrant collection {self.namespace}: {e}")
return {"status": "error", "message": str(e)}