LightRAG/lightrag/base.py
Magic_yuan aaaf617451 feat(lightrag): Implement mix search mode combining knowledge graph and vector retrieval
- Add 'mix' mode to QueryParam for hybrid search functionality
- Implement mix_kg_vector_query to combine knowledge graph and vector search results
- Update LightRAG class to handle 'mix' mode queries
- Enhance README with examples and explanations for the new mix search mode
- Introduce new prompt structure for generating responses based on combined search results
2024-12-28 11:56:28 +08:00

172 lines
5.1 KiB
Python

from dataclasses import dataclass, field
from typing import TypedDict, Union, Literal, Generic, TypeVar, Optional, Dict, Any
from enum import Enum
import numpy as np
from .utils import EmbeddingFunc
TextChunkSchema = TypedDict(
"TextChunkSchema",
{"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int},
)
T = TypeVar("T")
@dataclass
class QueryParam:
mode: Literal["local", "global", "hybrid", "naive", "mix"] = "global"
only_need_context: bool = False
only_need_prompt: bool = False
response_type: str = "Multiple Paragraphs"
stream: bool = False
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
top_k: int = 60
# Number of document chunks to retrieve.
# top_n: int = 10
# Number of tokens for the original chunks.
max_token_for_text_unit: int = 4000
# Number of tokens for the relationship descriptions
max_token_for_global_context: int = 4000
# Number of tokens for the entity descriptions
max_token_for_local_context: int = 4000
@dataclass
class StorageNameSpace:
namespace: str
global_config: dict
async def index_done_callback(self):
"""commit the storage operations after indexing"""
pass
async def query_done_callback(self):
"""commit the storage operations after querying"""
pass
@dataclass
class BaseVectorStorage(StorageNameSpace):
embedding_func: EmbeddingFunc
meta_fields: set = field(default_factory=set)
async def query(self, query: str, top_k: int) -> list[dict]:
raise NotImplementedError
async def upsert(self, data: dict[str, dict]):
"""Use 'content' field from value for embedding, use key as id.
If embedding_func is None, use 'embedding' field from value
"""
raise NotImplementedError
@dataclass
class BaseKVStorage(Generic[T], StorageNameSpace):
embedding_func: EmbeddingFunc
async def all_keys(self) -> list[str]:
raise NotImplementedError
async def get_by_id(self, id: str) -> Union[T, None]:
raise NotImplementedError
async def get_by_ids(
self, ids: list[str], fields: Union[set[str], None] = None
) -> list[Union[T, None]]:
raise NotImplementedError
async def filter_keys(self, data: list[str]) -> set[str]:
"""return un-exist keys"""
raise NotImplementedError
async def upsert(self, data: dict[str, T]):
raise NotImplementedError
async def drop(self):
raise NotImplementedError
@dataclass
class BaseGraphStorage(StorageNameSpace):
embedding_func: EmbeddingFunc = None
async def has_node(self, node_id: str) -> bool:
raise NotImplementedError
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
raise NotImplementedError
async def node_degree(self, node_id: str) -> int:
raise NotImplementedError
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
raise NotImplementedError
async def get_node(self, node_id: str) -> Union[dict, None]:
raise NotImplementedError
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
raise NotImplementedError
async def get_node_edges(
self, source_node_id: str
) -> Union[list[tuple[str, str]], None]:
raise NotImplementedError
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
raise NotImplementedError
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
raise NotImplementedError
async def delete_node(self, node_id: str):
raise NotImplementedError
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
raise NotImplementedError("Node embedding is not used in lightrag.")
class DocStatus(str, Enum):
"""Document processing status enum"""
PENDING = "pending"
PROCESSING = "processing"
PROCESSED = "processed"
FAILED = "failed"
@dataclass
class DocProcessingStatus:
"""Document processing status data structure"""
content_summary: str # First 100 chars of document content
content_length: int # Total length of document
status: DocStatus # Current processing status
created_at: str # ISO format timestamp
updated_at: str # ISO format timestamp
chunks_count: Optional[int] = None # Number of chunks after splitting
error: Optional[str] = None # Error message if failed
metadata: Dict[str, Any] = field(default_factory=dict) # Additional metadata
class DocStatusStorage(BaseKVStorage):
"""Base class for document status storage"""
async def get_status_counts(self) -> Dict[str, int]:
"""Get counts of documents in each status"""
raise NotImplementedError
async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
"""Get all failed documents"""
raise NotImplementedError
async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]:
"""Get all pending documents"""
raise NotImplementedError