LightRAG/lightrag/kg/milvus_impl.py

1039 lines
40 KiB
Python
Raw Normal View History

2024-12-04 17:26:47 +08:00
import asyncio
import os
from typing import Any, final
2024-12-04 17:26:47 +08:00
from dataclasses import dataclass
import numpy as np
2025-03-04 15:50:53 +08:00
from lightrag.utils import logger, compute_mdhash_id
2024-12-04 17:26:47 +08:00
from ..base import BaseVectorStorage
from ..constants import DEFAULT_MAX_FILE_PATH_LENGTH
from ..kg.shared_storage import get_storage_lock
2025-02-16 15:08:50 +01:00
import pipmaster as pm
2025-02-18 20:27:59 +01:00
2025-02-16 15:08:50 +01:00
if not pm.is_installed("pymilvus"):
pm.install("pymilvus")
2025-02-19 19:47:20 +01:00
import configparser
2025-07-04 21:42:10 +08:00
from pymilvus import MilvusClient, DataType, CollectionSchema, FieldSchema # type: ignore
2024-12-04 17:26:47 +08:00
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
2025-02-11 03:29:40 +08:00
@final
2024-12-04 17:26:47 +08:00
@dataclass
2025-02-13 17:32:05 +08:00
class MilvusVectorDBStorage(BaseVectorStorage):
2025-07-04 21:42:10 +08:00
def _create_schema_for_namespace(self) -> CollectionSchema:
"""Create schema based on the current instance's namespace"""
# Get vector dimension from embedding_func
dimension = self.embedding_func.embedding_dim
# Base fields (common to all collections)
base_fields = [
FieldSchema(
name="id", dtype=DataType.VARCHAR, max_length=64, is_primary=True
),
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension),
FieldSchema(name="created_at", dtype=DataType.INT64),
]
# Determine specific fields based on namespace
if self.namespace.endswith("entities"):
2025-07-04 21:42:10 +08:00
specific_fields = [
FieldSchema(
name="entity_name",
dtype=DataType.VARCHAR,
max_length=512,
2025-07-04 21:42:10 +08:00
nullable=True,
),
FieldSchema(
name="file_path",
dtype=DataType.VARCHAR,
max_length=DEFAULT_MAX_FILE_PATH_LENGTH,
2025-07-04 21:42:10 +08:00
nullable=True,
),
]
description = "LightRAG entities vector storage"
elif self.namespace.endswith("relationships"):
2025-07-04 21:42:10 +08:00
specific_fields = [
FieldSchema(
name="src_id", dtype=DataType.VARCHAR, max_length=512, nullable=True
2025-07-04 21:42:10 +08:00
),
FieldSchema(
name="tgt_id", dtype=DataType.VARCHAR, max_length=512, nullable=True
2025-07-04 21:42:10 +08:00
),
FieldSchema(
name="file_path",
dtype=DataType.VARCHAR,
max_length=DEFAULT_MAX_FILE_PATH_LENGTH,
2025-07-04 21:42:10 +08:00
nullable=True,
),
]
description = "LightRAG relationships vector storage"
elif self.namespace.endswith("chunks"):
2025-07-04 21:42:10 +08:00
specific_fields = [
FieldSchema(
name="full_doc_id",
dtype=DataType.VARCHAR,
max_length=64,
nullable=True,
),
FieldSchema(
name="file_path",
dtype=DataType.VARCHAR,
max_length=1024,
2025-07-04 21:42:10 +08:00
nullable=True,
),
]
description = "LightRAG chunks vector storage"
else:
# Default generic schema (backward compatibility)
specific_fields = [
FieldSchema(
name="file_path",
dtype=DataType.VARCHAR,
max_length=1024,
2025-07-04 21:42:10 +08:00
nullable=True,
),
]
description = "LightRAG generic vector storage"
# Merge all fields
all_fields = base_fields + specific_fields
return CollectionSchema(
fields=all_fields,
description=description,
enable_dynamic_field=True, # Support dynamic fields
)
def _get_index_params(self):
"""Get IndexParams in a version-compatible way"""
try:
# Try to use client's prepare_index_params method (most common)
if hasattr(self._client, "prepare_index_params"):
return self._client.prepare_index_params()
except Exception:
pass
try:
# Try to import IndexParams from different possible locations
from pymilvus.client.prepare import IndexParams
return IndexParams()
except ImportError:
pass
try:
from pymilvus.client.types import IndexParams
return IndexParams()
except ImportError:
pass
try:
from pymilvus import IndexParams
return IndexParams()
except ImportError:
pass
# If all else fails, return None to use fallback method
return None
def _create_vector_index_fallback(self):
"""Fallback method to create vector index using direct API"""
try:
self._client.create_index(
collection_name=self.final_namespace,
2025-07-04 21:42:10 +08:00
field_name="vector",
index_params={
"index_type": "HNSW",
"metric_type": "COSINE",
"params": {"M": 16, "efConstruction": 256},
},
)
logger.debug(
f"[{self.workspace}] Created vector index using fallback method"
)
2025-07-04 21:42:10 +08:00
except Exception as e:
logger.warning(
f"[{self.workspace}] Failed to create vector index using fallback method: {e}"
)
2025-07-04 21:42:10 +08:00
def _create_scalar_index_fallback(self, field_name: str, index_type: str):
"""Fallback method to create scalar index using direct API"""
# Skip unsupported index types
if index_type == "SORTED":
logger.info(
f"[{self.workspace}] Skipping SORTED index for {field_name} (not supported in this Milvus version)"
2025-07-04 21:42:10 +08:00
)
2024-12-04 17:26:47 +08:00
return
2025-07-04 21:42:10 +08:00
try:
self._client.create_index(
collection_name=self.final_namespace,
2025-07-04 21:42:10 +08:00
field_name=field_name,
index_params={"index_type": index_type},
)
logger.debug(
f"[{self.workspace}] Created {field_name} index using fallback method"
)
2025-07-04 21:42:10 +08:00
except Exception as e:
logger.info(
f"[{self.workspace}] Could not create {field_name} index using fallback method: {e}"
2025-07-04 21:42:10 +08:00
)
def _create_indexes_after_collection(self):
"""Create indexes after collection is created"""
try:
# Try to get IndexParams in a version-compatible way
IndexParamsClass = self._get_index_params()
if IndexParamsClass is not None:
# Use IndexParams approach if available
try:
# Create vector index first (required for most operations)
vector_index = IndexParamsClass
vector_index.add_index(
field_name="vector",
index_type="HNSW",
metric_type="COSINE",
params={"M": 16, "efConstruction": 256},
)
self._client.create_index(
collection_name=self.final_namespace, index_params=vector_index
)
logger.debug(
f"[{self.workspace}] Created vector index using IndexParams"
2025-07-04 21:42:10 +08:00
)
except Exception as e:
logger.debug(
f"[{self.workspace}] IndexParams method failed for vector index: {e}"
)
2025-07-04 21:42:10 +08:00
self._create_vector_index_fallback()
# Create scalar indexes based on namespace
if self.namespace.endswith("entities"):
2025-07-04 21:42:10 +08:00
# Create indexes for entity fields
try:
entity_name_index = self._get_index_params()
entity_name_index.add_index(
field_name="entity_name", index_type="INVERTED"
)
self._client.create_index(
collection_name=self.final_namespace,
2025-07-04 21:42:10 +08:00
index_params=entity_name_index,
)
except Exception as e:
logger.debug(
f"[{self.workspace}] IndexParams method failed for entity_name: {e}"
)
2025-07-04 21:42:10 +08:00
self._create_scalar_index_fallback("entity_name", "INVERTED")
elif self.namespace.endswith("relationships"):
2025-07-04 21:42:10 +08:00
# Create indexes for relationship fields
try:
src_id_index = self._get_index_params()
src_id_index.add_index(
field_name="src_id", index_type="INVERTED"
)
self._client.create_index(
collection_name=self.final_namespace,
index_params=src_id_index,
2025-07-04 21:42:10 +08:00
)
except Exception as e:
logger.debug(
f"[{self.workspace}] IndexParams method failed for src_id: {e}"
)
2025-07-04 21:42:10 +08:00
self._create_scalar_index_fallback("src_id", "INVERTED")
try:
tgt_id_index = self._get_index_params()
tgt_id_index.add_index(
field_name="tgt_id", index_type="INVERTED"
)
self._client.create_index(
collection_name=self.final_namespace,
index_params=tgt_id_index,
2025-07-04 21:42:10 +08:00
)
except Exception as e:
logger.debug(
f"[{self.workspace}] IndexParams method failed for tgt_id: {e}"
)
2025-07-04 21:42:10 +08:00
self._create_scalar_index_fallback("tgt_id", "INVERTED")
elif self.namespace.endswith("chunks"):
2025-07-04 21:42:10 +08:00
# Create indexes for chunk fields
try:
doc_id_index = self._get_index_params()
doc_id_index.add_index(
field_name="full_doc_id", index_type="INVERTED"
)
self._client.create_index(
collection_name=self.final_namespace,
index_params=doc_id_index,
2025-07-04 21:42:10 +08:00
)
except Exception as e:
logger.debug(
f"[{self.workspace}] IndexParams method failed for full_doc_id: {e}"
)
2025-07-04 21:42:10 +08:00
self._create_scalar_index_fallback("full_doc_id", "INVERTED")
# No common indexes needed
else:
# Fallback to direct API calls if IndexParams is not available
logger.info(
f"[{self.workspace}] IndexParams not available, using fallback methods for {self.namespace}"
2025-07-04 21:42:10 +08:00
)
# Create vector index using fallback
self._create_vector_index_fallback()
# Create scalar indexes using fallback
if self.namespace.endswith("entities"):
2025-07-04 21:42:10 +08:00
self._create_scalar_index_fallback("entity_name", "INVERTED")
elif self.namespace.endswith("relationships"):
2025-07-04 21:42:10 +08:00
self._create_scalar_index_fallback("src_id", "INVERTED")
self._create_scalar_index_fallback("tgt_id", "INVERTED")
elif self.namespace.endswith("chunks"):
2025-07-04 21:42:10 +08:00
self._create_scalar_index_fallback("full_doc_id", "INVERTED")
logger.info(
f"[{self.workspace}] Created indexes for collection: {self.namespace}"
)
2025-07-04 21:42:10 +08:00
except Exception as e:
logger.warning(
f"[{self.workspace}] Failed to create some indexes for {self.namespace}: {e}"
)
2025-07-04 21:42:10 +08:00
def _get_required_fields_for_namespace(self) -> dict:
"""Get required core field definitions for current namespace"""
# Base fields (common to all types)
base_fields = {
"id": {"type": "VarChar", "is_primary": True},
"vector": {"type": "FloatVector"},
"created_at": {"type": "Int64"},
}
# Add specific fields based on namespace
if self.namespace.endswith("entities"):
2025-07-04 21:42:10 +08:00
specific_fields = {
"entity_name": {"type": "VarChar"},
"file_path": {"type": "VarChar"},
}
elif self.namespace.endswith("relationships"):
2025-07-04 21:42:10 +08:00
specific_fields = {
"src_id": {"type": "VarChar"},
"tgt_id": {"type": "VarChar"},
"file_path": {"type": "VarChar"},
}
elif self.namespace.endswith("chunks"):
2025-07-04 21:42:10 +08:00
specific_fields = {
"full_doc_id": {"type": "VarChar"},
"file_path": {"type": "VarChar"},
}
else:
specific_fields = {
"file_path": {"type": "VarChar"},
}
return {**base_fields, **specific_fields}
def _is_field_compatible(self, existing_field: dict, expected_config: dict) -> bool:
"""Check compatibility of a single field"""
field_name = existing_field.get("name", "unknown")
existing_type = existing_field.get("type")
expected_type = expected_config.get("type")
logger.debug(
f"[{self.workspace}] Checking field '{field_name}': existing_type={existing_type} (type={type(existing_type)}), expected_type={expected_type}"
2025-07-04 21:42:10 +08:00
)
# Convert DataType enum values to string names if needed
original_existing_type = existing_type
if hasattr(existing_type, "name"):
existing_type = existing_type.name
logger.debug(
f"[{self.workspace}] Converted enum to name: {original_existing_type} -> {existing_type}"
2025-07-04 21:42:10 +08:00
)
elif isinstance(existing_type, int):
# Map common Milvus internal type codes to type names for backward compatibility
type_mapping = {
21: "VarChar",
101: "FloatVector",
5: "Int64",
9: "Double",
}
mapped_type = type_mapping.get(existing_type, str(existing_type))
logger.debug(
f"[{self.workspace}] Mapped numeric type: {existing_type} -> {mapped_type}"
)
2025-07-04 21:42:10 +08:00
existing_type = mapped_type
# Normalize type names for comparison
type_aliases = {
"VARCHAR": "VarChar",
"String": "VarChar",
"FLOAT_VECTOR": "FloatVector",
"INT64": "Int64",
"BigInt": "Int64",
"DOUBLE": "Double",
"Float": "Double",
}
original_existing = existing_type
original_expected = expected_type
existing_type = type_aliases.get(existing_type, existing_type)
expected_type = type_aliases.get(expected_type, expected_type)
if original_existing != existing_type or original_expected != expected_type:
logger.debug(
f"[{self.workspace}] Applied aliases: {original_existing} -> {existing_type}, {original_expected} -> {expected_type}"
2025-07-04 21:42:10 +08:00
)
# Basic type compatibility check
type_compatible = existing_type == expected_type
logger.debug(
f"[{self.workspace}] Type compatibility for '{field_name}': {existing_type} == {expected_type} -> {type_compatible}"
2024-12-04 17:26:47 +08:00
)
2025-07-04 21:42:10 +08:00
if not type_compatible:
logger.warning(
f"[{self.workspace}] Type mismatch for field '{field_name}': expected {expected_type}, got {existing_type}"
2025-07-04 21:42:10 +08:00
)
return False
# Primary key check - be more flexible about primary key detection
if expected_config.get("is_primary"):
# Check multiple possible field names for primary key status
is_primary = (
existing_field.get("is_primary_key", False)
or existing_field.get("is_primary", False)
or existing_field.get("primary_key", False)
)
logger.debug(
f"[{self.workspace}] Primary key check for '{field_name}': expected=True, actual={is_primary}"
)
logger.debug(
f"[{self.workspace}] Raw field data for '{field_name}': {existing_field}"
2025-07-04 21:42:10 +08:00
)
# For ID field, be more lenient - if it's the ID field, assume it should be primary
if field_name == "id" and not is_primary:
logger.info(
f"[{self.workspace}] ID field '{field_name}' not marked as primary in existing collection, but treating as compatible"
2025-07-04 21:42:10 +08:00
)
# Don't fail for ID field primary key mismatch
elif not is_primary:
logger.warning(
f"[{self.workspace}] Primary key mismatch for field '{field_name}': expected primary key, but field is not primary"
2025-07-04 21:42:10 +08:00
)
return False
logger.debug(f"[{self.workspace}] Field '{field_name}' is compatible")
2025-07-04 21:42:10 +08:00
return True
def _check_vector_dimension(self, collection_info: dict):
"""Check vector dimension compatibility"""
current_dimension = self.embedding_func.embedding_dim
# Find vector field dimension
for field in collection_info.get("fields", []):
if field.get("name") == "vector":
field_type = field.get("type")
2025-07-05 07:55:36 +08:00
# Extract type name from DataType enum or string
type_name = None
if hasattr(field_type, "name"):
type_name = field_type.name
elif isinstance(field_type, str):
type_name = field_type
else:
type_name = str(field_type)
# Check if it's a vector type (supports multiple formats)
if type_name in ["FloatVector", "FLOAT_VECTOR"]:
2025-07-04 21:42:10 +08:00
existing_dimension = field.get("params", {}).get("dim")
if existing_dimension != current_dimension:
raise ValueError(
f"Vector dimension mismatch for collection '{self.final_namespace}': "
2025-07-04 21:42:10 +08:00
f"existing={existing_dimension}, current={current_dimension}"
)
logger.debug(
f"[{self.workspace}] Vector dimension check passed: {current_dimension}"
)
2025-07-04 21:42:10 +08:00
return
# If no vector field found, this might be an old collection created with simple schema
logger.warning(
f"[{self.workspace}] Vector field not found in collection '{self.namespace}'. This might be an old collection created with simple schema."
)
logger.warning(
f"[{self.workspace}] Consider recreating the collection for optimal performance."
2025-07-04 21:42:10 +08:00
)
return
def _check_schema_compatibility(self, collection_info: dict):
"""Check schema field compatibility"""
existing_fields = {
field["name"]: field for field in collection_info.get("fields", [])
}
# Check if this is an old collection created with simple schema
has_vector_field = any(
field.get("name") == "vector" for field in collection_info.get("fields", [])
)
if not has_vector_field:
logger.warning(
f"[{self.workspace}] Collection {self.namespace} appears to be created with old simple schema (no vector field)"
)
logger.warning(
f"[{self.workspace}] This collection will work but may have suboptimal performance"
2025-07-04 21:42:10 +08:00
)
logger.warning(
f"[{self.workspace}] Consider recreating the collection for optimal performance"
2025-07-04 21:42:10 +08:00
)
return
# For collections with vector field, check basic compatibility
# Only check for critical incompatibilities, not missing optional fields
critical_fields = {"id": {"type": "VarChar", "is_primary": True}}
incompatible_fields = []
for field_name, expected_config in critical_fields.items():
if field_name in existing_fields:
existing_field = existing_fields[field_name]
if not self._is_field_compatible(existing_field, expected_config):
incompatible_fields.append(
f"{field_name}: expected {expected_config['type']}, "
f"got {existing_field.get('type')}"
)
if incompatible_fields:
raise ValueError(
f"Critical schema incompatibility in collection '{self.final_namespace}': {incompatible_fields}"
2025-07-04 21:42:10 +08:00
)
# Get all expected fields for informational purposes
expected_fields = self._get_required_fields_for_namespace()
missing_fields = [
field for field in expected_fields if field not in existing_fields
]
if missing_fields:
logger.info(
f"[{self.workspace}] Collection {self.namespace} missing optional fields: {missing_fields}"
2025-07-04 21:42:10 +08:00
)
logger.info(
"These fields would be available in a newly created collection for better performance"
)
logger.debug(
f"[{self.workspace}] Schema compatibility check passed for {self.namespace}"
)
2025-07-04 21:42:10 +08:00
def _validate_collection_compatibility(self):
"""Validate existing collection's dimension and schema compatibility"""
try:
collection_info = self._client.describe_collection(self.final_namespace)
2025-07-04 21:42:10 +08:00
# 1. Check vector dimension
self._check_vector_dimension(collection_info)
# 2. Check schema compatibility
self._check_schema_compatibility(collection_info)
2025-07-05 07:55:36 +08:00
logger.info(
f"[{self.workspace}] VectorDB Collection '{self.namespace}' compatibility validation passed"
2025-07-05 07:55:36 +08:00
)
2025-07-04 21:42:10 +08:00
except Exception as e:
logger.error(
f"[{self.workspace}] Collection compatibility validation failed for {self.namespace}: {e}"
2025-07-04 21:42:10 +08:00
)
raise
def _ensure_collection_loaded(self):
"""Ensure the collection is loaded into memory for search operations"""
try:
# Check if collection exists first
if not self._client.has_collection(self.final_namespace):
logger.error(
f"[{self.workspace}] Collection {self.namespace} does not exist"
)
raise ValueError(f"Collection {self.final_namespace} does not exist")
2025-07-05 07:09:01 +08:00
# Load the collection if it's not already loaded
# In Milvus, collections need to be loaded before they can be searched
self._client.load_collection(self.final_namespace)
# logger.debug(f"[{self.workspace}] Collection {self.namespace} loaded successfully")
2025-07-05 07:09:01 +08:00
except Exception as e:
logger.error(
f"[{self.workspace}] Failed to load collection {self.namespace}: {e}"
)
raise
2025-07-04 21:42:10 +08:00
def _create_collection_if_not_exist(self):
"""Create collection if not exists and check existing collection compatibility"""
try:
# Check if our specific collection exists
collection_exists = self._client.has_collection(self.final_namespace)
2025-07-04 21:42:10 +08:00
logger.info(
f"[{self.workspace}] VectorDB collection '{self.namespace}' exists check: {collection_exists}"
2025-07-04 21:42:10 +08:00
)
if collection_exists:
# Double-check by trying to describe the collection
try:
self._client.describe_collection(self.final_namespace)
2025-07-04 21:42:10 +08:00
self._validate_collection_compatibility()
# Ensure the collection is loaded after validation
self._ensure_collection_loaded()
2025-07-04 21:42:10 +08:00
return
except Exception as describe_error:
logger.warning(
f"[{self.workspace}] Collection '{self.namespace}' exists but cannot be described: {describe_error}"
2025-07-04 21:42:10 +08:00
)
logger.info(
f"[{self.workspace}] Treating as if collection doesn't exist and creating new one..."
2025-07-04 21:42:10 +08:00
)
# Fall through to creation logic
# Collection doesn't exist, create new collection
logger.info(f"[{self.workspace}] Creating new collection: {self.namespace}")
2025-07-04 21:42:10 +08:00
schema = self._create_schema_for_namespace()
# Create collection with schema only first
self._client.create_collection(
collection_name=self.final_namespace, schema=schema
2025-07-04 21:42:10 +08:00
)
# Then create indexes
self._create_indexes_after_collection()
# Load the newly created collection
self._ensure_collection_loaded()
logger.info(
f"[{self.workspace}] Successfully created Milvus collection: {self.namespace}"
)
2025-07-04 21:42:10 +08:00
except Exception as e:
logger.error(
f"[{self.workspace}] Error in _create_collection_if_not_exist for {self.namespace}: {e}"
2025-07-04 21:42:10 +08:00
)
# If there's any error, try to force create the collection
logger.info(
f"[{self.workspace}] Attempting to force create collection {self.namespace}..."
)
2025-07-04 21:42:10 +08:00
try:
# Try to drop the collection first if it exists in a bad state
try:
if self._client.has_collection(self.final_namespace):
2025-07-04 21:42:10 +08:00
logger.info(
f"[{self.workspace}] Dropping potentially corrupted collection {self.namespace}"
2025-07-04 21:42:10 +08:00
)
self._client.drop_collection(self.final_namespace)
2025-07-04 21:42:10 +08:00
except Exception as drop_error:
logger.warning(
f"[{self.workspace}] Could not drop collection {self.namespace}: {drop_error}"
2025-07-04 21:42:10 +08:00
)
# Create fresh collection
schema = self._create_schema_for_namespace()
self._client.create_collection(
collection_name=self.final_namespace, schema=schema
2025-07-04 21:42:10 +08:00
)
self._create_indexes_after_collection()
2025-07-05 07:09:01 +08:00
# Load the newly created collection
self._ensure_collection_loaded()
2025-07-05 07:09:01 +08:00
logger.info(
f"[{self.workspace}] Successfully force-created collection {self.namespace}"
)
2025-07-04 21:42:10 +08:00
except Exception as create_error:
logger.error(
f"[{self.workspace}] Failed to force-create collection {self.namespace}: {create_error}"
2025-07-04 21:42:10 +08:00
)
raise
2024-12-04 17:26:47 +08:00
def __post_init__(self):
# Check for MILVUS_WORKSPACE environment variable first (higher priority)
# This allows administrators to force a specific workspace for all Milvus storage instances
milvus_workspace = os.environ.get("MILVUS_WORKSPACE")
if milvus_workspace and milvus_workspace.strip():
# Use environment variable value, overriding the passed workspace parameter
effective_workspace = milvus_workspace.strip()
logger.info(
f"Using MILVUS_WORKSPACE environment variable: '{effective_workspace}' (overriding passed workspace: '{self.workspace}')"
)
else:
# Use the workspace parameter passed during initialization
effective_workspace = self.workspace
if effective_workspace:
logger.debug(
f"Using passed workspace parameter: '{effective_workspace}'"
)
# Build final_namespace with workspace prefix for data isolation
# Keep original namespace unchanged for type detection logic
if effective_workspace:
self.final_namespace = f"{effective_workspace}_{self.namespace}"
logger.debug(
f"Final namespace with workspace prefix: '{self.final_namespace}'"
)
else:
# When workspace is empty, final_namespace equals original namespace
self.final_namespace = self.namespace
logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'")
self.workspace = "_"
2025-02-14 03:00:56 +08:00
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = kwargs.get("cosine_better_than_threshold")
if cosine_threshold is None:
2025-02-13 04:12:00 +08:00
raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold
2025-07-04 21:42:10 +08:00
# Ensure created_at is in meta_fields
if "created_at" not in self.meta_fields:
self.meta_fields.add("created_at")
2024-12-04 17:26:47 +08:00
self._client = MilvusClient(
2025-02-11 03:29:40 +08:00
uri=os.environ.get(
"MILVUS_URI",
config.get(
"milvus",
"uri",
fallback=os.path.join(
self.global_config["working_dir"], "milvus_lite.db"
),
),
),
user=os.environ.get(
"MILVUS_USER", config.get("milvus", "user", fallback=None)
),
password=os.environ.get(
"MILVUS_PASSWORD", config.get("milvus", "password", fallback=None)
),
token=os.environ.get(
"MILVUS_TOKEN", config.get("milvus", "token", fallback=None)
),
db_name=os.environ.get(
"MILVUS_DB_NAME", config.get("milvus", "db_name", fallback=None)
),
2024-12-04 17:26:47 +08:00
)
self._max_batch_size = self.global_config["embedding_batch_num"]
self._initialized = False
2025-07-04 21:42:10 +08:00
async def initialize(self):
"""Initialize Milvus collection"""
async with get_storage_lock(enable_logging=True):
if self._initialized:
return
try:
# Create collection and check compatibility
self._create_collection_if_not_exist()
self._initialized = True
logger.info(
f"[{self.workspace}] Milvus collection '{self.namespace}' initialized successfully"
)
except Exception as e:
logger.error(
f"[{self.workspace}] Failed to initialize Milvus collection '{self.namespace}': {e}"
)
raise
2024-12-04 17:26:47 +08:00
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
2025-02-19 22:22:41 +01:00
if not data:
return
# Ensure collection is loaded before upserting
self._ensure_collection_loaded()
import time
2025-05-03 00:46:28 +08:00
current_time = int(time.time())
2025-05-03 00:46:28 +08:00
2025-02-19 22:22:41 +01:00
list_data: list[dict[str, Any]] = [
2024-12-04 17:26:47 +08:00
{
"id": k,
"created_at": current_time,
2024-12-04 17:26:47 +08:00
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
}
for k, v in data.items()
]
contents = [v["content"] for v in data.values()]
batches = [
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
2024-12-13 16:48:22 +08:00
embedding_tasks = [self.embedding_func(batch) for batch in batches]
2024-12-13 16:48:22 +08:00
embeddings_list = await asyncio.gather(*embedding_tasks)
2024-12-04 17:26:47 +08:00
embeddings = np.concatenate(embeddings_list)
for i, d in enumerate(list_data):
d["vector"] = embeddings[i]
results = self._client.upsert(
collection_name=self.final_namespace, data=list_data
)
2024-12-04 17:26:47 +08:00
return results
async def query(
self, query: str, top_k: int, ids: list[str] | None = None
) -> list[dict[str, Any]]:
# Ensure collection is loaded before querying
self._ensure_collection_loaded()
2025-07-05 07:09:01 +08:00
embedding = await self.embedding_func(
[query], _priority=5
) # higher priority for query
2025-07-04 21:42:10 +08:00
# Include all meta_fields (created_at is now always included)
output_fields = list(self.meta_fields)
2024-12-04 17:26:47 +08:00
results = self._client.search(
collection_name=self.final_namespace,
2024-12-04 17:26:47 +08:00
data=embedding,
limit=top_k,
2025-07-04 21:42:10 +08:00
output_fields=output_fields,
2025-02-13 04:12:00 +08:00
search_params={
"metric_type": "COSINE",
"params": {"radius": self.cosine_better_than_threshold},
},
2024-12-04 17:26:47 +08:00
)
return [
{
2025-05-03 00:46:28 +08:00
**dp["entity"],
"id": dp["id"],
"distance": dp["distance"],
2025-05-03 21:51:45 +08:00
"created_at": dp.get("created_at"),
}
2024-12-04 17:26:47 +08:00
for dp in results[0]
]
async def index_done_callback(self) -> None:
2025-02-16 16:04:07 +01:00
# Milvus handles persistence automatically
pass
2025-02-16 13:55:30 +01:00
async def delete_entity(self, entity_name: str) -> None:
2025-03-04 15:50:53 +08:00
"""Delete an entity from the vector database
2025-03-04 15:53:20 +08:00
2025-03-04 15:50:53 +08:00
Args:
entity_name: The name of the entity to delete
"""
try:
# Compute entity ID from name
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
2025-03-04 15:53:20 +08:00
logger.debug(
f"[{self.workspace}] Attempting to delete entity {entity_name} with ID {entity_id}"
2025-03-04 15:53:20 +08:00
)
2025-03-04 15:50:53 +08:00
# Delete the entity from Milvus collection
result = self._client.delete(
collection_name=self.final_namespace, pks=[entity_id]
2025-03-04 15:50:53 +08:00
)
2025-03-04 15:53:20 +08:00
2025-03-04 15:50:53 +08:00
if result and result.get("delete_count", 0) > 0:
logger.debug(
f"[{self.workspace}] Successfully deleted entity {entity_name}"
)
2025-03-04 15:50:53 +08:00
else:
logger.debug(
f"[{self.workspace}] Entity {entity_name} not found in storage"
)
2025-03-04 15:53:20 +08:00
2025-03-04 15:50:53 +08:00
except Exception as e:
logger.error(f"[{self.workspace}] Error deleting entity {entity_name}: {e}")
async def delete_entity_relation(self, entity_name: str) -> None:
2025-03-04 15:50:53 +08:00
"""Delete all relations associated with an entity
2025-03-04 15:53:20 +08:00
2025-03-04 15:50:53 +08:00
Args:
entity_name: The name of the entity whose relations should be deleted
"""
try:
# Ensure collection is loaded before querying
self._ensure_collection_loaded()
2025-07-05 07:09:01 +08:00
2025-03-04 15:50:53 +08:00
# Search for relations where entity is either source or target
expr = f'src_id == "{entity_name}" or tgt_id == "{entity_name}"'
2025-03-04 15:53:20 +08:00
2025-03-04 15:50:53 +08:00
# Find all relations involving this entity
results = self._client.query(
collection_name=self.final_namespace, filter=expr, output_fields=["id"]
2025-03-04 15:50:53 +08:00
)
2025-03-04 15:53:20 +08:00
2025-03-04 15:50:53 +08:00
if not results or len(results) == 0:
logger.debug(
f"[{self.workspace}] No relations found for entity {entity_name}"
)
2025-03-04 15:50:53 +08:00
return
2025-03-04 15:53:20 +08:00
2025-03-04 15:50:53 +08:00
# Extract IDs of relations to delete
relation_ids = [item["id"] for item in results]
2025-03-04 15:53:20 +08:00
logger.debug(
f"[{self.workspace}] Found {len(relation_ids)} relations for entity {entity_name}"
2025-03-04 15:53:20 +08:00
)
2025-03-04 15:50:53 +08:00
# Delete the relations
if relation_ids:
delete_result = self._client.delete(
collection_name=self.final_namespace, pks=relation_ids
2025-03-04 15:53:20 +08:00
)
logger.debug(
f"[{self.workspace}] Deleted {delete_result.get('delete_count', 0)} relations for {entity_name}"
2025-03-04 15:50:53 +08:00
)
2025-03-04 15:53:20 +08:00
2025-03-04 15:50:53 +08:00
except Exception as e:
logger.error(
f"[{self.workspace}] Error deleting relations for {entity_name}: {e}"
)
2025-03-04 15:50:53 +08:00
async def delete(self, ids: list[str]) -> None:
"""Delete vectors with specified IDs
2025-03-04 15:53:20 +08:00
2025-03-04 15:50:53 +08:00
Args:
ids: List of vector IDs to be deleted
"""
try:
# Ensure collection is loaded before deleting
self._ensure_collection_loaded()
2025-07-05 07:09:01 +08:00
2025-03-04 15:50:53 +08:00
# Delete vectors by IDs
result = self._client.delete(collection_name=self.final_namespace, pks=ids)
2025-03-04 15:53:20 +08:00
2025-03-04 15:50:53 +08:00
if result and result.get("delete_count", 0) > 0:
2025-03-04 15:53:20 +08:00
logger.debug(
f"[{self.workspace}] Successfully deleted {result.get('delete_count', 0)} vectors from {self.namespace}"
2025-03-04 15:53:20 +08:00
)
2025-03-04 15:50:53 +08:00
else:
logger.debug(
f"[{self.workspace}] No vectors were deleted from {self.namespace}"
)
2025-03-04 15:53:20 +08:00
2025-03-04 15:50:53 +08:00
except Exception as e:
logger.error(
f"[{self.workspace}] Error while deleting vectors from {self.namespace}: {e}"
)
2025-03-07 14:39:06 +08:00
2025-03-11 16:05:04 +08:00
async def get_by_id(self, id: str) -> dict[str, Any] | None:
"""Get vector data by its ID
Args:
id: The unique identifier of the vector
Returns:
The vector data if found, or None if not found
"""
try:
# Ensure collection is loaded before querying
self._ensure_collection_loaded()
2025-07-05 07:09:01 +08:00
2025-07-04 21:42:10 +08:00
# Include all meta_fields (created_at is now always included) plus id
output_fields = list(self.meta_fields) + ["id"]
2025-03-11 16:05:04 +08:00
# Query Milvus for a specific ID
result = self._client.query(
collection_name=self.final_namespace,
2025-03-11 16:05:04 +08:00
filter=f'id == "{id}"',
2025-07-04 21:42:10 +08:00
output_fields=output_fields,
2025-03-11 16:05:04 +08:00
)
if not result or len(result) == 0:
return None
return result[0]
except Exception as e:
logger.error(
f"[{self.workspace}] Error retrieving vector data for ID {id}: {e}"
)
2025-03-11 16:05:04 +08:00
return None
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
"""Get multiple vector data by their IDs
Args:
ids: List of unique identifiers
Returns:
List of vector data objects that were found
"""
if not ids:
return []
try:
# Ensure collection is loaded before querying
self._ensure_collection_loaded()
2025-07-05 07:09:01 +08:00
2025-07-04 21:42:10 +08:00
# Include all meta_fields (created_at is now always included) plus id
output_fields = list(self.meta_fields) + ["id"]
2025-03-11 16:05:04 +08:00
# Prepare the ID filter expression
id_list = '", "'.join(ids)
filter_expr = f'id in ["{id_list}"]'
# Query Milvus with the filter
result = self._client.query(
collection_name=self.final_namespace,
2025-03-11 16:05:04 +08:00
filter=filter_expr,
2025-07-04 21:42:10 +08:00
output_fields=output_fields,
2025-03-11 16:05:04 +08:00
)
2025-05-03 00:46:28 +08:00
2025-03-11 16:05:04 +08:00
return result or []
except Exception as e:
logger.error(
f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}"
)
2025-03-11 16:05:04 +08:00
return []
2025-03-31 23:22:27 +08:00
async def drop(self) -> dict[str, str]:
"""Drop all vector data from storage and clean up resources
2025-03-31 23:22:27 +08:00
This method will delete all data from the Milvus collection.
2025-03-31 23:22:27 +08:00
Returns:
dict[str, str]: Operation status and message
- On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"}
"""
async with get_storage_lock(enable_logging=True):
try:
# Drop the collection and recreate it
if self._client.has_collection(self.final_namespace):
self._client.drop_collection(self.final_namespace)
2025-03-31 23:22:27 +08:00
# Recreate the collection
self._create_collection_if_not_exist()
2025-03-31 23:22:27 +08:00
logger.info(
f"[{self.workspace}] Process {os.getpid()} drop Milvus collection {self.namespace}"
)
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(
f"[{self.workspace}] Error dropping Milvus collection {self.namespace}: {e}"
)
return {"status": "error", "message": str(e)}