Merge pull request #1736 from danielaskdd/optimize-milvus

Refactoring implementation of Milvus vector storage
This commit is contained in:
Daniel.y 2025-07-04 21:54:44 +08:00 committed by GitHub
commit ae8e2db43d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 620 additions and 34 deletions

View File

@ -25,11 +25,11 @@ STORAGE_IMPLEMENTATIONS = {
"implementations": [
"NanoVectorDBStorage",
"MilvusVectorDBStorage",
"ChromaVectorDBStorage",
"PGVectorStorage",
"FaissVectorDBStorage",
"QdrantVectorDBStorage",
"MongoVectorDBStorage",
# "ChromaVectorDBStorage",
# "TiDBVectorDBStorage",
],
"required_methods": ["query", "upsert"],

View File

@ -15,7 +15,7 @@ if not pm.is_installed("pymilvus"):
pm.install("pymilvus")
import configparser
from pymilvus import MilvusClient # type: ignore
from pymilvus import MilvusClient, DataType, CollectionSchema, FieldSchema # type: ignore
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
@ -24,16 +24,605 @@ config.read("config.ini", "utf-8")
@final
@dataclass
class MilvusVectorDBStorage(BaseVectorStorage):
@staticmethod
def create_collection_if_not_exist(
client: MilvusClient, collection_name: str, **kwargs
):
if client.has_collection(collection_name):
return
client.create_collection(
collection_name, max_length=64, id_type="string", **kwargs
def _create_schema_for_namespace(self) -> CollectionSchema:
"""Create schema based on the current instance's namespace"""
# Get vector dimension from embedding_func
dimension = self.embedding_func.embedding_dim
# Base fields (common to all collections)
base_fields = [
FieldSchema(
name="id", dtype=DataType.VARCHAR, max_length=64, is_primary=True
),
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension),
FieldSchema(name="created_at", dtype=DataType.INT64),
]
# Determine specific fields based on namespace
if "entities" in self.namespace.lower():
specific_fields = [
FieldSchema(
name="entity_name",
dtype=DataType.VARCHAR,
max_length=256,
nullable=True,
),
FieldSchema(
name="entity_type",
dtype=DataType.VARCHAR,
max_length=64,
nullable=True,
),
FieldSchema(
name="file_path",
dtype=DataType.VARCHAR,
max_length=512,
nullable=True,
),
]
description = "LightRAG entities vector storage"
elif "relationships" in self.namespace.lower():
specific_fields = [
FieldSchema(
name="src_id", dtype=DataType.VARCHAR, max_length=256, nullable=True
),
FieldSchema(
name="tgt_id", dtype=DataType.VARCHAR, max_length=256, nullable=True
),
FieldSchema(name="weight", dtype=DataType.DOUBLE, nullable=True),
FieldSchema(
name="file_path",
dtype=DataType.VARCHAR,
max_length=512,
nullable=True,
),
]
description = "LightRAG relationships vector storage"
elif "chunks" in self.namespace.lower():
specific_fields = [
FieldSchema(
name="full_doc_id",
dtype=DataType.VARCHAR,
max_length=64,
nullable=True,
),
FieldSchema(
name="file_path",
dtype=DataType.VARCHAR,
max_length=512,
nullable=True,
),
]
description = "LightRAG chunks vector storage"
else:
# Default generic schema (backward compatibility)
specific_fields = [
FieldSchema(
name="file_path",
dtype=DataType.VARCHAR,
max_length=512,
nullable=True,
),
]
description = "LightRAG generic vector storage"
# Merge all fields
all_fields = base_fields + specific_fields
return CollectionSchema(
fields=all_fields,
description=description,
enable_dynamic_field=True, # Support dynamic fields
)
def _get_index_params(self):
"""Get IndexParams in a version-compatible way"""
try:
# Try to use client's prepare_index_params method (most common)
if hasattr(self._client, "prepare_index_params"):
return self._client.prepare_index_params()
except Exception:
pass
try:
# Try to import IndexParams from different possible locations
from pymilvus.client.prepare import IndexParams
return IndexParams()
except ImportError:
pass
try:
from pymilvus.client.types import IndexParams
return IndexParams()
except ImportError:
pass
try:
from pymilvus import IndexParams
return IndexParams()
except ImportError:
pass
# If all else fails, return None to use fallback method
return None
def _create_vector_index_fallback(self):
"""Fallback method to create vector index using direct API"""
try:
self._client.create_index(
collection_name=self.namespace,
field_name="vector",
index_params={
"index_type": "HNSW",
"metric_type": "COSINE",
"params": {"M": 16, "efConstruction": 256},
},
)
logger.debug("Created vector index using fallback method")
except Exception as e:
logger.warning(f"Failed to create vector index using fallback method: {e}")
def _create_scalar_index_fallback(self, field_name: str, index_type: str):
"""Fallback method to create scalar index using direct API"""
# Skip unsupported index types
if index_type == "SORTED":
logger.info(
f"Skipping SORTED index for {field_name} (not supported in this Milvus version)"
)
return
try:
self._client.create_index(
collection_name=self.namespace,
field_name=field_name,
index_params={"index_type": index_type},
)
logger.debug(f"Created {field_name} index using fallback method")
except Exception as e:
logger.info(
f"Could not create {field_name} index using fallback method: {e}"
)
def _create_indexes_after_collection(self):
"""Create indexes after collection is created"""
try:
# Try to get IndexParams in a version-compatible way
IndexParamsClass = self._get_index_params()
if IndexParamsClass is not None:
# Use IndexParams approach if available
try:
# Create vector index first (required for most operations)
vector_index = IndexParamsClass
vector_index.add_index(
field_name="vector",
index_type="HNSW",
metric_type="COSINE",
params={"M": 16, "efConstruction": 256},
)
self._client.create_index(
collection_name=self.namespace, index_params=vector_index
)
logger.debug("Created vector index using IndexParams")
except Exception as e:
logger.debug(f"IndexParams method failed for vector index: {e}")
self._create_vector_index_fallback()
# Create scalar indexes based on namespace
if "entities" in self.namespace.lower():
# Create indexes for entity fields
try:
entity_name_index = self._get_index_params()
entity_name_index.add_index(
field_name="entity_name", index_type="INVERTED"
)
self._client.create_index(
collection_name=self.namespace,
index_params=entity_name_index,
)
except Exception as e:
logger.debug(f"IndexParams method failed for entity_name: {e}")
self._create_scalar_index_fallback("entity_name", "INVERTED")
try:
entity_type_index = self._get_index_params()
entity_type_index.add_index(
field_name="entity_type", index_type="INVERTED"
)
self._client.create_index(
collection_name=self.namespace,
index_params=entity_type_index,
)
except Exception as e:
logger.debug(f"IndexParams method failed for entity_type: {e}")
self._create_scalar_index_fallback("entity_type", "INVERTED")
elif "relationships" in self.namespace.lower():
# Create indexes for relationship fields
try:
src_id_index = self._get_index_params()
src_id_index.add_index(
field_name="src_id", index_type="INVERTED"
)
self._client.create_index(
collection_name=self.namespace, index_params=src_id_index
)
except Exception as e:
logger.debug(f"IndexParams method failed for src_id: {e}")
self._create_scalar_index_fallback("src_id", "INVERTED")
try:
tgt_id_index = self._get_index_params()
tgt_id_index.add_index(
field_name="tgt_id", index_type="INVERTED"
)
self._client.create_index(
collection_name=self.namespace, index_params=tgt_id_index
)
except Exception as e:
logger.debug(f"IndexParams method failed for tgt_id: {e}")
self._create_scalar_index_fallback("tgt_id", "INVERTED")
elif "chunks" in self.namespace.lower():
# Create indexes for chunk fields
try:
doc_id_index = self._get_index_params()
doc_id_index.add_index(
field_name="full_doc_id", index_type="INVERTED"
)
self._client.create_index(
collection_name=self.namespace, index_params=doc_id_index
)
except Exception as e:
logger.debug(f"IndexParams method failed for full_doc_id: {e}")
self._create_scalar_index_fallback("full_doc_id", "INVERTED")
# No common indexes needed
else:
# Fallback to direct API calls if IndexParams is not available
logger.info(
f"IndexParams not available, using fallback methods for {self.namespace}"
)
# Create vector index using fallback
self._create_vector_index_fallback()
# Create scalar indexes using fallback
if "entities" in self.namespace.lower():
self._create_scalar_index_fallback("entity_name", "INVERTED")
self._create_scalar_index_fallback("entity_type", "INVERTED")
elif "relationships" in self.namespace.lower():
self._create_scalar_index_fallback("src_id", "INVERTED")
self._create_scalar_index_fallback("tgt_id", "INVERTED")
elif "chunks" in self.namespace.lower():
self._create_scalar_index_fallback("full_doc_id", "INVERTED")
logger.info(f"Created indexes for collection: {self.namespace}")
except Exception as e:
logger.warning(f"Failed to create some indexes for {self.namespace}: {e}")
def _get_required_fields_for_namespace(self) -> dict:
"""Get required core field definitions for current namespace"""
# Base fields (common to all types)
base_fields = {
"id": {"type": "VarChar", "is_primary": True},
"vector": {"type": "FloatVector"},
"created_at": {"type": "Int64"},
}
# Add specific fields based on namespace
if "entities" in self.namespace.lower():
specific_fields = {
"entity_name": {"type": "VarChar"},
"entity_type": {"type": "VarChar"},
"file_path": {"type": "VarChar"},
}
elif "relationships" in self.namespace.lower():
specific_fields = {
"src_id": {"type": "VarChar"},
"tgt_id": {"type": "VarChar"},
"weight": {"type": "Double"},
"file_path": {"type": "VarChar"},
}
elif "chunks" in self.namespace.lower():
specific_fields = {
"full_doc_id": {"type": "VarChar"},
"file_path": {"type": "VarChar"},
}
else:
specific_fields = {
"file_path": {"type": "VarChar"},
}
return {**base_fields, **specific_fields}
def _is_field_compatible(self, existing_field: dict, expected_config: dict) -> bool:
"""Check compatibility of a single field"""
field_name = existing_field.get("name", "unknown")
existing_type = existing_field.get("type")
expected_type = expected_config.get("type")
logger.debug(
f"Checking field '{field_name}': existing_type={existing_type} (type={type(existing_type)}), expected_type={expected_type}"
)
# Convert DataType enum values to string names if needed
original_existing_type = existing_type
if hasattr(existing_type, "name"):
existing_type = existing_type.name
logger.debug(
f"Converted enum to name: {original_existing_type} -> {existing_type}"
)
elif isinstance(existing_type, int):
# Map common Milvus internal type codes to type names for backward compatibility
type_mapping = {
21: "VarChar",
101: "FloatVector",
5: "Int64",
9: "Double",
}
mapped_type = type_mapping.get(existing_type, str(existing_type))
logger.debug(f"Mapped numeric type: {existing_type} -> {mapped_type}")
existing_type = mapped_type
# Normalize type names for comparison
type_aliases = {
"VARCHAR": "VarChar",
"String": "VarChar",
"FLOAT_VECTOR": "FloatVector",
"INT64": "Int64",
"BigInt": "Int64",
"DOUBLE": "Double",
"Float": "Double",
}
original_existing = existing_type
original_expected = expected_type
existing_type = type_aliases.get(existing_type, existing_type)
expected_type = type_aliases.get(expected_type, expected_type)
if original_existing != existing_type or original_expected != expected_type:
logger.debug(
f"Applied aliases: {original_existing} -> {existing_type}, {original_expected} -> {expected_type}"
)
# Basic type compatibility check
type_compatible = existing_type == expected_type
logger.debug(
f"Type compatibility for '{field_name}': {existing_type} == {expected_type} -> {type_compatible}"
)
if not type_compatible:
logger.warning(
f"Type mismatch for field '{field_name}': expected {expected_type}, got {existing_type}"
)
return False
# Primary key check - be more flexible about primary key detection
if expected_config.get("is_primary"):
# Check multiple possible field names for primary key status
is_primary = (
existing_field.get("is_primary_key", False)
or existing_field.get("is_primary", False)
or existing_field.get("primary_key", False)
)
logger.debug(
f"Primary key check for '{field_name}': expected=True, actual={is_primary}"
)
logger.debug(f"Raw field data for '{field_name}': {existing_field}")
# For ID field, be more lenient - if it's the ID field, assume it should be primary
if field_name == "id" and not is_primary:
logger.info(
f"ID field '{field_name}' not marked as primary in existing collection, but treating as compatible"
)
# Don't fail for ID field primary key mismatch
elif not is_primary:
logger.warning(
f"Primary key mismatch for field '{field_name}': expected primary key, but field is not primary"
)
return False
logger.debug(f"Field '{field_name}' is compatible")
return True
def _check_vector_dimension(self, collection_info: dict):
"""Check vector dimension compatibility"""
current_dimension = self.embedding_func.embedding_dim
# Find vector field dimension
for field in collection_info.get("fields", []):
if field.get("name") == "vector":
field_type = field.get("type")
if field_type in ["FloatVector", "FLOAT_VECTOR"]:
existing_dimension = field.get("params", {}).get("dim")
if existing_dimension != current_dimension:
raise ValueError(
f"Vector dimension mismatch for collection '{self.namespace}': "
f"existing={existing_dimension}, current={current_dimension}"
)
logger.debug(f"Vector dimension check passed: {current_dimension}")
return
# If no vector field found, this might be an old collection created with simple schema
logger.warning(
f"Vector field not found in collection '{self.namespace}'. This might be an old collection created with simple schema."
)
logger.warning("Consider recreating the collection for optimal performance.")
return
def _check_schema_compatibility(self, collection_info: dict):
"""Check schema field compatibility"""
existing_fields = {
field["name"]: field for field in collection_info.get("fields", [])
}
# Check if this is an old collection created with simple schema
has_vector_field = any(
field.get("name") == "vector" for field in collection_info.get("fields", [])
)
if not has_vector_field:
logger.warning(
f"Collection {self.namespace} appears to be created with old simple schema (no vector field)"
)
logger.warning(
"This collection will work but may have suboptimal performance"
)
logger.warning("Consider recreating the collection for optimal performance")
return
# For collections with vector field, check basic compatibility
# Only check for critical incompatibilities, not missing optional fields
critical_fields = {"id": {"type": "VarChar", "is_primary": True}}
incompatible_fields = []
for field_name, expected_config in critical_fields.items():
if field_name in existing_fields:
existing_field = existing_fields[field_name]
if not self._is_field_compatible(existing_field, expected_config):
incompatible_fields.append(
f"{field_name}: expected {expected_config['type']}, "
f"got {existing_field.get('type')}"
)
if incompatible_fields:
raise ValueError(
f"Critical schema incompatibility in collection '{self.namespace}': {incompatible_fields}"
)
# Get all expected fields for informational purposes
expected_fields = self._get_required_fields_for_namespace()
missing_fields = [
field for field in expected_fields if field not in existing_fields
]
if missing_fields:
logger.info(
f"Collection {self.namespace} missing optional fields: {missing_fields}"
)
logger.info(
"These fields would be available in a newly created collection for better performance"
)
logger.debug(f"Schema compatibility check passed for {self.namespace}")
def _validate_collection_compatibility(self):
"""Validate existing collection's dimension and schema compatibility"""
try:
collection_info = self._client.describe_collection(self.namespace)
# 1. Check vector dimension
self._check_vector_dimension(collection_info)
# 2. Check schema compatibility
self._check_schema_compatibility(collection_info)
logger.info(f"Collection {self.namespace} compatibility validation passed")
except Exception as e:
logger.error(
f"Collection compatibility validation failed for {self.namespace}: {e}"
)
raise
def _create_collection_if_not_exist(self):
"""Create collection if not exists and check existing collection compatibility"""
try:
# First, list all collections to see what actually exists
try:
all_collections = self._client.list_collections()
logger.debug(f"All collections in database: {all_collections}")
except Exception as list_error:
logger.warning(f"Could not list collections: {list_error}")
all_collections = []
# Check if our specific collection exists
collection_exists = self._client.has_collection(self.namespace)
logger.info(
f"Collection '{self.namespace}' exists check: {collection_exists}"
)
if collection_exists:
# Double-check by trying to describe the collection
try:
self._client.describe_collection(self.namespace)
logger.info(
f"Collection '{self.namespace}' confirmed to exist, validating compatibility..."
)
self._validate_collection_compatibility()
return
except Exception as describe_error:
logger.warning(
f"Collection '{self.namespace}' exists but cannot be described: {describe_error}"
)
logger.info(
"Treating as if collection doesn't exist and creating new one..."
)
# Fall through to creation logic
# Collection doesn't exist, create new collection
logger.info(f"Creating new collection: {self.namespace}")
schema = self._create_schema_for_namespace()
# Create collection with schema only first
self._client.create_collection(
collection_name=self.namespace, schema=schema
)
# Then create indexes
self._create_indexes_after_collection()
logger.info(f"Successfully created Milvus collection: {self.namespace}")
except Exception as e:
logger.error(
f"Error in _create_collection_if_not_exist for {self.namespace}: {e}"
)
# If there's any error, try to force create the collection
logger.info(f"Attempting to force create collection {self.namespace}...")
try:
# Try to drop the collection first if it exists in a bad state
try:
if self._client.has_collection(self.namespace):
logger.info(
f"Dropping potentially corrupted collection {self.namespace}"
)
self._client.drop_collection(self.namespace)
except Exception as drop_error:
logger.warning(
f"Could not drop collection {self.namespace}: {drop_error}"
)
# Create fresh collection
schema = self._create_schema_for_namespace()
self._client.create_collection(
collection_name=self.namespace, schema=schema
)
self._create_indexes_after_collection()
logger.info(f"Successfully force-created collection {self.namespace}")
except Exception as create_error:
logger.error(
f"Failed to force-create collection {self.namespace}: {create_error}"
)
raise
def __post_init__(self):
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = kwargs.get("cosine_better_than_threshold")
@ -43,6 +632,10 @@ class MilvusVectorDBStorage(BaseVectorStorage):
)
self.cosine_better_than_threshold = cosine_threshold
# Ensure created_at is in meta_fields
if "created_at" not in self.meta_fields:
self.meta_fields.add("created_at")
self._client = MilvusClient(
uri=os.environ.get(
"MILVUS_URI",
@ -68,11 +661,9 @@ class MilvusVectorDBStorage(BaseVectorStorage):
),
)
self._max_batch_size = self.global_config["embedding_batch_num"]
MilvusVectorDBStorage.create_collection_if_not_exist(
self._client,
self.namespace,
dimension=self.embedding_func.embedding_dim,
)
# Create collection and check compatibility
self._create_collection_if_not_exist()
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.debug(f"Inserting {len(data)} to {self.namespace}")
@ -112,23 +703,25 @@ class MilvusVectorDBStorage(BaseVectorStorage):
embedding = await self.embedding_func(
[query], _priority=5
) # higher priority for query
# Include all meta_fields (created_at is now always included)
output_fields = list(self.meta_fields)
results = self._client.search(
collection_name=self.namespace,
data=embedding,
limit=top_k,
output_fields=list(self.meta_fields) + ["created_at"],
output_fields=output_fields,
search_params={
"metric_type": "COSINE",
"params": {"radius": self.cosine_better_than_threshold},
},
)
print(results)
return [
{
**dp["entity"],
"id": dp["id"],
"distance": dp["distance"],
# created_at is requested in output_fields, so it should be a top-level key in the result dict (dp)
"created_at": dp.get("created_at"),
}
for dp in results[0]
@ -232,20 +825,19 @@ class MilvusVectorDBStorage(BaseVectorStorage):
The vector data if found, or None if not found
"""
try:
# Include all meta_fields (created_at is now always included) plus id
output_fields = list(self.meta_fields) + ["id"]
# Query Milvus for a specific ID
result = self._client.query(
collection_name=self.namespace,
filter=f'id == "{id}"',
output_fields=list(self.meta_fields) + ["id", "created_at"],
output_fields=output_fields,
)
if not result or len(result) == 0:
return None
# Ensure the result contains created_at field
if "created_at" not in result[0]:
result[0]["created_at"] = None
return result[0]
except Exception as e:
logger.error(f"Error retrieving vector data for ID {id}: {e}")
@ -264,6 +856,9 @@ class MilvusVectorDBStorage(BaseVectorStorage):
return []
try:
# Include all meta_fields (created_at is now always included) plus id
output_fields = list(self.meta_fields) + ["id"]
# Prepare the ID filter expression
id_list = '", "'.join(ids)
filter_expr = f'id in ["{id_list}"]'
@ -272,14 +867,9 @@ class MilvusVectorDBStorage(BaseVectorStorage):
result = self._client.query(
collection_name=self.namespace,
filter=filter_expr,
output_fields=list(self.meta_fields) + ["id", "created_at"],
output_fields=output_fields,
)
# Ensure each result contains created_at field
for item in result:
if "created_at" not in item:
item["created_at"] = None
return result or []
except Exception as e:
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
@ -301,11 +891,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
self._client.drop_collection(self.namespace)
# Recreate the collection
MilvusVectorDBStorage.create_collection_if_not_exist(
self._client,
self.namespace,
dimension=self.embedding_func.embedding_dim,
)
self._create_collection_if_not_exist()
logger.info(
f"Process {os.getpid()} drop Milvus collection {self.namespace}"