mirror of
https://github.com/HKUDS/LightRAG.git
synced 2026-01-06 03:40:46 +00:00
Merge pull request #1736 from danielaskdd/optimize-milvus
Refactoring implementation of Milvus vector storage
This commit is contained in:
commit
ae8e2db43d
@ -25,11 +25,11 @@ STORAGE_IMPLEMENTATIONS = {
|
||||
"implementations": [
|
||||
"NanoVectorDBStorage",
|
||||
"MilvusVectorDBStorage",
|
||||
"ChromaVectorDBStorage",
|
||||
"PGVectorStorage",
|
||||
"FaissVectorDBStorage",
|
||||
"QdrantVectorDBStorage",
|
||||
"MongoVectorDBStorage",
|
||||
# "ChromaVectorDBStorage",
|
||||
# "TiDBVectorDBStorage",
|
||||
],
|
||||
"required_methods": ["query", "upsert"],
|
||||
|
||||
@ -15,7 +15,7 @@ if not pm.is_installed("pymilvus"):
|
||||
pm.install("pymilvus")
|
||||
|
||||
import configparser
|
||||
from pymilvus import MilvusClient # type: ignore
|
||||
from pymilvus import MilvusClient, DataType, CollectionSchema, FieldSchema # type: ignore
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
@ -24,16 +24,605 @@ config.read("config.ini", "utf-8")
|
||||
@final
|
||||
@dataclass
|
||||
class MilvusVectorDBStorage(BaseVectorStorage):
|
||||
@staticmethod
|
||||
def create_collection_if_not_exist(
|
||||
client: MilvusClient, collection_name: str, **kwargs
|
||||
):
|
||||
if client.has_collection(collection_name):
|
||||
return
|
||||
client.create_collection(
|
||||
collection_name, max_length=64, id_type="string", **kwargs
|
||||
def _create_schema_for_namespace(self) -> CollectionSchema:
|
||||
"""Create schema based on the current instance's namespace"""
|
||||
|
||||
# Get vector dimension from embedding_func
|
||||
dimension = self.embedding_func.embedding_dim
|
||||
|
||||
# Base fields (common to all collections)
|
||||
base_fields = [
|
||||
FieldSchema(
|
||||
name="id", dtype=DataType.VARCHAR, max_length=64, is_primary=True
|
||||
),
|
||||
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension),
|
||||
FieldSchema(name="created_at", dtype=DataType.INT64),
|
||||
]
|
||||
|
||||
# Determine specific fields based on namespace
|
||||
if "entities" in self.namespace.lower():
|
||||
specific_fields = [
|
||||
FieldSchema(
|
||||
name="entity_name",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=256,
|
||||
nullable=True,
|
||||
),
|
||||
FieldSchema(
|
||||
name="entity_type",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=64,
|
||||
nullable=True,
|
||||
),
|
||||
FieldSchema(
|
||||
name="file_path",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=512,
|
||||
nullable=True,
|
||||
),
|
||||
]
|
||||
description = "LightRAG entities vector storage"
|
||||
|
||||
elif "relationships" in self.namespace.lower():
|
||||
specific_fields = [
|
||||
FieldSchema(
|
||||
name="src_id", dtype=DataType.VARCHAR, max_length=256, nullable=True
|
||||
),
|
||||
FieldSchema(
|
||||
name="tgt_id", dtype=DataType.VARCHAR, max_length=256, nullable=True
|
||||
),
|
||||
FieldSchema(name="weight", dtype=DataType.DOUBLE, nullable=True),
|
||||
FieldSchema(
|
||||
name="file_path",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=512,
|
||||
nullable=True,
|
||||
),
|
||||
]
|
||||
description = "LightRAG relationships vector storage"
|
||||
|
||||
elif "chunks" in self.namespace.lower():
|
||||
specific_fields = [
|
||||
FieldSchema(
|
||||
name="full_doc_id",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=64,
|
||||
nullable=True,
|
||||
),
|
||||
FieldSchema(
|
||||
name="file_path",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=512,
|
||||
nullable=True,
|
||||
),
|
||||
]
|
||||
description = "LightRAG chunks vector storage"
|
||||
|
||||
else:
|
||||
# Default generic schema (backward compatibility)
|
||||
specific_fields = [
|
||||
FieldSchema(
|
||||
name="file_path",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=512,
|
||||
nullable=True,
|
||||
),
|
||||
]
|
||||
description = "LightRAG generic vector storage"
|
||||
|
||||
# Merge all fields
|
||||
all_fields = base_fields + specific_fields
|
||||
|
||||
return CollectionSchema(
|
||||
fields=all_fields,
|
||||
description=description,
|
||||
enable_dynamic_field=True, # Support dynamic fields
|
||||
)
|
||||
|
||||
def _get_index_params(self):
|
||||
"""Get IndexParams in a version-compatible way"""
|
||||
try:
|
||||
# Try to use client's prepare_index_params method (most common)
|
||||
if hasattr(self._client, "prepare_index_params"):
|
||||
return self._client.prepare_index_params()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
# Try to import IndexParams from different possible locations
|
||||
from pymilvus.client.prepare import IndexParams
|
||||
|
||||
return IndexParams()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from pymilvus.client.types import IndexParams
|
||||
|
||||
return IndexParams()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from pymilvus import IndexParams
|
||||
|
||||
return IndexParams()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# If all else fails, return None to use fallback method
|
||||
return None
|
||||
|
||||
def _create_vector_index_fallback(self):
|
||||
"""Fallback method to create vector index using direct API"""
|
||||
try:
|
||||
self._client.create_index(
|
||||
collection_name=self.namespace,
|
||||
field_name="vector",
|
||||
index_params={
|
||||
"index_type": "HNSW",
|
||||
"metric_type": "COSINE",
|
||||
"params": {"M": 16, "efConstruction": 256},
|
||||
},
|
||||
)
|
||||
logger.debug("Created vector index using fallback method")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create vector index using fallback method: {e}")
|
||||
|
||||
def _create_scalar_index_fallback(self, field_name: str, index_type: str):
|
||||
"""Fallback method to create scalar index using direct API"""
|
||||
# Skip unsupported index types
|
||||
if index_type == "SORTED":
|
||||
logger.info(
|
||||
f"Skipping SORTED index for {field_name} (not supported in this Milvus version)"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
self._client.create_index(
|
||||
collection_name=self.namespace,
|
||||
field_name=field_name,
|
||||
index_params={"index_type": index_type},
|
||||
)
|
||||
logger.debug(f"Created {field_name} index using fallback method")
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
f"Could not create {field_name} index using fallback method: {e}"
|
||||
)
|
||||
|
||||
def _create_indexes_after_collection(self):
|
||||
"""Create indexes after collection is created"""
|
||||
try:
|
||||
# Try to get IndexParams in a version-compatible way
|
||||
IndexParamsClass = self._get_index_params()
|
||||
|
||||
if IndexParamsClass is not None:
|
||||
# Use IndexParams approach if available
|
||||
try:
|
||||
# Create vector index first (required for most operations)
|
||||
vector_index = IndexParamsClass
|
||||
vector_index.add_index(
|
||||
field_name="vector",
|
||||
index_type="HNSW",
|
||||
metric_type="COSINE",
|
||||
params={"M": 16, "efConstruction": 256},
|
||||
)
|
||||
self._client.create_index(
|
||||
collection_name=self.namespace, index_params=vector_index
|
||||
)
|
||||
logger.debug("Created vector index using IndexParams")
|
||||
except Exception as e:
|
||||
logger.debug(f"IndexParams method failed for vector index: {e}")
|
||||
self._create_vector_index_fallback()
|
||||
|
||||
# Create scalar indexes based on namespace
|
||||
if "entities" in self.namespace.lower():
|
||||
# Create indexes for entity fields
|
||||
try:
|
||||
entity_name_index = self._get_index_params()
|
||||
entity_name_index.add_index(
|
||||
field_name="entity_name", index_type="INVERTED"
|
||||
)
|
||||
self._client.create_index(
|
||||
collection_name=self.namespace,
|
||||
index_params=entity_name_index,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"IndexParams method failed for entity_name: {e}")
|
||||
self._create_scalar_index_fallback("entity_name", "INVERTED")
|
||||
|
||||
try:
|
||||
entity_type_index = self._get_index_params()
|
||||
entity_type_index.add_index(
|
||||
field_name="entity_type", index_type="INVERTED"
|
||||
)
|
||||
self._client.create_index(
|
||||
collection_name=self.namespace,
|
||||
index_params=entity_type_index,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"IndexParams method failed for entity_type: {e}")
|
||||
self._create_scalar_index_fallback("entity_type", "INVERTED")
|
||||
|
||||
elif "relationships" in self.namespace.lower():
|
||||
# Create indexes for relationship fields
|
||||
try:
|
||||
src_id_index = self._get_index_params()
|
||||
src_id_index.add_index(
|
||||
field_name="src_id", index_type="INVERTED"
|
||||
)
|
||||
self._client.create_index(
|
||||
collection_name=self.namespace, index_params=src_id_index
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"IndexParams method failed for src_id: {e}")
|
||||
self._create_scalar_index_fallback("src_id", "INVERTED")
|
||||
|
||||
try:
|
||||
tgt_id_index = self._get_index_params()
|
||||
tgt_id_index.add_index(
|
||||
field_name="tgt_id", index_type="INVERTED"
|
||||
)
|
||||
self._client.create_index(
|
||||
collection_name=self.namespace, index_params=tgt_id_index
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"IndexParams method failed for tgt_id: {e}")
|
||||
self._create_scalar_index_fallback("tgt_id", "INVERTED")
|
||||
|
||||
elif "chunks" in self.namespace.lower():
|
||||
# Create indexes for chunk fields
|
||||
try:
|
||||
doc_id_index = self._get_index_params()
|
||||
doc_id_index.add_index(
|
||||
field_name="full_doc_id", index_type="INVERTED"
|
||||
)
|
||||
self._client.create_index(
|
||||
collection_name=self.namespace, index_params=doc_id_index
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"IndexParams method failed for full_doc_id: {e}")
|
||||
self._create_scalar_index_fallback("full_doc_id", "INVERTED")
|
||||
|
||||
# No common indexes needed
|
||||
|
||||
else:
|
||||
# Fallback to direct API calls if IndexParams is not available
|
||||
logger.info(
|
||||
f"IndexParams not available, using fallback methods for {self.namespace}"
|
||||
)
|
||||
|
||||
# Create vector index using fallback
|
||||
self._create_vector_index_fallback()
|
||||
|
||||
# Create scalar indexes using fallback
|
||||
if "entities" in self.namespace.lower():
|
||||
self._create_scalar_index_fallback("entity_name", "INVERTED")
|
||||
self._create_scalar_index_fallback("entity_type", "INVERTED")
|
||||
elif "relationships" in self.namespace.lower():
|
||||
self._create_scalar_index_fallback("src_id", "INVERTED")
|
||||
self._create_scalar_index_fallback("tgt_id", "INVERTED")
|
||||
elif "chunks" in self.namespace.lower():
|
||||
self._create_scalar_index_fallback("full_doc_id", "INVERTED")
|
||||
|
||||
logger.info(f"Created indexes for collection: {self.namespace}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create some indexes for {self.namespace}: {e}")
|
||||
|
||||
def _get_required_fields_for_namespace(self) -> dict:
|
||||
"""Get required core field definitions for current namespace"""
|
||||
|
||||
# Base fields (common to all types)
|
||||
base_fields = {
|
||||
"id": {"type": "VarChar", "is_primary": True},
|
||||
"vector": {"type": "FloatVector"},
|
||||
"created_at": {"type": "Int64"},
|
||||
}
|
||||
|
||||
# Add specific fields based on namespace
|
||||
if "entities" in self.namespace.lower():
|
||||
specific_fields = {
|
||||
"entity_name": {"type": "VarChar"},
|
||||
"entity_type": {"type": "VarChar"},
|
||||
"file_path": {"type": "VarChar"},
|
||||
}
|
||||
elif "relationships" in self.namespace.lower():
|
||||
specific_fields = {
|
||||
"src_id": {"type": "VarChar"},
|
||||
"tgt_id": {"type": "VarChar"},
|
||||
"weight": {"type": "Double"},
|
||||
"file_path": {"type": "VarChar"},
|
||||
}
|
||||
elif "chunks" in self.namespace.lower():
|
||||
specific_fields = {
|
||||
"full_doc_id": {"type": "VarChar"},
|
||||
"file_path": {"type": "VarChar"},
|
||||
}
|
||||
else:
|
||||
specific_fields = {
|
||||
"file_path": {"type": "VarChar"},
|
||||
}
|
||||
|
||||
return {**base_fields, **specific_fields}
|
||||
|
||||
def _is_field_compatible(self, existing_field: dict, expected_config: dict) -> bool:
|
||||
"""Check compatibility of a single field"""
|
||||
field_name = existing_field.get("name", "unknown")
|
||||
existing_type = existing_field.get("type")
|
||||
expected_type = expected_config.get("type")
|
||||
|
||||
logger.debug(
|
||||
f"Checking field '{field_name}': existing_type={existing_type} (type={type(existing_type)}), expected_type={expected_type}"
|
||||
)
|
||||
|
||||
# Convert DataType enum values to string names if needed
|
||||
original_existing_type = existing_type
|
||||
if hasattr(existing_type, "name"):
|
||||
existing_type = existing_type.name
|
||||
logger.debug(
|
||||
f"Converted enum to name: {original_existing_type} -> {existing_type}"
|
||||
)
|
||||
elif isinstance(existing_type, int):
|
||||
# Map common Milvus internal type codes to type names for backward compatibility
|
||||
type_mapping = {
|
||||
21: "VarChar",
|
||||
101: "FloatVector",
|
||||
5: "Int64",
|
||||
9: "Double",
|
||||
}
|
||||
mapped_type = type_mapping.get(existing_type, str(existing_type))
|
||||
logger.debug(f"Mapped numeric type: {existing_type} -> {mapped_type}")
|
||||
existing_type = mapped_type
|
||||
|
||||
# Normalize type names for comparison
|
||||
type_aliases = {
|
||||
"VARCHAR": "VarChar",
|
||||
"String": "VarChar",
|
||||
"FLOAT_VECTOR": "FloatVector",
|
||||
"INT64": "Int64",
|
||||
"BigInt": "Int64",
|
||||
"DOUBLE": "Double",
|
||||
"Float": "Double",
|
||||
}
|
||||
|
||||
original_existing = existing_type
|
||||
original_expected = expected_type
|
||||
existing_type = type_aliases.get(existing_type, existing_type)
|
||||
expected_type = type_aliases.get(expected_type, expected_type)
|
||||
|
||||
if original_existing != existing_type or original_expected != expected_type:
|
||||
logger.debug(
|
||||
f"Applied aliases: {original_existing} -> {existing_type}, {original_expected} -> {expected_type}"
|
||||
)
|
||||
|
||||
# Basic type compatibility check
|
||||
type_compatible = existing_type == expected_type
|
||||
logger.debug(
|
||||
f"Type compatibility for '{field_name}': {existing_type} == {expected_type} -> {type_compatible}"
|
||||
)
|
||||
|
||||
if not type_compatible:
|
||||
logger.warning(
|
||||
f"Type mismatch for field '{field_name}': expected {expected_type}, got {existing_type}"
|
||||
)
|
||||
return False
|
||||
|
||||
# Primary key check - be more flexible about primary key detection
|
||||
if expected_config.get("is_primary"):
|
||||
# Check multiple possible field names for primary key status
|
||||
is_primary = (
|
||||
existing_field.get("is_primary_key", False)
|
||||
or existing_field.get("is_primary", False)
|
||||
or existing_field.get("primary_key", False)
|
||||
)
|
||||
logger.debug(
|
||||
f"Primary key check for '{field_name}': expected=True, actual={is_primary}"
|
||||
)
|
||||
logger.debug(f"Raw field data for '{field_name}': {existing_field}")
|
||||
|
||||
# For ID field, be more lenient - if it's the ID field, assume it should be primary
|
||||
if field_name == "id" and not is_primary:
|
||||
logger.info(
|
||||
f"ID field '{field_name}' not marked as primary in existing collection, but treating as compatible"
|
||||
)
|
||||
# Don't fail for ID field primary key mismatch
|
||||
elif not is_primary:
|
||||
logger.warning(
|
||||
f"Primary key mismatch for field '{field_name}': expected primary key, but field is not primary"
|
||||
)
|
||||
return False
|
||||
|
||||
logger.debug(f"Field '{field_name}' is compatible")
|
||||
return True
|
||||
|
||||
def _check_vector_dimension(self, collection_info: dict):
|
||||
"""Check vector dimension compatibility"""
|
||||
current_dimension = self.embedding_func.embedding_dim
|
||||
|
||||
# Find vector field dimension
|
||||
for field in collection_info.get("fields", []):
|
||||
if field.get("name") == "vector":
|
||||
field_type = field.get("type")
|
||||
if field_type in ["FloatVector", "FLOAT_VECTOR"]:
|
||||
existing_dimension = field.get("params", {}).get("dim")
|
||||
|
||||
if existing_dimension != current_dimension:
|
||||
raise ValueError(
|
||||
f"Vector dimension mismatch for collection '{self.namespace}': "
|
||||
f"existing={existing_dimension}, current={current_dimension}"
|
||||
)
|
||||
|
||||
logger.debug(f"Vector dimension check passed: {current_dimension}")
|
||||
return
|
||||
|
||||
# If no vector field found, this might be an old collection created with simple schema
|
||||
logger.warning(
|
||||
f"Vector field not found in collection '{self.namespace}'. This might be an old collection created with simple schema."
|
||||
)
|
||||
logger.warning("Consider recreating the collection for optimal performance.")
|
||||
return
|
||||
|
||||
def _check_schema_compatibility(self, collection_info: dict):
|
||||
"""Check schema field compatibility"""
|
||||
existing_fields = {
|
||||
field["name"]: field for field in collection_info.get("fields", [])
|
||||
}
|
||||
|
||||
# Check if this is an old collection created with simple schema
|
||||
has_vector_field = any(
|
||||
field.get("name") == "vector" for field in collection_info.get("fields", [])
|
||||
)
|
||||
|
||||
if not has_vector_field:
|
||||
logger.warning(
|
||||
f"Collection {self.namespace} appears to be created with old simple schema (no vector field)"
|
||||
)
|
||||
logger.warning(
|
||||
"This collection will work but may have suboptimal performance"
|
||||
)
|
||||
logger.warning("Consider recreating the collection for optimal performance")
|
||||
return
|
||||
|
||||
# For collections with vector field, check basic compatibility
|
||||
# Only check for critical incompatibilities, not missing optional fields
|
||||
critical_fields = {"id": {"type": "VarChar", "is_primary": True}}
|
||||
|
||||
incompatible_fields = []
|
||||
|
||||
for field_name, expected_config in critical_fields.items():
|
||||
if field_name in existing_fields:
|
||||
existing_field = existing_fields[field_name]
|
||||
if not self._is_field_compatible(existing_field, expected_config):
|
||||
incompatible_fields.append(
|
||||
f"{field_name}: expected {expected_config['type']}, "
|
||||
f"got {existing_field.get('type')}"
|
||||
)
|
||||
|
||||
if incompatible_fields:
|
||||
raise ValueError(
|
||||
f"Critical schema incompatibility in collection '{self.namespace}': {incompatible_fields}"
|
||||
)
|
||||
|
||||
# Get all expected fields for informational purposes
|
||||
expected_fields = self._get_required_fields_for_namespace()
|
||||
missing_fields = [
|
||||
field for field in expected_fields if field not in existing_fields
|
||||
]
|
||||
|
||||
if missing_fields:
|
||||
logger.info(
|
||||
f"Collection {self.namespace} missing optional fields: {missing_fields}"
|
||||
)
|
||||
logger.info(
|
||||
"These fields would be available in a newly created collection for better performance"
|
||||
)
|
||||
|
||||
logger.debug(f"Schema compatibility check passed for {self.namespace}")
|
||||
|
||||
def _validate_collection_compatibility(self):
|
||||
"""Validate existing collection's dimension and schema compatibility"""
|
||||
try:
|
||||
collection_info = self._client.describe_collection(self.namespace)
|
||||
|
||||
# 1. Check vector dimension
|
||||
self._check_vector_dimension(collection_info)
|
||||
|
||||
# 2. Check schema compatibility
|
||||
self._check_schema_compatibility(collection_info)
|
||||
|
||||
logger.info(f"Collection {self.namespace} compatibility validation passed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Collection compatibility validation failed for {self.namespace}: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
def _create_collection_if_not_exist(self):
|
||||
"""Create collection if not exists and check existing collection compatibility"""
|
||||
|
||||
try:
|
||||
# First, list all collections to see what actually exists
|
||||
try:
|
||||
all_collections = self._client.list_collections()
|
||||
logger.debug(f"All collections in database: {all_collections}")
|
||||
except Exception as list_error:
|
||||
logger.warning(f"Could not list collections: {list_error}")
|
||||
all_collections = []
|
||||
|
||||
# Check if our specific collection exists
|
||||
collection_exists = self._client.has_collection(self.namespace)
|
||||
logger.info(
|
||||
f"Collection '{self.namespace}' exists check: {collection_exists}"
|
||||
)
|
||||
|
||||
if collection_exists:
|
||||
# Double-check by trying to describe the collection
|
||||
try:
|
||||
self._client.describe_collection(self.namespace)
|
||||
logger.info(
|
||||
f"Collection '{self.namespace}' confirmed to exist, validating compatibility..."
|
||||
)
|
||||
self._validate_collection_compatibility()
|
||||
return
|
||||
except Exception as describe_error:
|
||||
logger.warning(
|
||||
f"Collection '{self.namespace}' exists but cannot be described: {describe_error}"
|
||||
)
|
||||
logger.info(
|
||||
"Treating as if collection doesn't exist and creating new one..."
|
||||
)
|
||||
# Fall through to creation logic
|
||||
|
||||
# Collection doesn't exist, create new collection
|
||||
logger.info(f"Creating new collection: {self.namespace}")
|
||||
schema = self._create_schema_for_namespace()
|
||||
|
||||
# Create collection with schema only first
|
||||
self._client.create_collection(
|
||||
collection_name=self.namespace, schema=schema
|
||||
)
|
||||
|
||||
# Then create indexes
|
||||
self._create_indexes_after_collection()
|
||||
|
||||
logger.info(f"Successfully created Milvus collection: {self.namespace}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in _create_collection_if_not_exist for {self.namespace}: {e}"
|
||||
)
|
||||
|
||||
# If there's any error, try to force create the collection
|
||||
logger.info(f"Attempting to force create collection {self.namespace}...")
|
||||
try:
|
||||
# Try to drop the collection first if it exists in a bad state
|
||||
try:
|
||||
if self._client.has_collection(self.namespace):
|
||||
logger.info(
|
||||
f"Dropping potentially corrupted collection {self.namespace}"
|
||||
)
|
||||
self._client.drop_collection(self.namespace)
|
||||
except Exception as drop_error:
|
||||
logger.warning(
|
||||
f"Could not drop collection {self.namespace}: {drop_error}"
|
||||
)
|
||||
|
||||
# Create fresh collection
|
||||
schema = self._create_schema_for_namespace()
|
||||
self._client.create_collection(
|
||||
collection_name=self.namespace, schema=schema
|
||||
)
|
||||
self._create_indexes_after_collection()
|
||||
logger.info(f"Successfully force-created collection {self.namespace}")
|
||||
|
||||
except Exception as create_error:
|
||||
logger.error(
|
||||
f"Failed to force-create collection {self.namespace}: {create_error}"
|
||||
)
|
||||
raise
|
||||
|
||||
def __post_init__(self):
|
||||
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
||||
@ -43,6 +632,10 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||
)
|
||||
self.cosine_better_than_threshold = cosine_threshold
|
||||
|
||||
# Ensure created_at is in meta_fields
|
||||
if "created_at" not in self.meta_fields:
|
||||
self.meta_fields.add("created_at")
|
||||
|
||||
self._client = MilvusClient(
|
||||
uri=os.environ.get(
|
||||
"MILVUS_URI",
|
||||
@ -68,11 +661,9 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||
),
|
||||
)
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
MilvusVectorDBStorage.create_collection_if_not_exist(
|
||||
self._client,
|
||||
self.namespace,
|
||||
dimension=self.embedding_func.embedding_dim,
|
||||
)
|
||||
|
||||
# Create collection and check compatibility
|
||||
self._create_collection_if_not_exist()
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.debug(f"Inserting {len(data)} to {self.namespace}")
|
||||
@ -112,23 +703,25 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||
embedding = await self.embedding_func(
|
||||
[query], _priority=5
|
||||
) # higher priority for query
|
||||
|
||||
# Include all meta_fields (created_at is now always included)
|
||||
output_fields = list(self.meta_fields)
|
||||
|
||||
results = self._client.search(
|
||||
collection_name=self.namespace,
|
||||
data=embedding,
|
||||
limit=top_k,
|
||||
output_fields=list(self.meta_fields) + ["created_at"],
|
||||
output_fields=output_fields,
|
||||
search_params={
|
||||
"metric_type": "COSINE",
|
||||
"params": {"radius": self.cosine_better_than_threshold},
|
||||
},
|
||||
)
|
||||
print(results)
|
||||
return [
|
||||
{
|
||||
**dp["entity"],
|
||||
"id": dp["id"],
|
||||
"distance": dp["distance"],
|
||||
# created_at is requested in output_fields, so it should be a top-level key in the result dict (dp)
|
||||
"created_at": dp.get("created_at"),
|
||||
}
|
||||
for dp in results[0]
|
||||
@ -232,20 +825,19 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||
The vector data if found, or None if not found
|
||||
"""
|
||||
try:
|
||||
# Include all meta_fields (created_at is now always included) plus id
|
||||
output_fields = list(self.meta_fields) + ["id"]
|
||||
|
||||
# Query Milvus for a specific ID
|
||||
result = self._client.query(
|
||||
collection_name=self.namespace,
|
||||
filter=f'id == "{id}"',
|
||||
output_fields=list(self.meta_fields) + ["id", "created_at"],
|
||||
output_fields=output_fields,
|
||||
)
|
||||
|
||||
if not result or len(result) == 0:
|
||||
return None
|
||||
|
||||
# Ensure the result contains created_at field
|
||||
if "created_at" not in result[0]:
|
||||
result[0]["created_at"] = None
|
||||
|
||||
return result[0]
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for ID {id}: {e}")
|
||||
@ -264,6 +856,9 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||
return []
|
||||
|
||||
try:
|
||||
# Include all meta_fields (created_at is now always included) plus id
|
||||
output_fields = list(self.meta_fields) + ["id"]
|
||||
|
||||
# Prepare the ID filter expression
|
||||
id_list = '", "'.join(ids)
|
||||
filter_expr = f'id in ["{id_list}"]'
|
||||
@ -272,14 +867,9 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||
result = self._client.query(
|
||||
collection_name=self.namespace,
|
||||
filter=filter_expr,
|
||||
output_fields=list(self.meta_fields) + ["id", "created_at"],
|
||||
output_fields=output_fields,
|
||||
)
|
||||
|
||||
# Ensure each result contains created_at field
|
||||
for item in result:
|
||||
if "created_at" not in item:
|
||||
item["created_at"] = None
|
||||
|
||||
return result or []
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
||||
@ -301,11 +891,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||
self._client.drop_collection(self.namespace)
|
||||
|
||||
# Recreate the collection
|
||||
MilvusVectorDBStorage.create_collection_if_not_exist(
|
||||
self._client,
|
||||
self.namespace,
|
||||
dimension=self.embedding_func.embedding_dim,
|
||||
)
|
||||
self._create_collection_if_not_exist()
|
||||
|
||||
logger.info(
|
||||
f"Process {os.getpid()} drop Milvus collection {self.namespace}"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user