LightRAG/lightrag/base.py
Magic_yuan 650b8e38b7 feat(lightrag): Add document status tracking and checkpoint support
功能(lightrag): 添加文档状态跟踪和断点续传支持

- Add DocStatus enum and DocProcessingStatus class for document processing state management
- 添加 DocStatus 枚举和 DocProcessingStatus 类用于文档处理状态管理

- Implement JsonDocStatusStorage for persistent status storage
- 实现 JsonDocStatusStorage 用于持久化状态存储

- Add document-level deduplication in batch processing
- 在批处理中添加文档级别的去重功能

- Add checkpoint support in ainsert method for resumable document processing
- 在 ainsert 方法中添加断点续传支持,实现可恢复的文档处理

- Add status query methods for monitoring processing progress
- 添加状态查询方法用于监控处理进度

- Update LightRAG initialization to support document status tracking
- 更新 LightRAG 初始化以支持文档状态跟踪
2024-12-28 00:11:25 +08:00

172 lines
5.1 KiB
Python

from dataclasses import dataclass, field
from typing import TypedDict, Union, Literal, Generic, TypeVar, Optional, Dict, Any
from enum import Enum
import numpy as np
from .utils import EmbeddingFunc
TextChunkSchema = TypedDict(
"TextChunkSchema",
{"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int},
)
T = TypeVar("T")
@dataclass
class QueryParam:
mode: Literal["local", "global", "hybrid", "naive"] = "global"
only_need_context: bool = False
only_need_prompt: bool = False
response_type: str = "Multiple Paragraphs"
stream: bool = False
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
top_k: int = 60
# Number of document chunks to retrieve.
# top_n: int = 10
# Number of tokens for the original chunks.
max_token_for_text_unit: int = 4000
# Number of tokens for the relationship descriptions
max_token_for_global_context: int = 4000
# Number of tokens for the entity descriptions
max_token_for_local_context: int = 4000
@dataclass
class StorageNameSpace:
namespace: str
global_config: dict
async def index_done_callback(self):
"""commit the storage operations after indexing"""
pass
async def query_done_callback(self):
"""commit the storage operations after querying"""
pass
@dataclass
class BaseVectorStorage(StorageNameSpace):
embedding_func: EmbeddingFunc
meta_fields: set = field(default_factory=set)
async def query(self, query: str, top_k: int) -> list[dict]:
raise NotImplementedError
async def upsert(self, data: dict[str, dict]):
"""Use 'content' field from value for embedding, use key as id.
If embedding_func is None, use 'embedding' field from value
"""
raise NotImplementedError
@dataclass
class BaseKVStorage(Generic[T], StorageNameSpace):
embedding_func: EmbeddingFunc
async def all_keys(self) -> list[str]:
raise NotImplementedError
async def get_by_id(self, id: str) -> Union[T, None]:
raise NotImplementedError
async def get_by_ids(
self, ids: list[str], fields: Union[set[str], None] = None
) -> list[Union[T, None]]:
raise NotImplementedError
async def filter_keys(self, data: list[str]) -> set[str]:
"""return un-exist keys"""
raise NotImplementedError
async def upsert(self, data: dict[str, T]):
raise NotImplementedError
async def drop(self):
raise NotImplementedError
@dataclass
class BaseGraphStorage(StorageNameSpace):
embedding_func: EmbeddingFunc = None
async def has_node(self, node_id: str) -> bool:
raise NotImplementedError
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
raise NotImplementedError
async def node_degree(self, node_id: str) -> int:
raise NotImplementedError
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
raise NotImplementedError
async def get_node(self, node_id: str) -> Union[dict, None]:
raise NotImplementedError
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
raise NotImplementedError
async def get_node_edges(
self, source_node_id: str
) -> Union[list[tuple[str, str]], None]:
raise NotImplementedError
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
raise NotImplementedError
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
raise NotImplementedError
async def delete_node(self, node_id: str):
raise NotImplementedError
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
raise NotImplementedError("Node embedding is not used in lightrag.")
class DocStatus(str, Enum):
"""Document processing status enum"""
PENDING = "pending"
PROCESSING = "processing"
PROCESSED = "processed"
FAILED = "failed"
@dataclass
class DocProcessingStatus:
"""Document processing status data structure"""
content_summary: str # First 100 chars of document content
content_length: int # Total length of document
status: DocStatus # Current processing status
created_at: str # ISO format timestamp
updated_at: str # ISO format timestamp
chunks_count: Optional[int] = None # Number of chunks after splitting
error: Optional[str] = None # Error message if failed
metadata: Dict[str, Any] = field(default_factory=dict) # Additional metadata
class DocStatusStorage(BaseKVStorage):
"""Base class for document status storage"""
async def get_status_counts(self) -> Dict[str, int]:
"""Get counts of documents in each status"""
raise NotImplementedError
async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
"""Get all failed documents"""
raise NotImplementedError
async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]:
"""Get all pending documents"""
raise NotImplementedError