LightRAG/lightrag/kg/qdrant_impl.py

278 lines
9.5 KiB
Python
Raw Normal View History

2025-02-10 00:57:28 +08:00
import asyncio
import os
2025-03-04 15:50:53 +08:00
from typing import Any, final, List
2025-02-10 00:57:28 +08:00
from dataclasses import dataclass
import numpy as np
import hashlib
import uuid
from ..utils import logger
from ..base import BaseVectorStorage
import configparser
2025-02-10 00:57:28 +08:00
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
2025-02-16 15:08:50 +01:00
import pipmaster as pm
2025-02-19 19:51:39 +01:00
if not pm.is_installed("qdrant-client"):
pm.install("qdrant-client")
2025-02-16 15:08:50 +01:00
2025-02-19 19:51:39 +01:00
from qdrant_client import QdrantClient, models
2025-02-11 03:29:40 +08:00
2025-02-10 00:57:28 +08:00
def compute_mdhash_id_for_qdrant(
content: str, prefix: str = "", style: str = "simple"
) -> str:
"""
Generate a UUID based on the content and support multiple formats.
:param content: The content used to generate the UUID.
:param style: The format of the UUID, optional values are "simple", "hyphenated", "urn".
:return: A UUID that meets the requirements of Qdrant.
"""
if not content:
raise ValueError("Content must not be empty.")
# Use the hash value of the content to create a UUID.
hashed_content = hashlib.sha256((prefix + content).encode("utf-8")).digest()
generated_uuid = uuid.UUID(bytes=hashed_content[:16], version=4)
# Return the UUID according to the specified format.
if style == "simple":
return generated_uuid.hex
elif style == "hyphenated":
return str(generated_uuid)
elif style == "urn":
return f"urn:uuid:{generated_uuid}"
else:
raise ValueError("Invalid style. Choose from 'simple', 'hyphenated', or 'urn'.")
@final
2025-02-10 00:57:28 +08:00
@dataclass
class QdrantVectorDBStorage(BaseVectorStorage):
@staticmethod
def create_collection_if_not_exist(
client: QdrantClient, collection_name: str, **kwargs
):
if client.collection_exists(collection_name):
return
client.create_collection(collection_name, **kwargs)
def __post_init__(self):
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = kwargs.get("cosine_better_than_threshold")
if cosine_threshold is None:
2025-02-13 04:12:00 +08:00
raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold
2025-02-10 00:57:28 +08:00
self._client = QdrantClient(
2025-02-11 03:29:40 +08:00
url=os.environ.get(
"QDRANT_URL", config.get("qdrant", "uri", fallback=None)
),
api_key=os.environ.get(
"QDRANT_API_KEY", config.get("qdrant", "apikey", fallback=None)
),
2025-02-10 00:57:28 +08:00
)
self._max_batch_size = self.global_config["embedding_batch_num"]
QdrantVectorDBStorage.create_collection_if_not_exist(
self._client,
self.namespace,
vectors_config=models.VectorParams(
size=self.embedding_func.embedding_dim, distance=models.Distance.COSINE
),
)
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
2025-02-19 22:22:41 +01:00
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
2025-02-10 00:57:28 +08:00
list_data = [
{
"id": k,
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
}
for k, v in data.items()
]
contents = [v["content"] for v in data.values()]
batches = [
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
embedding_tasks = [self.embedding_func(batch) for batch in batches]
2025-02-10 00:57:28 +08:00
embeddings_list = await asyncio.gather(*embedding_tasks)
embeddings = np.concatenate(embeddings_list)
list_points = []
for i, d in enumerate(list_data):
list_points.append(
models.PointStruct(
id=compute_mdhash_id_for_qdrant(d["id"]),
vector=embeddings[i],
payload=d,
)
)
results = self._client.upsert(
collection_name=self.namespace, points=list_points, wait=True
)
return results
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
2025-02-10 00:57:28 +08:00
embedding = await self.embedding_func([query])
results = self._client.search(
collection_name=self.namespace,
query_vector=embedding[0],
limit=top_k,
with_payload=True,
score_threshold=self.cosine_better_than_threshold,
2025-02-10 00:57:28 +08:00
)
2025-02-10 00:57:28 +08:00
logger.debug(f"query result: {results}")
return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in results]
async def index_done_callback(self) -> None:
2025-02-16 16:04:07 +01:00
# Qdrant handles persistence automatically
pass
2025-02-16 13:55:30 +01:00
2025-03-04 15:50:53 +08:00
async def delete(self, ids: List[str]) -> None:
"""Delete vectors with specified IDs
2025-03-04 15:53:20 +08:00
2025-03-04 15:50:53 +08:00
Args:
ids: List of vector IDs to be deleted
"""
try:
# Convert regular ids to Qdrant compatible ids
qdrant_ids = [compute_mdhash_id_for_qdrant(id) for id in ids]
# Delete points from the collection
self._client.delete(
collection_name=self.namespace,
points_selector=models.PointIdsList(
points=qdrant_ids,
),
2025-03-04 15:53:20 +08:00
wait=True,
)
logger.debug(
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
2025-03-04 15:50:53 +08:00
)
except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
async def delete_entity(self, entity_name: str) -> None:
2025-03-04 15:50:53 +08:00
"""Delete an entity by name
2025-03-04 15:53:20 +08:00
2025-03-04 15:50:53 +08:00
Args:
entity_name: Name of the entity to delete
"""
try:
# Generate the entity ID
entity_id = compute_mdhash_id_for_qdrant(entity_name, prefix="ent-")
2025-03-04 15:53:20 +08:00
logger.debug(
f"Attempting to delete entity {entity_name} with ID {entity_id}"
)
2025-03-04 15:50:53 +08:00
# Delete the entity point from the collection
self._client.delete(
collection_name=self.namespace,
points_selector=models.PointIdsList(
points=[entity_id],
),
2025-03-04 15:53:20 +08:00
wait=True,
2025-03-04 15:50:53 +08:00
)
logger.debug(f"Successfully deleted entity {entity_name}")
except Exception as e:
logger.error(f"Error deleting entity {entity_name}: {e}")
async def delete_entity_relation(self, entity_name: str) -> None:
2025-03-04 15:50:53 +08:00
"""Delete all relations associated with an entity
2025-03-04 15:53:20 +08:00
2025-03-04 15:50:53 +08:00
Args:
entity_name: Name of the entity whose relations should be deleted
"""
try:
# Find relations where the entity is either source or target
results = self._client.scroll(
collection_name=self.namespace,
scroll_filter=models.Filter(
should=[
models.FieldCondition(
2025-03-04 15:53:20 +08:00
key="src_id", match=models.MatchValue(value=entity_name)
2025-03-04 15:50:53 +08:00
),
models.FieldCondition(
2025-03-04 15:53:20 +08:00
key="tgt_id", match=models.MatchValue(value=entity_name)
),
2025-03-04 15:50:53 +08:00
]
),
with_payload=True,
2025-03-04 15:53:20 +08:00
limit=1000, # Adjust as needed for your use case
2025-03-04 15:50:53 +08:00
)
2025-03-04 15:53:20 +08:00
2025-03-04 15:50:53 +08:00
# Extract points that need to be deleted
relation_points = results[0]
ids_to_delete = [point.id for point in relation_points]
2025-03-04 15:53:20 +08:00
2025-03-04 15:50:53 +08:00
if ids_to_delete:
# Delete the relations
self._client.delete(
collection_name=self.namespace,
points_selector=models.PointIdsList(
points=ids_to_delete,
),
2025-03-04 15:53:20 +08:00
wait=True,
)
logger.debug(
f"Deleted {len(ids_to_delete)} relations for {entity_name}"
2025-03-04 15:50:53 +08:00
)
else:
logger.debug(f"No relations found for entity {entity_name}")
except Exception as e:
logger.error(f"Error deleting relations for {entity_name}: {e}")
2025-03-07 14:39:06 +08:00
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
"""Search for records with IDs starting with a specific prefix.
Args:
prefix: The prefix to search for in record IDs
Returns:
List of records with matching ID prefixes
"""
try:
# Use scroll method to find records with IDs starting with the prefix
results = self._client.scroll(
collection_name=self.namespace,
scroll_filter=models.Filter(
must=[
models.FieldCondition(
key="id", match=models.MatchText(text=prefix, prefix=True)
)
]
),
with_payload=True,
with_vectors=False,
limit=1000, # Adjust as needed for your use case
)
# Extract matching points
matching_records = results[0]
# Format the results to match expected return format
formatted_results = [
{**point.payload, "id": point.id} for point in matching_records
]
logger.debug(
f"Found {len(formatted_results)} records with prefix '{prefix}'"
)
return formatted_results
except Exception as e:
logger.error(f"Error searching for prefix '{prefix}': {e}")
return []