2025-01-01 22:43:59 +08:00
import asyncio
2025-01-27 09:39:39 +01:00
import json
2025-01-01 22:43:59 +08:00
import os
2025-01-27 09:39:39 +01:00
import time
2025-02-19 04:53:15 +08:00
from dataclasses import dataclass , field
2025-02-19 13:31:30 +01:00
from typing import Any , Union , final
2025-02-09 19:51:05 +01:00
import numpy as np
2025-02-19 03:46:18 +08:00
import configparser
2025-01-27 23:21:34 +08:00
2025-03-04 15:50:53 +08:00
from lightrag . types import KnowledgeGraph , KnowledgeGraphNode , KnowledgeGraphEdge
2025-02-16 13:53:59 +01:00
2025-01-27 09:39:39 +01:00
from tenacity import (
retry ,
retry_if_exception_type ,
stop_after_attempt ,
wait_exponential ,
)
2025-01-01 22:43:59 +08:00
2025-01-27 09:39:39 +01:00
from . . base import (
2025-02-09 19:51:05 +01:00
BaseGraphStorage ,
2025-01-27 09:39:39 +01:00
BaseKVStorage ,
BaseVectorStorage ,
DocProcessingStatus ,
2025-02-09 19:51:05 +01:00
DocStatus ,
DocStatusStorage ,
2025-01-27 09:39:39 +01:00
)
2025-02-08 16:05:59 +08:00
from . . namespace import NameSpace , is_namespace
2025-02-09 19:51:05 +01:00
from . . utils import logger
2025-01-01 22:43:59 +08:00
2025-02-16 15:08:50 +01:00
import pipmaster as pm
if not pm . is_installed ( " asyncpg " ) :
pm . install ( " asyncpg " )
2025-03-01 16:23:34 +08:00
import asyncpg # type: ignore
from asyncpg import Pool # type: ignore
2025-01-01 22:43:59 +08:00
2025-04-17 01:28:22 +08:00
from dotenv import load_dotenv
# use the .env that is inside the current folder
# allows to use different .env file for each lightrag instance
# the OS environment variables take precedence over the .env file
load_dotenv ( dotenv_path = " .env " , override = False )
2025-04-03 04:10:20 +08:00
# Get maximum number of graph nodes from environment variable, default is 1000
MAX_GRAPH_NODES = int ( os . getenv ( " MAX_GRAPH_NODES " , 1000 ) )
2025-02-19 20:50:39 +01:00
2025-01-27 09:39:39 +01:00
class PostgreSQLDB :
2025-02-19 13:31:30 +01:00
def __init__ ( self , config : dict [ str , Any ] , * * kwargs : Any ) :
2025-01-27 09:39:39 +01:00
self . host = config . get ( " host " , " localhost " )
self . port = config . get ( " port " , 5432 )
self . user = config . get ( " user " , " postgres " )
self . password = config . get ( " password " , None )
self . database = config . get ( " database " , " postgres " )
self . workspace = config . get ( " workspace " , " default " )
self . max = 12
self . increment = 1
2025-02-19 13:31:30 +01:00
self . pool : Pool | None = None
2025-01-01 22:43:59 +08:00
2025-01-27 09:39:39 +01:00
if self . user is None or self . password is None or self . database is None :
2025-02-20 13:09:33 +01:00
raise ValueError ( " Missing database user, password, or database " )
2025-01-01 22:43:59 +08:00
2025-01-27 09:39:39 +01:00
async def initdb ( self ) :
try :
2025-02-19 13:31:30 +01:00
self . pool = await asyncpg . create_pool ( # type: ignore
2025-01-27 09:39:39 +01:00
user = self . user ,
password = self . password ,
database = self . database ,
host = self . host ,
port = self . port ,
min_size = 1 ,
max_size = self . max ,
)
logger . info (
2025-02-19 13:31:30 +01:00
f " PostgreSQL, Connected to database at { self . host } : { self . port } / { self . database } "
2025-01-27 09:39:39 +01:00
)
except Exception as e :
logger . error (
2025-02-19 13:31:30 +01:00
f " PostgreSQL, Failed to connect database at { self . host } : { self . port } / { self . database } , Got: { e } "
2025-01-27 09:39:39 +01:00
)
raise
2025-01-01 22:43:59 +08:00
2025-02-19 14:26:46 +01:00
@staticmethod
async def configure_age ( connection : asyncpg . Connection , graph_name : str ) - > None :
""" Set the Apache AGE environment and creates a graph if it does not exist.
This method :
- Sets the PostgreSQL ` search_path ` to include ` ag_catalog ` , ensuring that Apache AGE functions can be used without specifying the schema .
- Attempts to create a new graph with the provided ` graph_name ` if it does not already exist .
- Silently ignores errors related to the graph already existing .
"""
try :
await connection . execute ( # type: ignore
' SET search_path = ag_catalog, " $user " , public '
)
await connection . execute ( # type: ignore
f " select create_graph( ' { graph_name } ' ) "
)
except (
asyncpg . exceptions . InvalidSchemaNameError ,
asyncpg . exceptions . UniqueViolationError ,
) :
pass
2025-02-19 13:31:30 +01:00
2025-01-27 09:39:39 +01:00
async def check_tables ( self ) :
for k , v in TABLES . items ( ) :
2025-01-01 22:43:59 +08:00
try :
2025-02-09 19:51:05 +01:00
await self . query ( f " SELECT 1 FROM { k } LIMIT 1 " )
2025-02-19 14:26:46 +01:00
except Exception :
2025-01-27 09:39:39 +01:00
try :
2025-02-19 13:31:30 +01:00
logger . info ( f " PostgreSQL, Try Creating table { k } in database " )
2025-01-27 09:39:39 +01:00
await self . execute ( v [ " ddl " ] )
2025-02-19 14:26:46 +01:00
logger . info (
f " PostgreSQL, Creation success table { k } in PostgreSQL database "
)
2025-01-27 09:39:39 +01:00
except Exception as e :
2025-02-19 13:31:30 +01:00
logger . error (
f " PostgreSQL, Failed to create table { k } in database, Please verify the connection with PostgreSQL database, Got: { e } "
)
raise e
2025-04-03 17:31:01 +08:00
2025-04-03 17:29:52 +08:00
# Create index for id column in each table
try :
index_name = f " idx_ { k . lower ( ) } _id "
check_index_sql = f """
2025-04-03 17:31:01 +08:00
SELECT 1 FROM pg_indexes
WHERE indexname = ' {index_name} '
2025-04-03 17:29:52 +08:00
AND tablename = ' { k.lower()} '
"""
index_exists = await self . query ( check_index_sql )
2025-04-03 17:31:01 +08:00
2025-04-03 17:29:52 +08:00
if not index_exists :
create_index_sql = f " CREATE INDEX { index_name } ON { k } (id) "
logger . info ( f " PostgreSQL, Creating index { index_name } on table { k } " )
await self . execute ( create_index_sql )
except Exception as e :
2025-04-03 17:31:01 +08:00
logger . error (
f " PostgreSQL, Failed to create index on table { k } , Got: { e } "
)
2025-01-27 09:39:39 +01:00
async def query (
self ,
sql : str ,
2025-02-19 13:31:30 +01:00
params : dict [ str , Any ] | None = None ,
2025-01-27 09:39:39 +01:00
multirows : bool = False ,
2025-02-19 14:26:46 +01:00
with_age : bool = False ,
graph_name : str | None = None ,
2025-02-19 13:31:30 +01:00
) - > dict [ str , Any ] | None | list [ dict [ str , Any ] ] :
2025-04-16 17:55:49 +08:00
# start_time = time.time()
# logger.info(f"PostgreSQL, Querying:\n{sql}")
2025-02-19 13:31:30 +01:00
async with self . pool . acquire ( ) as connection : # type: ignore
2025-02-19 14:26:46 +01:00
if with_age and graph_name :
await self . configure_age ( connection , graph_name ) # type: ignore
elif with_age and not graph_name :
raise ValueError ( " Graph name is required when with_age is True " )
2025-01-27 09:39:39 +01:00
try :
if params :
rows = await connection . fetch ( sql , * params . values ( ) )
else :
rows = await connection . fetch ( sql )
if multirows :
if rows :
columns = [ col for col in rows [ 0 ] . keys ( ) ]
data = [ dict ( zip ( columns , row ) ) for row in rows ]
else :
data = [ ]
else :
if rows :
columns = rows [ 0 ] . keys ( )
data = dict ( zip ( columns , rows [ 0 ] ) )
else :
data = None
2025-04-16 17:55:49 +08:00
# query_time = time.time() - start_time
# logger.info(f"PostgreSQL, Query result len: {len(data)}")
# logger.info(f"PostgreSQL, Query execution time: {query_time:.4f}s")
2025-01-27 09:39:39 +01:00
return data
except Exception as e :
2025-02-19 14:26:46 +01:00
logger . error ( f " PostgreSQL database, error: { e } " )
2025-01-27 09:39:39 +01:00
raise
async def execute (
self ,
sql : str ,
2025-02-19 13:31:30 +01:00
data : dict [ str , Any ] | None = None ,
2025-01-27 09:39:39 +01:00
upsert : bool = False ,
2025-02-19 14:26:46 +01:00
with_age : bool = False ,
graph_name : str | None = None ,
2025-01-27 09:39:39 +01:00
) :
try :
2025-02-19 13:31:30 +01:00
async with self . pool . acquire ( ) as connection : # type: ignore
2025-02-19 15:09:41 +01:00
if with_age and graph_name :
await self . configure_age ( connection , graph_name ) # type: ignore
elif with_age and not graph_name :
raise ValueError ( " Graph name is required when with_age is True " )
2025-01-27 09:39:39 +01:00
if data is None :
2025-02-19 13:31:30 +01:00
await connection . execute ( sql ) # type: ignore
2025-01-27 09:39:39 +01:00
else :
2025-02-19 13:31:30 +01:00
await connection . execute ( sql , * data . values ( ) ) # type: ignore
2025-01-27 09:39:39 +01:00
except (
asyncpg . exceptions . UniqueViolationError ,
asyncpg . exceptions . DuplicateTableError ,
) as e :
2025-02-20 15:09:43 +01:00
if upsert :
print ( " Key value duplicate, but upsert succeeded. " )
else :
logger . error ( f " Upsert error: { e } " )
2025-01-27 09:39:39 +01:00
except Exception as e :
2025-02-20 15:09:43 +01:00
logger . error ( f " PostgreSQL database, \n sql: { sql } , \n data: { data } , \n error: { e } " )
2025-01-27 09:39:39 +01:00
raise
2025-02-19 03:46:18 +08:00
class ClientManager :
2025-02-19 13:31:30 +01:00
_instances : dict [ str , Any ] = { " db " : None , " ref_count " : 0 }
2025-02-19 03:46:18 +08:00
_lock = asyncio . Lock ( )
@staticmethod
2025-02-19 13:31:30 +01:00
def get_config ( ) - > dict [ str , Any ] :
2025-02-19 03:46:18 +08:00
config = configparser . ConfigParser ( )
config . read ( " config.ini " , " utf-8 " )
return {
" host " : os . environ . get (
" POSTGRES_HOST " ,
config . get ( " postgres " , " host " , fallback = " localhost " ) ,
) ,
" port " : os . environ . get (
" POSTGRES_PORT " , config . get ( " postgres " , " port " , fallback = 5432 )
) ,
" user " : os . environ . get (
" POSTGRES_USER " , config . get ( " postgres " , " user " , fallback = None )
) ,
" password " : os . environ . get (
" POSTGRES_PASSWORD " ,
config . get ( " postgres " , " password " , fallback = None ) ,
) ,
" database " : os . environ . get (
" POSTGRES_DATABASE " ,
config . get ( " postgres " , " database " , fallback = None ) ,
) ,
" workspace " : os . environ . get (
" POSTGRES_WORKSPACE " ,
config . get ( " postgres " , " workspace " , fallback = " default " ) ,
) ,
}
@classmethod
async def get_client ( cls ) - > PostgreSQLDB :
async with cls . _lock :
if cls . _instances [ " db " ] is None :
config = ClientManager . get_config ( )
db = PostgreSQLDB ( config )
await db . initdb ( )
await db . check_tables ( )
cls . _instances [ " db " ] = db
cls . _instances [ " ref_count " ] = 0
cls . _instances [ " ref_count " ] + = 1
return cls . _instances [ " db " ]
@classmethod
async def release_client ( cls , db : PostgreSQLDB ) :
async with cls . _lock :
if db is not None :
if db is cls . _instances [ " db " ] :
cls . _instances [ " ref_count " ] - = 1
if cls . _instances [ " ref_count " ] == 0 :
await db . pool . close ( )
logger . info ( " Closed PostgreSQL database connection pool " )
cls . _instances [ " db " ] = None
else :
await db . pool . close ( )
2025-02-16 14:38:09 +01:00
@final
2025-01-27 09:39:39 +01:00
@dataclass
class PGKVStorage ( BaseKVStorage ) :
2025-02-19 04:55:59 +08:00
db : PostgreSQLDB = field ( default = None )
2025-01-27 09:39:39 +01:00
def __post_init__ ( self ) :
self . _max_batch_size = self . global_config [ " embedding_batch_num " ]
2025-02-19 03:46:18 +08:00
async def initialize ( self ) :
2025-02-19 04:53:15 +08:00
if self . db is None :
2025-02-19 03:46:18 +08:00
self . db = await ClientManager . get_client ( )
async def finalize ( self ) :
2025-02-19 04:53:15 +08:00
if self . db is not None :
2025-02-19 03:46:18 +08:00
await ClientManager . release_client ( self . db )
self . db = None
2025-01-27 09:39:39 +01:00
################ QUERY METHODS ################
2025-04-17 11:17:01 +08:00
async def get_all ( self ) - > dict [ str , Any ] :
""" Get all data from storage
Returns :
Dictionary containing all stored data
"""
table_name = namespace_to_table_name ( self . namespace )
if not table_name :
logger . error ( f " Unknown namespace for get_all: { self . namespace } " )
return { }
sql = f " SELECT * FROM { table_name } WHERE workspace=$1 "
params = { " workspace " : self . db . workspace }
try :
results = await self . db . query ( sql , params , multirows = True )
if is_namespace ( self . namespace , NameSpace . KV_STORE_LLM_RESPONSE_CACHE ) :
result_dict = { }
for row in results :
mode = row [ " mode " ]
if mode not in result_dict :
result_dict [ mode ] = { }
result_dict [ mode ] [ row [ " id " ] ] = row
return result_dict
else :
return { row [ " id " ] : row for row in results }
except Exception as e :
logger . error ( f " Error retrieving all data from { self . namespace } : { e } " )
return { }
2025-01-27 09:39:39 +01:00
2025-02-16 13:31:12 +01:00
async def get_by_id ( self , id : str ) - > dict [ str , Any ] | None :
2025-01-27 09:39:39 +01:00
""" Get doc_full data by id. """
2025-03-31 02:59:44 +08:00
sql = SQL_TEMPLATES [ " get_by_id_ " + self . namespace ]
2025-01-27 09:39:39 +01:00
params = { " workspace " : self . db . workspace , " id " : id }
2025-02-08 16:05:59 +08:00
if is_namespace ( self . namespace , NameSpace . KV_STORE_LLM_RESPONSE_CACHE ) :
2025-01-27 09:39:39 +01:00
array_res = await self . db . query ( sql , params , multirows = True )
res = { }
for row in array_res :
res [ row [ " id " ] ] = row
2025-02-09 19:51:05 +01:00
return res if res else None
2025-01-27 09:39:39 +01:00
else :
2025-02-09 19:51:05 +01:00
response = await self . db . query ( sql , params )
return response if response else None
2025-01-27 09:39:39 +01:00
async def get_by_mode_and_id ( self , mode : str , id : str ) - > Union [ dict , None ] :
""" Specifically for llm_response_cache. """
2025-03-31 02:59:44 +08:00
sql = SQL_TEMPLATES [ " get_by_mode_id_ " + self . namespace ]
2025-04-22 00:26:16 +08:00
params = { " workspace " : self . db . workspace , " mode " : mode , " id " : id }
2025-02-08 16:05:59 +08:00
if is_namespace ( self . namespace , NameSpace . KV_STORE_LLM_RESPONSE_CACHE ) :
2025-01-27 09:39:39 +01:00
array_res = await self . db . query ( sql , params , multirows = True )
res = { }
for row in array_res :
res [ row [ " id " ] ] = row
return res
else :
return None
# Query by id
2025-02-09 10:33:15 +01:00
async def get_by_ids ( self , ids : list [ str ] ) - > list [ dict [ str , Any ] ] :
2025-01-27 09:39:39 +01:00
""" Get doc_chunks data by id """
2025-03-31 02:59:44 +08:00
sql = SQL_TEMPLATES [ " get_by_ids_ " + self . namespace ] . format (
2025-01-27 09:39:39 +01:00
ids = " , " . join ( [ f " ' { id } ' " for id in ids ] )
)
params = { " workspace " : self . db . workspace }
2025-02-08 16:05:59 +08:00
if is_namespace ( self . namespace , NameSpace . KV_STORE_LLM_RESPONSE_CACHE ) :
2025-01-27 09:39:39 +01:00
array_res = await self . db . query ( sql , params , multirows = True )
modes = set ( )
dict_res : dict [ str , dict ] = { }
for row in array_res :
modes . add ( row [ " mode " ] )
for mode in modes :
if mode not in dict_res :
dict_res [ mode ] = { }
for row in array_res :
dict_res [ row [ " mode " ] ] [ row [ " id " ] ] = row
2025-02-09 10:33:15 +01:00
return [ { k : v } for k , v in dict_res . items ( ) ]
2025-01-27 09:39:39 +01:00
else :
2025-02-09 10:33:15 +01:00
return await self . db . query ( sql , params , multirows = True )
2025-02-08 23:58:15 +01:00
2025-02-09 11:24:08 +01:00
async def get_by_status ( self , status : str ) - > Union [ list [ dict [ str , Any ] ] , None ] :
2025-02-08 23:18:12 +01:00
""" Specifically for llm_response_cache. """
2025-03-31 02:59:44 +08:00
SQL = SQL_TEMPLATES [ " get_by_status_ " + self . namespace ]
2025-02-08 23:18:12 +01:00
params = { " workspace " : self . db . workspace , " status " : status }
2025-02-08 23:58:15 +01:00
return await self . db . query ( SQL , params , multirows = True )
2025-01-27 09:39:39 +01:00
2025-02-16 13:31:12 +01:00
async def filter_keys ( self , keys : set [ str ] ) - > set [ str ] :
2025-01-27 09:39:39 +01:00
""" Filter out duplicated content """
sql = SQL_TEMPLATES [ " filter_keys " ] . format (
2025-02-08 16:05:59 +08:00
table_name = namespace_to_table_name ( self . namespace ) ,
2025-01-27 09:39:39 +01:00
ids = " , " . join ( [ f " ' { id } ' " for id in keys ] ) ,
)
params = { " workspace " : self . db . workspace }
try :
res = await self . db . query ( sql , params , multirows = True )
if res :
exist_keys = [ key [ " id " ] for key in res ]
else :
exist_keys = [ ]
2025-02-18 10:07:57 +01:00
new_keys = set ( [ s for s in keys if s not in exist_keys ] )
return new_keys
2025-01-27 09:39:39 +01:00
except Exception as e :
2025-02-18 16:55:48 +01:00
logger . error (
f " PostgreSQL database, \n sql: { sql } , \n params: { params } , \n error: { e } "
)
raise
2025-01-27 09:39:39 +01:00
################ INSERT METHODS ################
2025-02-16 13:31:12 +01:00
async def upsert ( self , data : dict [ str , dict [ str , Any ] ] ) - > None :
2025-04-10 01:06:46 +08:00
logger . debug ( f " Inserting { len ( data ) } to { self . namespace } " )
2025-02-19 22:22:41 +01:00
if not data :
return
2025-02-08 16:05:59 +08:00
if is_namespace ( self . namespace , NameSpace . KV_STORE_TEXT_CHUNKS ) :
2025-01-27 09:39:39 +01:00
pass
2025-02-08 16:05:59 +08:00
elif is_namespace ( self . namespace , NameSpace . KV_STORE_FULL_DOCS ) :
2025-01-27 09:39:39 +01:00
for k , v in data . items ( ) :
upsert_sql = SQL_TEMPLATES [ " upsert_doc_full " ]
_data = {
" id " : k ,
" content " : v [ " content " ] ,
" workspace " : self . db . workspace ,
}
await self . db . execute ( upsert_sql , _data )
2025-02-08 16:05:59 +08:00
elif is_namespace ( self . namespace , NameSpace . KV_STORE_LLM_RESPONSE_CACHE ) :
2025-01-27 09:39:39 +01:00
for mode , items in data . items ( ) :
for k , v in items . items ( ) :
upsert_sql = SQL_TEMPLATES [ " upsert_llm_response_cache " ]
_data = {
" workspace " : self . db . workspace ,
" id " : k ,
" original_prompt " : v [ " original_prompt " ] ,
" return_value " : v [ " return " ] ,
" mode " : mode ,
}
await self . db . execute ( upsert_sql , _data )
2025-02-16 13:24:42 +01:00
async def index_done_callback ( self ) - > None :
2025-02-16 16:04:07 +01:00
# PG handles persistence automatically
2025-02-16 14:38:09 +01:00
pass
2025-02-17 23:20:10 +01:00
2025-03-31 23:10:21 +08:00
async def delete ( self , ids : list [ str ] ) - > None :
""" Delete specific records from storage by their IDs
2025-03-31 23:22:27 +08:00
2025-03-31 23:10:21 +08:00
Args :
ids ( list [ str ] ) : List of document IDs to be deleted from storage
2025-03-31 23:22:27 +08:00
2025-03-31 23:10:21 +08:00
Returns :
None
"""
if not ids :
return
table_name = namespace_to_table_name ( self . namespace )
if not table_name :
logger . error ( f " Unknown namespace for deletion: { self . namespace } " )
return
delete_sql = f " DELETE FROM { table_name } WHERE workspace=$1 AND id = ANY($2) "
try :
2025-03-31 23:22:27 +08:00
await self . db . execute (
delete_sql , { " workspace " : self . db . workspace , " ids " : ids }
)
logger . debug (
f " Successfully deleted { len ( ids ) } records from { self . namespace } "
)
2025-03-31 23:10:21 +08:00
except Exception as e :
logger . error ( f " Error while deleting records from { self . namespace } : { e } " )
async def drop_cache_by_modes ( self , modes : list [ str ] | None = None ) - > bool :
""" Delete specific records from storage by cache mode
2025-03-31 23:22:27 +08:00
2025-03-31 23:10:21 +08:00
Args :
modes ( list [ str ] ) : List of cache modes to be dropped from storage
2025-03-31 23:22:27 +08:00
2025-03-31 23:10:21 +08:00
Returns :
bool : True if successful , False otherwise
"""
if not modes :
return False
2025-03-31 23:22:27 +08:00
2025-03-31 23:10:21 +08:00
try :
table_name = namespace_to_table_name ( self . namespace )
if not table_name :
return False
2025-03-31 23:22:27 +08:00
2025-03-31 23:10:21 +08:00
if table_name != " LIGHTRAG_LLM_CACHE " :
return False
2025-03-31 23:22:27 +08:00
2025-03-31 23:10:21 +08:00
sql = f """
DELETE FROM { table_name }
WHERE workspace = $ 1 AND mode = ANY ( $ 2 )
"""
2025-03-31 23:22:27 +08:00
params = { " workspace " : self . db . workspace , " modes " : modes }
2025-03-31 23:10:21 +08:00
logger . info ( f " Deleting cache by modes: { modes } " )
await self . db . execute ( sql , params )
return True
except Exception as e :
logger . error ( f " Error deleting cache by modes { modes } : { e } " )
return False
2025-03-31 01:03:41 +08:00
async def drop ( self ) - > dict [ str , str ] :
2025-02-18 09:10:50 +01:00
""" Drop the storage """
2025-03-31 01:03:41 +08:00
try :
table_name = namespace_to_table_name ( self . namespace )
if not table_name :
2025-03-31 23:22:27 +08:00
return {
" status " : " error " ,
" message " : f " Unknown namespace: { self . namespace } " ,
}
2025-03-31 01:03:41 +08:00
drop_sql = SQL_TEMPLATES [ " drop_specifiy_table_workspace " ] . format (
table_name = table_name
)
await self . db . execute ( drop_sql , { " workspace " : self . db . workspace } )
return { " status " : " success " , " message " : " data dropped " }
except Exception as e :
return { " status " : " error " , " message " : str ( e ) }
2025-02-18 10:24:19 +01:00
2025-02-16 14:38:09 +01:00
@final
2025-01-27 09:39:39 +01:00
@dataclass
class PGVectorStorage ( BaseVectorStorage ) :
2025-02-19 13:42:49 +01:00
db : PostgreSQLDB | None = field ( default = None )
2025-02-19 04:53:15 +08:00
2025-01-27 09:39:39 +01:00
def __post_init__ ( self ) :
self . _max_batch_size = self . global_config [ " embedding_batch_num " ]
2025-01-29 23:47:57 +08:00
config = self . global_config . get ( " vector_db_storage_cls_kwargs " , { } )
2025-02-13 03:25:48 +08:00
cosine_threshold = config . get ( " cosine_better_than_threshold " )
if cosine_threshold is None :
2025-02-13 04:12:00 +08:00
raise ValueError (
" cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs "
)
2025-02-13 03:25:48 +08:00
self . cosine_better_than_threshold = cosine_threshold
2025-01-27 09:39:39 +01:00
2025-02-19 03:46:18 +08:00
async def initialize ( self ) :
2025-02-19 04:53:15 +08:00
if self . db is None :
2025-02-19 03:46:18 +08:00
self . db = await ClientManager . get_client ( )
async def finalize ( self ) :
2025-02-19 04:53:15 +08:00
if self . db is not None :
2025-02-19 03:46:18 +08:00
await ClientManager . release_client ( self . db )
self . db = None
2025-02-19 13:42:49 +01:00
def _upsert_chunks ( self , item : dict [ str , Any ] ) - > tuple [ str , dict [ str , Any ] ] :
2025-01-27 09:39:39 +01:00
try :
upsert_sql = SQL_TEMPLATES [ " upsert_chunk " ]
2025-02-19 13:42:49 +01:00
data : dict [ str , Any ] = {
2025-01-27 09:39:39 +01:00
" workspace " : self . db . workspace ,
" id " : item [ " __id__ " ] ,
" tokens " : item [ " tokens " ] ,
" chunk_order_index " : item [ " chunk_order_index " ] ,
" full_doc_id " : item [ " full_doc_id " ] ,
" content " : item [ " content " ] ,
" content_vector " : json . dumps ( item [ " __vector__ " ] . tolist ( ) ) ,
2025-03-17 23:59:47 +08:00
" file_path " : item [ " file_path " ] ,
2025-01-27 09:39:39 +01:00
}
except Exception as e :
2025-02-18 16:55:48 +01:00
logger . error ( f " Error to prepare upsert, \n sql: { e } \n item: { item } " )
raise
2025-01-27 09:39:39 +01:00
return upsert_sql , data
2025-02-19 13:42:49 +01:00
def _upsert_entities ( self , item : dict [ str , Any ] ) - > tuple [ str , dict [ str , Any ] ] :
2025-01-27 09:39:39 +01:00
upsert_sql = SQL_TEMPLATES [ " upsert_entity " ]
2025-03-13 13:45:09 +02:00
source_id = item [ " source_id " ]
if isinstance ( source_id , str ) and " <SEP> " in source_id :
chunk_ids = source_id . split ( " <SEP> " )
else :
chunk_ids = [ source_id ]
2025-03-17 15:59:54 +08:00
2025-02-19 13:42:49 +01:00
data : dict [ str , Any ] = {
2025-01-27 09:39:39 +01:00
" workspace " : self . db . workspace ,
" id " : item [ " __id__ " ] ,
" entity_name " : item [ " entity_name " ] ,
" content " : item [ " content " ] ,
" content_vector " : json . dumps ( item [ " __vector__ " ] . tolist ( ) ) ,
2025-03-13 13:45:09 +02:00
" chunk_ids " : chunk_ids ,
2025-03-17 23:59:47 +08:00
" file_path " : item [ " file_path " ] ,
2025-03-10 15:39:18 +00:00
# TODO: add document_id
2025-01-27 09:39:39 +01:00
}
return upsert_sql , data
2025-02-19 13:42:49 +01:00
def _upsert_relationships ( self , item : dict [ str , Any ] ) - > tuple [ str , dict [ str , Any ] ] :
2025-01-27 09:39:39 +01:00
upsert_sql = SQL_TEMPLATES [ " upsert_relationship " ]
2025-03-13 13:45:09 +02:00
source_id = item [ " source_id " ]
if isinstance ( source_id , str ) and " <SEP> " in source_id :
chunk_ids = source_id . split ( " <SEP> " )
else :
chunk_ids = [ source_id ]
2025-03-17 15:59:54 +08:00
2025-02-19 13:42:49 +01:00
data : dict [ str , Any ] = {
2025-01-27 09:39:39 +01:00
" workspace " : self . db . workspace ,
" id " : item [ " __id__ " ] ,
" source_id " : item [ " src_id " ] ,
" target_id " : item [ " tgt_id " ] ,
" content " : item [ " content " ] ,
" content_vector " : json . dumps ( item [ " __vector__ " ] . tolist ( ) ) ,
2025-03-13 13:45:09 +02:00
" chunk_ids " : chunk_ids ,
2025-03-17 23:59:47 +08:00
" file_path " : item [ " file_path " ] ,
2025-03-10 15:39:18 +00:00
# TODO: add document_id
2025-01-27 09:39:39 +01:00
}
return upsert_sql , data
2025-02-16 13:24:42 +01:00
async def upsert ( self , data : dict [ str , dict [ str , Any ] ] ) - > None :
2025-04-10 01:06:46 +08:00
logger . debug ( f " Inserting { len ( data ) } to { self . namespace } " )
2025-02-19 22:22:41 +01:00
if not data :
return
2025-01-27 09:39:39 +01:00
current_time = time . time ( )
list_data = [
{
" __id__ " : k ,
" __created_at__ " : current_time ,
* * { k1 : v1 for k1 , v1 in v . items ( ) } ,
}
for k , v in data . items ( )
]
contents = [ v [ " content " ] for v in data . values ( ) ]
batches = [
contents [ i : i + self . _max_batch_size ]
for i in range ( 0 , len ( contents ) , self . _max_batch_size )
]
2025-02-18 19:58:03 +01:00
embedding_tasks = [ self . embedding_func ( batch ) for batch in batches ]
2025-01-27 09:39:39 +01:00
embeddings_list = await asyncio . gather ( * embedding_tasks )
embeddings = np . concatenate ( embeddings_list )
for i , d in enumerate ( list_data ) :
d [ " __vector__ " ] = embeddings [ i ]
for item in list_data :
2025-02-08 16:05:59 +08:00
if is_namespace ( self . namespace , NameSpace . VECTOR_STORE_CHUNKS ) :
2025-01-27 09:39:39 +01:00
upsert_sql , data = self . _upsert_chunks ( item )
2025-02-08 16:05:59 +08:00
elif is_namespace ( self . namespace , NameSpace . VECTOR_STORE_ENTITIES ) :
2025-01-27 09:39:39 +01:00
upsert_sql , data = self . _upsert_entities ( item )
2025-02-08 16:05:59 +08:00
elif is_namespace ( self . namespace , NameSpace . VECTOR_STORE_RELATIONSHIPS ) :
2025-01-27 09:39:39 +01:00
upsert_sql , data = self . _upsert_relationships ( item )
else :
raise ValueError ( f " { self . namespace } is not supported " )
await self . db . execute ( upsert_sql , data )
#################### query method ###############
2025-03-10 15:39:18 +00:00
async def query (
self , query : str , top_k : int , ids : list [ str ] | None = None
) - > list [ dict [ str , Any ] ] :
2025-01-27 09:39:39 +01:00
embeddings = await self . embedding_func ( [ query ] )
embedding = embeddings [ 0 ]
embedding_string = " , " . join ( map ( str , embedding ) )
2025-04-21 21:20:21 +03:00
# Use parameterized document IDs (None means search across all documents)
sql = SQL_TEMPLATES [ self . namespace ] . format ( embedding_string = embedding_string )
2025-01-27 09:39:39 +01:00
params = {
" workspace " : self . db . workspace ,
2025-04-21 21:20:21 +03:00
" doc_ids " : ids ,
2025-01-27 09:39:39 +01:00
" better_than_threshold " : self . cosine_better_than_threshold ,
" top_k " : top_k ,
}
results = await self . db . query ( sql , params = params , multirows = True )
return results
2025-02-16 16:04:07 +01:00
async def index_done_callback ( self ) - > None :
# PG handles persistence automatically
pass
2025-02-16 16:04:35 +01:00
2025-03-04 15:50:53 +08:00
async def delete ( self , ids : list [ str ] ) - > None :
""" Delete vectors with specified IDs from the storage.
Args :
ids : List of vector IDs to be deleted
"""
if not ids :
return
table_name = namespace_to_table_name ( self . namespace )
if not table_name :
logger . error ( f " Unknown namespace for vector deletion: { self . namespace } " )
return
2025-03-31 23:10:21 +08:00
delete_sql = f " DELETE FROM { table_name } WHERE workspace=$1 AND id = ANY($2) "
2025-03-04 15:53:20 +08:00
2025-03-04 15:50:53 +08:00
try :
2025-03-31 23:22:27 +08:00
await self . db . execute (
delete_sql , { " workspace " : self . db . workspace , " ids " : ids }
)
2025-03-04 15:53:20 +08:00
logger . debug (
f " Successfully deleted { len ( ids ) } vectors from { self . namespace } "
)
2025-03-04 15:50:53 +08:00
except Exception as e :
logger . error ( f " Error while deleting vectors from { self . namespace } : { e } " )
2025-02-16 13:24:42 +01:00
async def delete_entity ( self , entity_name : str ) - > None :
2025-03-04 15:50:53 +08:00
""" Delete an entity by its name from the vector storage.
Args :
entity_name : The name of the entity to delete
"""
try :
# Construct SQL to delete the entity
2025-03-04 15:53:20 +08:00
delete_sql = """ DELETE FROM LIGHTRAG_VDB_ENTITY
2025-03-04 15:50:53 +08:00
WHERE workspace = $ 1 AND entity_name = $ 2 """
2025-03-04 15:53:20 +08:00
2025-03-04 15:50:53 +08:00
await self . db . execute (
2025-03-04 15:53:20 +08:00
delete_sql , { " workspace " : self . db . workspace , " entity_name " : entity_name }
2025-03-04 15:50:53 +08:00
)
logger . debug ( f " Successfully deleted entity { entity_name } " )
except Exception as e :
logger . error ( f " Error deleting entity { entity_name } : { e } " )
2025-02-16 13:24:42 +01:00
async def delete_entity_relation ( self , entity_name : str ) - > None :
2025-03-04 15:50:53 +08:00
""" Delete all relations associated with an entity.
Args :
entity_name : The name of the entity whose relations should be deleted
"""
try :
# Delete relations where the entity is either the source or target
2025-03-04 15:53:20 +08:00
delete_sql = """ DELETE FROM LIGHTRAG_VDB_RELATION
2025-03-04 15:50:53 +08:00
WHERE workspace = $ 1 AND ( source_id = $ 2 OR target_id = $ 2 ) """
2025-03-04 15:53:20 +08:00
2025-03-04 15:50:53 +08:00
await self . db . execute (
2025-03-04 15:53:20 +08:00
delete_sql , { " workspace " : self . db . workspace , " entity_name " : entity_name }
2025-03-04 15:50:53 +08:00
)
logger . debug ( f " Successfully deleted relations for entity { entity_name } " )
except Exception as e :
logger . error ( f " Error deleting relations for entity { entity_name } : { e } " )
2025-01-27 09:39:39 +01:00
2025-03-07 14:39:06 +08:00
async def search_by_prefix ( self , prefix : str ) - > list [ dict [ str , Any ] ] :
""" Search for records with IDs starting with a specific prefix.
Args :
prefix : The prefix to search for in record IDs
Returns :
List of records with matching ID prefixes
"""
table_name = namespace_to_table_name ( self . namespace )
if not table_name :
logger . error ( f " Unknown namespace for prefix search: { self . namespace } " )
return [ ]
search_sql = f " SELECT * FROM { table_name } WHERE workspace=$1 AND id LIKE $2 "
params = { " workspace " : self . db . workspace , " prefix " : f " { prefix } % " }
try :
results = await self . db . query ( search_sql , params , multirows = True )
logger . debug ( f " Found { len ( results ) } records with prefix ' { prefix } ' " )
# Format results to match the expected return format
formatted_results = [ ]
for record in results :
formatted_record = dict ( record )
# Ensure id field is available (for consistency with NanoVectorDB implementation)
if " id " not in formatted_record :
formatted_record [ " id " ] = record [ " id " ]
formatted_results . append ( formatted_record )
return formatted_results
except Exception as e :
logger . error ( f " Error during prefix search for ' { prefix } ' : { e } " )
return [ ]
2025-03-11 16:05:04 +08:00
async def get_by_id ( self , id : str ) - > dict [ str , Any ] | None :
""" Get vector data by its ID
Args :
id : The unique identifier of the vector
Returns :
The vector data if found , or None if not found
"""
table_name = namespace_to_table_name ( self . namespace )
if not table_name :
logger . error ( f " Unknown namespace for ID lookup: { self . namespace } " )
return None
query = f " SELECT * FROM { table_name } WHERE workspace=$1 AND id=$2 "
params = { " workspace " : self . db . workspace , " id " : id }
try :
result = await self . db . query ( query , params )
if result :
return dict ( result )
return None
except Exception as e :
logger . error ( f " Error retrieving vector data for ID { id } : { e } " )
return None
async def get_by_ids ( self , ids : list [ str ] ) - > list [ dict [ str , Any ] ] :
""" Get multiple vector data by their IDs
Args :
ids : List of unique identifiers
Returns :
List of vector data objects that were found
"""
if not ids :
return [ ]
table_name = namespace_to_table_name ( self . namespace )
if not table_name :
logger . error ( f " Unknown namespace for IDs lookup: { self . namespace } " )
return [ ]
ids_str = " , " . join ( [ f " ' { id } ' " for id in ids ] )
query = f " SELECT * FROM { table_name } WHERE workspace=$1 AND id IN ( { ids_str } ) "
params = { " workspace " : self . db . workspace }
try :
results = await self . db . query ( query , params , multirows = True )
return [ dict ( record ) for record in results ]
except Exception as e :
logger . error ( f " Error retrieving vector data for IDs { ids } : { e } " )
return [ ]
2025-03-31 01:03:41 +08:00
async def drop ( self ) - > dict [ str , str ] :
""" Drop the storage """
try :
table_name = namespace_to_table_name ( self . namespace )
if not table_name :
2025-03-31 23:22:27 +08:00
return {
" status " : " error " ,
" message " : f " Unknown namespace: { self . namespace } " ,
}
2025-03-31 01:03:41 +08:00
drop_sql = SQL_TEMPLATES [ " drop_specifiy_table_workspace " ] . format (
table_name = table_name
)
await self . db . execute ( drop_sql , { " workspace " : self . db . workspace } )
return { " status " : " success " , " message " : " data dropped " }
except Exception as e :
return { " status " : " error " , " message " : str ( e ) }
2025-02-16 13:55:30 +01:00
2025-02-16 14:38:09 +01:00
@final
2025-01-27 09:39:39 +01:00
@dataclass
class PGDocStatusStorage ( DocStatusStorage ) :
2025-02-19 04:55:59 +08:00
db : PostgreSQLDB = field ( default = None )
2025-02-19 04:53:15 +08:00
2025-02-19 03:46:18 +08:00
async def initialize ( self ) :
2025-02-19 04:53:15 +08:00
if self . db is None :
2025-02-19 03:46:18 +08:00
self . db = await ClientManager . get_client ( )
async def finalize ( self ) :
2025-02-19 04:53:15 +08:00
if self . db is not None :
2025-02-19 03:46:18 +08:00
await ClientManager . release_client ( self . db )
self . db = None
2025-02-16 14:38:09 +01:00
async def filter_keys ( self , keys : set [ str ] ) - > set [ str ] :
2025-02-18 10:12:08 +01:00
""" Filter out duplicated content """
sql = SQL_TEMPLATES [ " filter_keys " ] . format (
table_name = namespace_to_table_name ( self . namespace ) ,
ids = " , " . join ( [ f " ' { id } ' " for id in keys ] ) ,
)
params = { " workspace " : self . db . workspace }
try :
res = await self . db . query ( sql , params , multirows = True )
if res :
exist_keys = [ key [ " id " ] for key in res ]
else :
exist_keys = [ ]
new_keys = set ( [ s for s in keys if s not in exist_keys ] )
print ( f " keys: { keys } " )
print ( f " new_keys: { new_keys } " )
return new_keys
except Exception as e :
2025-02-18 16:55:48 +01:00
logger . error (
f " PostgreSQL database, \n sql: { sql } , \n params: { params } , \n error: { e } "
)
raise
2025-01-01 22:43:59 +08:00
2025-02-09 19:51:05 +01:00
async def get_by_id ( self , id : str ) - > Union [ dict [ str , Any ] , None ] :
2025-02-02 18:20:32 +08:00
sql = " select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2 "
params = { " workspace " : self . db . workspace , " id " : id }
result = await self . db . query ( sql , params , True )
2025-02-04 17:09:34 +08:00
if result is None or result == [ ] :
2025-02-09 19:51:05 +01:00
return None
2025-02-02 18:20:32 +08:00
else :
2025-03-17 17:32:54 -07:00
return dict (
2025-02-09 15:36:01 +01:00
content = result [ 0 ] [ " content " ] ,
2025-02-02 18:20:32 +08:00
content_length = result [ 0 ] [ " content_length " ] ,
content_summary = result [ 0 ] [ " content_summary " ] ,
status = result [ 0 ] [ " status " ] ,
chunks_count = result [ 0 ] [ " chunks_count " ] ,
created_at = result [ 0 ] [ " created_at " ] ,
updated_at = result [ 0 ] [ " updated_at " ] ,
2025-03-17 23:59:47 +08:00
file_path = result [ 0 ] [ " file_path " ] ,
2025-02-02 18:20:32 +08:00
)
2025-02-16 14:38:09 +01:00
async def get_by_ids ( self , ids : list [ str ] ) - > list [ dict [ str , Any ] ] :
2025-03-19 15:24:25 -07:00
""" Get doc_chunks data by multiple IDs. """
if not ids :
return [ ]
sql = " SELECT * FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id = ANY($2) "
params = { " workspace " : self . db . workspace , " ids " : ids }
results = await self . db . query ( sql , params , True )
if not results :
return [ ]
return [
{
" content " : row [ " content " ] ,
" content_length " : row [ " content_length " ] ,
" content_summary " : row [ " content_summary " ] ,
" status " : row [ " status " ] ,
" chunks_count " : row [ " chunks_count " ] ,
" created_at " : row [ " created_at " ] ,
" updated_at " : row [ " updated_at " ] ,
" file_path " : row [ " file_path " ] ,
}
for row in results
]
2025-02-16 14:38:09 +01:00
2025-02-19 13:31:30 +01:00
async def get_status_counts ( self ) - > dict [ str , int ] :
2025-01-27 09:39:39 +01:00
""" Get counts of documents in each status """
sql = """ SELECT status as " status " , COUNT(1) as " count "
FROM LIGHTRAG_DOC_STATUS
where workspace = $ 1 GROUP BY STATUS
"""
result = await self . db . query ( sql , { " workspace " : self . db . workspace } , True )
counts = { }
for doc in result :
counts [ doc [ " status " ] ] = doc [ " count " ]
return counts
2025-01-01 22:43:59 +08:00
2025-01-27 09:39:39 +01:00
async def get_docs_by_status (
self , status : DocStatus
2025-02-19 13:31:30 +01:00
) - > dict [ str , DocProcessingStatus ] :
2025-02-16 21:28:58 +08:00
""" all documents with a specific status """
2025-02-11 16:11:15 +08:00
sql = " select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2 "
2025-02-16 15:52:59 +01:00
params = { " workspace " : self . db . workspace , " status " : status . value }
2025-01-27 09:39:39 +01:00
result = await self . db . query ( sql , params , True )
2025-02-18 10:16:00 +01:00
docs_by_status = {
2025-01-27 09:39:39 +01:00
element [ " id " ] : DocProcessingStatus (
2025-03-12 07:47:21 +08:00
content = element [ " content " ] ,
2025-01-27 09:39:39 +01:00
content_summary = element [ " content_summary " ] ,
content_length = element [ " content_length " ] ,
status = element [ " status " ] ,
2025-02-18 16:10:26 +01:00
created_at = element [ " created_at " ] ,
updated_at = element [ " updated_at " ] ,
2025-01-27 09:39:39 +01:00
chunks_count = element [ " chunks_count " ] ,
2025-03-17 23:59:47 +08:00
file_path = element [ " file_path " ] ,
2025-01-27 09:39:39 +01:00
)
for element in result
}
2025-02-18 10:16:00 +01:00
return docs_by_status
2025-01-01 22:43:59 +08:00
2025-02-16 14:38:09 +01:00
async def index_done_callback ( self ) - > None :
2025-02-16 16:04:07 +01:00
# PG handles persistence automatically
2025-02-16 14:38:09 +01:00
pass
2025-01-27 09:39:39 +01:00
2025-04-01 22:15:31 +08:00
async def delete ( self , ids : list [ str ] ) - > None :
""" Delete specific records from storage by their IDs
Args :
ids ( list [ str ] ) : List of document IDs to be deleted from storage
Returns :
None
"""
if not ids :
return
table_name = namespace_to_table_name ( self . namespace )
if not table_name :
logger . error ( f " Unknown namespace for deletion: { self . namespace } " )
return
delete_sql = f " DELETE FROM { table_name } WHERE workspace=$1 AND id = ANY($2) "
try :
await self . db . execute (
delete_sql , { " workspace " : self . db . workspace , " ids " : ids }
)
logger . debug (
f " Successfully deleted { len ( ids ) } records from { self . namespace } "
)
except Exception as e :
logger . error ( f " Error while deleting records from { self . namespace } : { e } " )
2025-02-16 14:50:04 +01:00
async def upsert ( self , data : dict [ str , dict [ str , Any ] ] ) - > None :
2025-01-27 09:39:39 +01:00
""" Update or insert document status
Args :
2025-02-19 13:31:30 +01:00
data : dictionary of document IDs and their status data
2025-01-27 09:39:39 +01:00
"""
2025-04-10 01:06:46 +08:00
logger . debug ( f " Inserting { len ( data ) } to { self . namespace } " )
2025-02-19 22:22:41 +01:00
if not data :
return
2025-03-17 23:59:47 +08:00
sql = """ insert into LIGHTRAG_DOC_STATUS(workspace,id,content,content_summary,content_length,chunks_count,status,file_path)
values ( $ 1 , $ 2 , $ 3 , $ 4 , $ 5 , $ 6 , $ 7 , $ 8 )
2025-01-27 09:39:39 +01:00
on conflict ( id , workspace ) do update set
2025-02-09 15:36:01 +01:00
content = EXCLUDED . content ,
2025-01-27 09:39:39 +01:00
content_summary = EXCLUDED . content_summary ,
content_length = EXCLUDED . content_length ,
chunks_count = EXCLUDED . chunks_count ,
status = EXCLUDED . status ,
2025-03-17 23:59:47 +08:00
file_path = EXCLUDED . file_path ,
2025-01-27 09:39:39 +01:00
updated_at = CURRENT_TIMESTAMP """
for k , v in data . items ( ) :
# chunks_count is optional
await self . db . execute (
sql ,
{
" workspace " : self . db . workspace ,
" id " : k ,
2025-02-09 15:36:01 +01:00
" content " : v [ " content " ] ,
2025-01-27 09:39:39 +01:00
" content_summary " : v [ " content_summary " ] ,
" content_length " : v [ " content_length " ] ,
" chunks_count " : v [ " chunks_count " ] if " chunks_count " in v else - 1 ,
" status " : v [ " status " ] ,
2025-03-17 23:59:47 +08:00
" file_path " : v [ " file_path " ] ,
2025-01-27 09:39:39 +01:00
} ,
)
2025-02-18 10:24:19 +01:00
2025-03-31 01:03:41 +08:00
async def drop ( self ) - > dict [ str , str ] :
2025-02-18 09:57:10 +01:00
""" Drop the storage """
2025-03-31 01:03:41 +08:00
try :
table_name = namespace_to_table_name ( self . namespace )
if not table_name :
2025-03-31 23:22:27 +08:00
return {
" status " : " error " ,
" message " : f " Unknown namespace: { self . namespace } " ,
}
2025-03-31 01:03:41 +08:00
drop_sql = SQL_TEMPLATES [ " drop_specifiy_table_workspace " ] . format (
table_name = table_name
)
await self . db . execute ( drop_sql , { " workspace " : self . db . workspace } )
return { " status " : " success " , " message " : " data dropped " }
except Exception as e :
return { " status " : " error " , " message " : str ( e ) }
2025-01-27 09:39:39 +01:00
2025-02-18 10:24:19 +01:00
2025-01-27 09:39:39 +01:00
class PGGraphQueryException ( Exception ) :
""" Exception for the AGE queries. """
2025-02-19 13:31:30 +01:00
def __init__ ( self , exception : Union [ str , dict [ str , Any ] ] ) - > None :
2025-01-27 09:39:39 +01:00
if isinstance ( exception , dict ) :
self . message = exception [ " message " ] if " message " in exception else " unknown "
self . details = exception [ " details " ] if " details " in exception else " unknown "
else :
self . message = exception
self . details = " unknown "
def get_message ( self ) - > str :
return self . message
def get_details ( self ) - > Any :
return self . details
2025-02-16 14:38:09 +01:00
@final
2025-01-27 09:39:39 +01:00
@dataclass
class PGGraphStorage ( BaseGraphStorage ) :
2025-02-12 22:25:34 +08:00
def __post_init__ ( self ) :
2025-02-13 04:52:54 +08:00
self . graph_name = self . namespace or os . environ . get ( " AGE_GRAPH_NAME " , " lightrag " )
2025-02-19 13:31:30 +01:00
self . db : PostgreSQLDB | None = None
2025-04-17 02:32:32 +08:00
2025-04-17 02:31:56 +08:00
@staticmethod
def _normalize_node_id ( node_id : str ) - > str :
"""
Normalize node ID to ensure special characters are properly handled in Cypher queries .
2025-04-17 02:32:32 +08:00
2025-04-17 02:31:56 +08:00
Args :
node_id : The original node ID
2025-04-17 02:32:32 +08:00
2025-04-17 02:31:56 +08:00
Returns :
Normalized node ID suitable for Cypher queries
"""
# Escape backslashes
2025-04-17 22:58:36 +08:00
normalized_id = node_id
2025-04-17 02:31:56 +08:00
normalized_id = normalized_id . replace ( " \\ " , " \\ \\ " )
2025-04-17 22:58:36 +08:00
normalized_id = normalized_id . replace ( ' " ' , ' \\ " ' )
2025-04-17 02:31:56 +08:00
return normalized_id
2025-01-27 09:39:39 +01:00
2025-02-19 03:46:18 +08:00
async def initialize ( self ) :
2025-02-19 04:53:15 +08:00
if self . db is None :
2025-02-19 03:46:18 +08:00
self . db = await ClientManager . get_client ( )
2025-04-24 01:41:33 +08:00
# 分别执行每个语句,忽略错误
queries = [
f " SELECT create_graph( ' { self . graph_name } ' ) " ,
f ' CREATE INDEX CONCURRENTLY vertex_p_idx ON { self . graph_name } . " _ag_label_vertex " (id) ' ,
f ' CREATE INDEX CONCURRENTLY vertex_idx_node_id ON { self . graph_name } . " _ag_label_vertex " (ag_catalog.agtype_access_operator(properties, \' " entity_id " \' ::agtype)) ' ,
f ' CREATE INDEX CONCURRENTLY edge_p_idx ON { self . graph_name } . " _ag_label_edge " (id) ' ,
f ' CREATE INDEX CONCURRENTLY edge_sid_idx ON { self . graph_name } . " _ag_label_edge " (start_id) ' ,
f ' CREATE INDEX CONCURRENTLY edge_eid_idx ON { self . graph_name } . " _ag_label_edge " (end_id) ' ,
f ' CREATE INDEX CONCURRENTLY edge_seid_idx ON { self . graph_name } . " _ag_label_edge " (start_id,end_id) ' ,
f ' CREATE INDEX CONCURRENTLY directed_p_idx ON { self . graph_name } . " DIRECTED " (id) ' ,
f ' CREATE INDEX CONCURRENTLY directed_eid_idx ON { self . graph_name } . " DIRECTED " (end_id) ' ,
f ' CREATE INDEX CONCURRENTLY directed_sid_idx ON { self . graph_name } . " DIRECTED " (start_id) ' ,
f ' CREATE INDEX CONCURRENTLY directed_seid_idx ON { self . graph_name } . " DIRECTED " (start_id,end_id) ' ,
f ' CREATE INDEX CONCURRENTLY entity_p_idx ON { self . graph_name } . " base " (id) ' ,
f ' CREATE INDEX CONCURRENTLY entity_idx_node_id ON { self . graph_name } . " base " (ag_catalog.agtype_access_operator(properties, \' " entity_id " \' ::agtype)) ' ,
f ' CREATE INDEX CONCURRENTLY entity_node_id_gin_idx ON { self . graph_name } . " base " using gin(properties) ' ,
f ' ALTER TABLE { self . graph_name } . " DIRECTED " CLUSTER ON directed_sid_idx ' ,
]
2025-04-16 02:31:16 +08:00
2025-04-24 01:41:33 +08:00
for query in queries :
try :
await self . db . execute (
query ,
upsert = True ,
with_age = True ,
graph_name = self . graph_name ,
)
logger . info ( f " Successfully executed: { query } " )
except Exception :
continue
2025-04-16 02:31:16 +08:00
2025-02-19 03:46:18 +08:00
async def finalize ( self ) :
2025-02-19 04:53:15 +08:00
if self . db is not None :
2025-02-19 03:46:18 +08:00
await ClientManager . release_client ( self . db )
self . db = None
2025-02-16 14:38:09 +01:00
async def index_done_callback ( self ) - > None :
2025-02-16 16:04:07 +01:00
# PG handles persistence automatically
2025-02-16 14:38:09 +01:00
pass
2025-01-27 09:39:39 +01:00
@staticmethod
2025-02-19 13:31:30 +01:00
def _record_to_dict ( record : asyncpg . Record ) - > dict [ str , Any ] :
2025-01-27 09:39:39 +01:00
"""
Convert a record returned from an age query to a dictionary
Args :
record ( ) : a record from an age query result
Returns :
2025-02-19 13:31:30 +01:00
dict [ str , Any ] : a dictionary representation of the record where
2025-01-27 09:39:39 +01:00
the dictionary key is the field name and the value is the
value converted to a python type
"""
# result holder
d = { }
# prebuild a mapping of vertex_id to vertex mappings to be used
# later to build edges
vertices = { }
for k in record . keys ( ) :
v = record [ k ]
# agtype comes back '{key: value}::type' which must be parsed
if isinstance ( v , str ) and " :: " in v :
2025-03-08 11:45:59 +08:00
if v . startswith ( " [ " ) and v . endswith ( " ] " ) :
if " ::vertex " not in v :
continue
v = v . replace ( " ::vertex " , " " )
vertexes = json . loads ( v )
for vertex in vertexes :
vertices [ vertex [ " id " ] ] = vertex . get ( " properties " )
else :
dtype = v . split ( " :: " ) [ - 1 ]
v = v . split ( " :: " ) [ 0 ]
if dtype == " vertex " :
vertex = json . loads ( v )
vertices [ vertex [ " id " ] ] = vertex . get ( " properties " )
2025-01-27 09:39:39 +01:00
# iterate returned fields and parse appropriately
for k in record . keys ( ) :
v = record [ k ]
if isinstance ( v , str ) and " :: " in v :
2025-03-08 11:45:59 +08:00
if v . startswith ( " [ " ) and v . endswith ( " ] " ) :
if " ::vertex " in v :
v = v . replace ( " ::vertex " , " " )
2025-04-03 15:16:48 +08:00
d [ k ] = json . loads ( v )
2025-03-08 11:45:59 +08:00
elif " ::edge " in v :
v = v . replace ( " ::edge " , " " )
2025-04-03 15:16:48 +08:00
d [ k ] = json . loads ( v )
2025-03-08 11:45:59 +08:00
else :
print ( " WARNING: unsupported type " )
continue
else :
dtype = v . split ( " :: " ) [ - 1 ]
v = v . split ( " :: " ) [ 0 ]
if dtype == " vertex " :
2025-04-03 15:16:48 +08:00
d [ k ] = json . loads ( v )
2025-03-08 11:45:59 +08:00
elif dtype == " edge " :
2025-04-03 15:16:48 +08:00
d [ k ] = json . loads ( v )
2025-01-27 09:39:39 +01:00
else :
2025-04-04 04:45:59 +08:00
try :
d [ k ] = (
json . loads ( v )
2025-04-04 04:46:40 +08:00
if isinstance ( v , str )
and ( v . startswith ( " { " ) or v . startswith ( " [ " ) )
2025-04-04 04:45:59 +08:00
else v
)
except json . JSONDecodeError :
d [ k ] = v
2025-01-27 09:39:39 +01:00
return d
@staticmethod
def _format_properties (
2025-02-19 13:31:30 +01:00
properties : dict [ str , Any ] , _id : Union [ str , None ] = None
2025-01-27 09:39:39 +01:00
) - > str :
"""
Convert a dictionary of properties to a string representation that
can be used in a cypher query insert / merge statement .
Args :
2025-02-19 13:31:30 +01:00
properties ( dict [ str , str ] ) : a dictionary containing node / edge properties
2025-01-27 09:39:39 +01:00
_id ( Union [ str , None ] ) : the id of the node or None if none exists
Returns :
str : the properties dictionary as a properly formatted string
"""
props = [ ]
# wrap property key in backticks to escape
for k , v in properties . items ( ) :
prop = f " ` { k } `: { json . dumps ( v ) } "
props . append ( prop )
if _id is not None and " id " not in properties :
props . append (
f " id: { json . dumps ( _id ) } " if isinstance ( _id , str ) else f " id: { _id } "
)
return " { " + " , " . join ( props ) + " } "
async def _query (
2025-02-19 13:31:30 +01:00
self ,
query : str ,
readonly : bool = True ,
upsert : bool = False ,
) - > list [ dict [ str , Any ] ] :
2025-01-27 09:39:39 +01:00
"""
Query the graph by taking a cypher query , converting it to an
age compatible query , executing it and converting the result
Args :
query ( str ) : a cypher query to be executed
Returns :
2025-02-19 13:31:30 +01:00
list [ dict [ str , Any ] ] : a list of dictionaries containing the result set
2025-01-27 09:39:39 +01:00
"""
try :
if readonly :
data = await self . db . query (
2025-02-19 13:31:30 +01:00
query ,
2025-01-27 09:39:39 +01:00
multirows = True ,
2025-02-19 14:26:46 +01:00
with_age = True ,
graph_name = self . graph_name ,
2025-01-27 09:39:39 +01:00
)
else :
data = await self . db . execute (
2025-02-19 13:31:30 +01:00
query ,
2025-01-27 09:39:39 +01:00
upsert = upsert ,
2025-02-19 14:26:46 +01:00
with_age = True ,
graph_name = self . graph_name ,
2025-01-27 09:39:39 +01:00
)
2025-02-19 14:26:46 +01:00
2025-01-27 09:39:39 +01:00
except Exception as e :
raise PGGraphQueryException (
{
" message " : f " Error executing graph query: { query } " ,
2025-02-19 13:31:30 +01:00
" wrapped " : query ,
2025-01-27 09:39:39 +01:00
" detail " : str ( e ) ,
}
) from e
if data is None :
result = [ ]
# decode records
else :
2025-02-19 13:31:30 +01:00
result = [ self . _record_to_dict ( d ) for d in data ]
2025-01-27 09:39:39 +01:00
return result
async def has_node ( self , node_id : str ) - > bool :
2025-04-17 02:31:56 +08:00
entity_name_label = self . _normalize_node_id ( node_id )
2025-01-27 09:39:39 +01:00
query = """ SELECT * FROM cypher( ' %s ' , $$
2025-04-03 15:16:48 +08:00
MATCH ( n : base { entity_id : " %s " } )
2025-01-27 09:39:39 +01:00
RETURN count ( n ) > 0 AS node_exists
$ $ ) AS ( node_exists bool ) """ % (self.graph_name, entity_name_label)
single_result = ( await self . _query ( query ) ) [ 0 ]
return single_result [ " node_exists " ]
async def has_edge ( self , source_node_id : str , target_node_id : str ) - > bool :
2025-04-17 02:31:56 +08:00
src_label = self . _normalize_node_id ( source_node_id )
tgt_label = self . _normalize_node_id ( target_node_id )
2025-01-27 09:39:39 +01:00
query = """ SELECT * FROM cypher( ' %s ' , $$
2025-04-16 11:39:06 +08:00
MATCH ( a : base { entity_id : " %s " } ) - [ r ] - ( b : base { entity_id : " %s " } )
2025-01-27 09:39:39 +01:00
RETURN COUNT ( r ) > 0 AS edge_exists
$ $ ) AS ( edge_exists bool ) """ % (
self . graph_name ,
src_label ,
tgt_label ,
)
single_result = ( await self . _query ( query ) ) [ 0 ]
2025-02-19 13:42:49 +01:00
2025-01-27 09:39:39 +01:00
return single_result [ " edge_exists " ]
2025-02-16 13:53:59 +01:00
async def get_node ( self , node_id : str ) - > dict [ str , str ] | None :
2025-04-03 15:40:31 +08:00
""" Get node by its label identifier, return only node properties """
2025-04-17 02:31:56 +08:00
label = self . _normalize_node_id ( node_id )
2025-01-27 09:39:39 +01:00
query = """ SELECT * FROM cypher( ' %s ' , $$
2025-04-03 15:16:48 +08:00
MATCH ( n : base { entity_id : " %s " } )
2025-01-27 09:39:39 +01:00
RETURN n
$ $ ) AS ( n agtype ) """ % (self.graph_name, label)
record = await self . _query ( query )
if record :
node = record [ 0 ]
2025-04-03 15:40:31 +08:00
node_dict = node [ " n " ] [ " properties " ]
2025-02-19 13:42:49 +01:00
2025-01-27 09:39:39 +01:00
return node_dict
return None
async def node_degree ( self , node_id : str ) - > int :
2025-04-17 02:31:56 +08:00
label = self . _normalize_node_id ( node_id )
2025-01-27 09:39:39 +01:00
query = """ SELECT * FROM cypher( ' %s ' , $$
2025-04-16 14:01:21 +08:00
MATCH ( n : base { entity_id : " %s " } ) - [ r ] - ( )
RETURN count ( r ) AS total_edge_count
2025-01-27 09:39:39 +01:00
$ $ ) AS ( total_edge_count integer ) """ % (self.graph_name, label)
record = ( await self . _query ( query ) ) [ 0 ]
if record :
edge_count = int ( record [ " total_edge_count " ] )
return edge_count
async def edge_degree ( self , src_id : str , tgt_id : str ) - > int :
src_degree = await self . node_degree ( src_id )
trg_degree = await self . node_degree ( tgt_id )
# Convert None to 0 for addition
src_degree = 0 if src_degree is None else src_degree
trg_degree = 0 if trg_degree is None else trg_degree
degrees = int ( src_degree ) + int ( trg_degree )
2025-02-19 13:42:49 +01:00
2025-01-27 09:39:39 +01:00
return degrees
async def get_edge (
self , source_node_id : str , target_node_id : str
2025-02-16 13:53:59 +01:00
) - > dict [ str , str ] | None :
2025-04-03 15:40:31 +08:00
""" Get edge properties between two nodes """
2025-04-03 15:40:55 +08:00
2025-04-17 02:31:56 +08:00
src_label = self . _normalize_node_id ( source_node_id )
tgt_label = self . _normalize_node_id ( target_node_id )
2025-01-27 09:39:39 +01:00
query = """ SELECT * FROM cypher( ' %s ' , $$
2025-04-16 11:39:06 +08:00
MATCH ( a : base { entity_id : " %s " } ) - [ r ] - ( b : base { entity_id : " %s " } )
2025-01-27 09:39:39 +01:00
RETURN properties ( r ) as edge_properties
LIMIT 1
$ $ ) AS ( edge_properties agtype ) """ % (
self . graph_name ,
src_label ,
tgt_label ,
)
record = await self . _query ( query )
if record and record [ 0 ] and record [ 0 ] [ " edge_properties " ] :
result = record [ 0 ] [ " edge_properties " ]
2025-02-19 13:42:49 +01:00
2025-01-27 09:39:39 +01:00
return result
2025-02-16 13:53:59 +01:00
async def get_node_edges ( self , source_node_id : str ) - > list [ tuple [ str , str ] ] | None :
2025-01-27 09:39:39 +01:00
"""
Retrieves all edges ( relationships ) for a particular node identified by its label .
2025-02-19 13:31:30 +01:00
: return : list of dictionaries containing edge information
2025-01-27 09:39:39 +01:00
"""
2025-04-17 02:31:56 +08:00
label = self . _normalize_node_id ( source_node_id )
2025-01-27 09:39:39 +01:00
query = """ SELECT * FROM cypher( ' %s ' , $$
2025-04-03 15:16:48 +08:00
MATCH ( n : base { entity_id : " %s " } )
2025-04-16 11:39:06 +08:00
OPTIONAL MATCH ( n ) - [ ] - ( connected )
2025-02-02 18:20:32 +08:00
RETURN n , connected
$ $ ) AS ( n agtype , connected agtype ) """ % (
2025-01-27 09:39:39 +01:00
self . graph_name ,
label ,
)
results = await self . _query ( query )
edges = [ ]
for record in results :
source_node = record [ " n " ] if record [ " n " ] else None
connected_node = record [ " connected " ] if record [ " connected " ] else None
2025-04-03 15:16:48 +08:00
if (
source_node
and connected_node
and " properties " in source_node
and " properties " in connected_node
) :
source_label = source_node [ " properties " ] . get ( " entity_id " )
target_label = connected_node [ " properties " ] . get ( " entity_id " )
2025-01-27 09:39:39 +01:00
2025-04-03 15:16:48 +08:00
if source_label and target_label :
edges . append ( ( source_label , target_label ) )
2025-01-27 09:39:39 +01:00
return edges
@retry (
stop = stop_after_attempt ( 3 ) ,
wait = wait_exponential ( multiplier = 1 , min = 4 , max = 10 ) ,
retry = retry_if_exception_type ( ( PGGraphQueryException , ) ) ,
2025-01-27 09:36:53 +01:00
)
2025-02-16 13:53:59 +01:00
async def upsert_node ( self , node_id : str , node_data : dict [ str , str ] ) - > None :
2025-04-03 18:41:11 +08:00
"""
Upsert a node in the Neo4j database .
Args :
node_id : The unique identifier for the node ( used as label )
node_data : Dictionary of node properties
"""
if " entity_id " not in node_data :
raise ValueError (
" PostgreSQL: node properties must contain an ' entity_id ' field "
)
2025-04-17 02:31:56 +08:00
label = self . _normalize_node_id ( node_id )
2025-04-03 15:16:48 +08:00
properties = self . _format_properties ( node_data )
2025-01-27 09:39:39 +01:00
query = """ SELECT * FROM cypher( ' %s ' , $$
2025-04-03 15:16:48 +08:00
MERGE ( n : base { entity_id : " %s " } )
2025-01-27 09:39:39 +01:00
SET n + = % s
RETURN n
$ $ ) AS ( n agtype ) """ % (
self . graph_name ,
label ,
2025-04-03 15:16:48 +08:00
properties ,
2025-01-27 09:39:39 +01:00
)
try :
await self . _query ( query , readonly = False , upsert = True )
2025-02-19 13:42:49 +01:00
2025-04-03 18:41:11 +08:00
except Exception :
2025-04-03 21:15:01 +08:00
logger . error ( f " POSTGRES, upsert_node error on node_id: ` { node_id } ` " )
2025-01-27 09:39:39 +01:00
raise
@retry (
stop = stop_after_attempt ( 3 ) ,
wait = wait_exponential ( multiplier = 1 , min = 4 , max = 10 ) ,
retry = retry_if_exception_type ( ( PGGraphQueryException , ) ) ,
2025-01-01 22:43:59 +08:00
)
2025-01-27 09:39:39 +01:00
async def upsert_edge (
2025-02-16 13:53:59 +01:00
self , source_node_id : str , target_node_id : str , edge_data : dict [ str , str ]
) - > None :
2025-01-27 09:39:39 +01:00
"""
Upsert an edge and its properties between two nodes identified by their labels .
2025-01-01 22:43:59 +08:00
2025-01-27 09:39:39 +01:00
Args :
source_node_id ( str ) : Label of the source node ( used as identifier )
target_node_id ( str ) : Label of the target node ( used as identifier )
2025-02-19 13:31:30 +01:00
edge_data ( dict ) : dictionary of properties to set on the edge
2025-01-27 09:39:39 +01:00
"""
2025-04-17 02:31:56 +08:00
src_label = self . _normalize_node_id ( source_node_id )
tgt_label = self . _normalize_node_id ( target_node_id )
2025-04-03 15:16:48 +08:00
edge_properties = self . _format_properties ( edge_data )
2025-01-01 22:43:59 +08:00
2025-01-27 09:39:39 +01:00
query = """ SELECT * FROM cypher( ' %s ' , $$
2025-04-03 15:16:48 +08:00
MATCH ( source : base { entity_id : " %s " } )
2025-01-27 09:39:39 +01:00
WITH source
2025-04-03 15:16:48 +08:00
MATCH ( target : base { entity_id : " %s " } )
2025-04-15 16:49:36 +08:00
MERGE ( source ) - [ r : DIRECTED ] - ( target )
2025-01-27 09:39:39 +01:00
SET r + = % s
RETURN r
$ $ ) AS ( r agtype ) """ % (
self . graph_name ,
src_label ,
tgt_label ,
2025-04-03 15:16:48 +08:00
edge_properties ,
2025-01-27 09:39:39 +01:00
)
2025-02-19 13:42:49 +01:00
2025-01-01 22:43:59 +08:00
try :
2025-01-27 09:39:39 +01:00
await self . _query ( query , readonly = False , upsert = True )
2025-02-19 13:42:49 +01:00
2025-04-03 18:41:11 +08:00
except Exception :
logger . error (
2025-04-03 21:15:01 +08:00
f " POSTGRES, upsert_edge error on edge: ` { source_node_id } `-` { target_node_id } ` "
2025-04-03 18:41:11 +08:00
)
2025-01-27 09:39:39 +01:00
raise
2025-02-16 13:53:59 +01:00
async def delete_node ( self , node_id : str ) - > None :
2025-03-04 15:50:53 +08:00
"""
Delete a node from the graph .
Args :
node_id ( str ) : The ID of the node to delete .
"""
2025-04-17 02:31:56 +08:00
label = self . _normalize_node_id ( node_id )
2025-03-04 15:50:53 +08:00
query = """ SELECT * FROM cypher( ' %s ' , $$
2025-04-02 14:03:56 +08:00
MATCH ( n : base { entity_id : " %s " } )
2025-03-04 15:50:53 +08:00
DETACH DELETE n
$ $ ) AS ( n agtype ) """ % (self.graph_name, label)
try :
await self . _query ( query , readonly = False )
except Exception as e :
logger . error ( " Error during node deletion: { %s } " , e )
raise
async def remove_nodes ( self , node_ids : list [ str ] ) - > None :
"""
Remove multiple nodes from the graph .
Args :
node_ids ( list [ str ] ) : A list of node IDs to remove .
"""
2025-04-17 02:31:56 +08:00
node_ids = [ self . _normalize_node_id ( node_id ) for node_id in node_ids ]
2025-04-03 15:16:48 +08:00
node_id_list = " , " . join ( [ f ' " { node_id } " ' for node_id in node_ids ] )
2025-03-04 15:50:53 +08:00
query = """ SELECT * FROM cypher( ' %s ' , $$
2025-04-02 14:03:56 +08:00
MATCH ( n : base )
2025-04-03 15:16:48 +08:00
WHERE n . entity_id IN [ % s ]
2025-03-04 15:50:53 +08:00
DETACH DELETE n
$ $ ) AS ( n agtype ) """ % (self.graph_name, node_id_list)
try :
await self . _query ( query , readonly = False )
except Exception as e :
logger . error ( " Error during node removal: { %s } " , e )
raise
async def remove_edges ( self , edges : list [ tuple [ str , str ] ] ) - > None :
"""
Remove multiple edges from the graph .
Args :
edges ( list [ tuple [ str , str ] ] ) : A list of edges to remove , where each edge is a tuple of ( source_node_id , target_node_id ) .
"""
2025-04-02 14:03:56 +08:00
for source , target in edges :
2025-04-17 02:31:56 +08:00
src_label = self . _normalize_node_id ( source )
tgt_label = self . _normalize_node_id ( target )
2025-03-04 15:50:53 +08:00
2025-04-02 14:03:56 +08:00
query = """ SELECT * FROM cypher( ' %s ' , $$
2025-04-15 16:49:36 +08:00
MATCH ( a : base { entity_id : " %s " } ) - [ r ] - ( b : base { entity_id : " %s " } )
2025-04-02 14:03:56 +08:00
DELETE r
$ $ ) AS ( r agtype ) """ % (self.graph_name, src_label, tgt_label)
2025-03-04 15:50:53 +08:00
2025-04-02 14:03:56 +08:00
try :
await self . _query ( query , readonly = False )
logger . debug ( f " Deleted edge from ' { source } ' to ' { target } ' " )
except Exception as e :
logger . error ( f " Error during edge deletion: { str ( e ) } " )
raise
2025-03-04 15:50:53 +08:00
2025-04-13 01:07:07 +08:00
async def get_nodes_batch ( self , node_ids : list [ str ] ) - > dict [ str , dict ] :
"""
Retrieve multiple nodes in one query using UNWIND .
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
Args :
node_ids : List of node entity IDs to fetch .
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
Returns :
A dictionary mapping each node_id to its node data ( or None if not found ) .
"""
if not node_ids :
return { }
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
# Format node IDs for the query
2025-04-15 12:34:04 +08:00
formatted_ids = " , " . join (
2025-04-17 02:31:56 +08:00
[ ' " ' + self . _normalize_node_id ( node_id ) + ' " ' for node_id in node_ids ]
2025-04-15 12:34:04 +08:00
)
2025-04-13 01:07:07 +08:00
query = """ SELECT * FROM cypher( ' %s ' , $$
UNWIND [ % s ] AS node_id
MATCH ( n : base { entity_id : node_id } )
RETURN node_id , n
2025-04-15 12:34:04 +08:00
$ $ ) AS ( node_id text , n agtype ) """ % (self.graph_name, formatted_ids)
2025-04-13 01:07:07 +08:00
results = await self . _query ( query )
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
# Build result dictionary
nodes_dict = { }
for result in results :
if result [ " node_id " ] and result [ " n " ] :
node_dict = result [ " n " ] [ " properties " ]
# Remove the 'base' label if present in a 'labels' property
if " labels " in node_dict :
2025-04-15 12:34:04 +08:00
node_dict [ " labels " ] = [
label for label in node_dict [ " labels " ] if label != " base "
]
2025-04-13 01:07:07 +08:00
nodes_dict [ result [ " node_id " ] ] = node_dict
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
return nodes_dict
async def node_degrees_batch ( self , node_ids : list [ str ] ) - > dict [ str , int ] :
"""
Retrieve the degree for multiple nodes in a single query using UNWIND .
2025-04-16 14:01:21 +08:00
Calculates the total degree by counting distinct relationships .
Uses separate queries for outgoing and incoming edges .
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
Args :
node_ids : List of node labels ( entity_id values ) to look up .
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
Returns :
2025-04-16 14:01:21 +08:00
A dictionary mapping each node_id to its degree ( total number of relationships ) .
2025-04-13 01:07:07 +08:00
If a node is not found , its degree will be set to 0.
"""
if not node_ids :
return { }
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
# Format node IDs for the query
2025-04-15 12:34:04 +08:00
formatted_ids = " , " . join (
2025-04-17 02:31:56 +08:00
[ ' " ' + self . _normalize_node_id ( node_id ) + ' " ' for node_id in node_ids ]
2025-04-15 12:34:04 +08:00
)
2025-04-16 14:01:21 +08:00
outgoing_query = """ SELECT * FROM cypher( ' %s ' , $$
2025-04-13 01:07:07 +08:00
UNWIND [ % s ] AS node_id
MATCH ( n : base { entity_id : node_id } )
2025-04-16 14:01:21 +08:00
OPTIONAL MATCH ( n ) - [ r ] - > ( a )
RETURN node_id , count ( a ) AS out_degree
$ $ ) AS ( node_id text , out_degree bigint ) """ % (
2025-04-13 01:07:07 +08:00
self . graph_name ,
2025-04-15 12:34:04 +08:00
formatted_ids ,
2025-04-13 01:07:07 +08:00
)
2025-04-15 12:34:04 +08:00
2025-04-16 14:01:21 +08:00
incoming_query = """ SELECT * FROM cypher( ' %s ' , $$
UNWIND [ % s ] AS node_id
MATCH ( n : base { entity_id : node_id } )
OPTIONAL MATCH ( n ) < - [ r ] - ( b )
RETURN node_id , count ( b ) AS in_degree
$ $ ) AS ( node_id text , in_degree bigint ) """ % (
self . graph_name ,
formatted_ids ,
)
2025-04-15 12:34:04 +08:00
2025-04-16 14:01:21 +08:00
outgoing_results = await self . _query ( outgoing_query )
incoming_results = await self . _query ( incoming_query )
2025-04-15 12:34:04 +08:00
2025-04-16 14:01:21 +08:00
out_degrees = { }
in_degrees = { }
2025-04-16 14:07:22 +08:00
2025-04-16 14:01:21 +08:00
for result in outgoing_results :
if result [ " node_id " ] is not None :
out_degrees [ result [ " node_id " ] ] = int ( result [ " out_degree " ] )
2025-04-16 14:07:22 +08:00
2025-04-16 14:01:21 +08:00
for result in incoming_results :
if result [ " node_id " ] is not None :
in_degrees [ result [ " node_id " ] ] = int ( result [ " in_degree " ] )
2025-04-16 14:07:22 +08:00
2025-04-16 14:01:21 +08:00
degrees_dict = { }
2025-04-13 01:07:07 +08:00
for node_id in node_ids :
2025-04-16 14:01:21 +08:00
out_degree = out_degrees . get ( node_id , 0 )
in_degree = in_degrees . get ( node_id , 0 )
degrees_dict [ node_id ] = out_degree + in_degree
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
return degrees_dict
2025-04-15 12:34:04 +08:00
async def edge_degrees_batch (
self , edges : list [ tuple [ str , str ] ]
) - > dict [ tuple [ str , str ] , int ] :
2025-04-13 01:07:07 +08:00
"""
Calculate the combined degree for each edge ( sum of the source and target node degrees )
in batch using the already implemented node_degrees_batch .
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
Args :
edges : List of ( source_node_id , target_node_id ) tuples
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
Returns :
Dictionary mapping edge tuples to their combined degrees
"""
if not edges :
return { }
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
# Use node_degrees_batch to get all node degrees efficiently
all_nodes = set ( )
for src , tgt in edges :
all_nodes . add ( src )
all_nodes . add ( tgt )
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
node_degrees = await self . node_degrees_batch ( list ( all_nodes ) )
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
# Calculate edge degrees
edge_degrees_dict = { }
for src , tgt in edges :
src_degree = node_degrees . get ( src , 0 )
tgt_degree = node_degrees . get ( tgt , 0 )
edge_degrees_dict [ ( src , tgt ) ] = src_degree + tgt_degree
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
return edge_degrees_dict
2025-04-15 12:34:04 +08:00
async def get_edges_batch (
self , pairs : list [ dict [ str , str ] ]
) - > dict [ tuple [ str , str ] , dict ] :
2025-04-13 01:07:07 +08:00
"""
Retrieve edge properties for multiple ( src , tgt ) pairs in one query .
2025-04-16 14:01:21 +08:00
Get forward and backward edges seperately and merge them before return
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
Args :
pairs : List of dictionaries , e . g . [ { " src " : " node1 " , " tgt " : " node2 " } , . . . ]
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
Returns :
A dictionary mapping ( src , tgt ) tuples to their edge properties .
"""
if not pairs :
return { }
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
src_nodes = [ ]
tgt_nodes = [ ]
for pair in pairs :
2025-04-17 02:31:56 +08:00
src_nodes . append ( self . _normalize_node_id ( pair [ " src " ] ) )
tgt_nodes . append ( self . _normalize_node_id ( pair [ " tgt " ] ) )
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
src_array = " , " . join ( [ f ' " { src } " ' for src in src_nodes ] )
tgt_array = " , " . join ( [ f ' " { tgt } " ' for tgt in tgt_nodes ] )
2025-04-15 12:34:04 +08:00
2025-04-16 14:01:21 +08:00
forward_query = f """ SELECT * FROM cypher( ' { self . graph_name } ' , $$
2025-04-13 01:07:07 +08:00
WITH [ { src_array } ] AS sources , [ { tgt_array } ] AS targets
UNWIND range ( 0 , size ( sources ) - 1 ) AS i
2025-04-16 02:13:30 +08:00
MATCH ( a : base { { entity_id : sources [ i ] } } ) - [ r : DIRECTED ] - > ( b : base { { entity_id : targets [ i ] } } )
2025-04-13 01:07:07 +08:00
RETURN sources [ i ] AS source , targets [ i ] AS target , properties ( r ) AS edge_properties
$ $ ) AS ( source text , target text , edge_properties agtype ) """
2025-04-15 12:34:04 +08:00
2025-04-16 14:01:21 +08:00
backward_query = f """ SELECT * FROM cypher( ' { self . graph_name } ' , $$
WITH [ { src_array } ] AS sources , [ { tgt_array } ] AS targets
UNWIND range ( 0 , size ( sources ) - 1 ) AS i
MATCH ( a : base { { entity_id : sources [ i ] } } ) < - [ r : DIRECTED ] - ( b : base { { entity_id : targets [ i ] } } )
RETURN sources [ i ] AS source , targets [ i ] AS target , properties ( r ) AS edge_properties
$ $ ) AS ( source text , target text , edge_properties agtype ) """
forward_results = await self . _query ( forward_query )
backward_results = await self . _query ( backward_query )
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
edges_dict = { }
2025-04-16 14:01:21 +08:00
for result in forward_results :
if result [ " source " ] and result [ " target " ] and result [ " edge_properties " ] :
2025-04-16 14:07:22 +08:00
edges_dict [ ( result [ " source " ] , result [ " target " ] ) ] = result [
" edge_properties "
]
2025-04-16 14:01:21 +08:00
for result in backward_results :
2025-04-13 01:07:07 +08:00
if result [ " source " ] and result [ " target " ] and result [ " edge_properties " ] :
2025-04-16 14:07:22 +08:00
edges_dict [ ( result [ " source " ] , result [ " target " ] ) ] = result [
" edge_properties "
]
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
return edges_dict
2025-04-15 12:34:04 +08:00
async def get_nodes_edges_batch (
self , node_ids : list [ str ]
) - > dict [ str , list [ tuple [ str , str ] ] ] :
2025-04-13 01:07:07 +08:00
"""
2025-04-16 14:01:21 +08:00
Get all edges ( both outgoing and incoming ) for multiple nodes in a single batch operation .
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
Args :
node_ids : List of node IDs to get edges for
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
Returns :
Dictionary mapping node IDs to lists of ( source , target ) edge tuples
"""
if not node_ids :
return { }
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
# Format node IDs for the query
2025-04-15 12:34:04 +08:00
formatted_ids = " , " . join (
2025-04-17 02:31:56 +08:00
[ ' " ' + self . _normalize_node_id ( node_id ) + ' " ' for node_id in node_ids ]
2025-04-15 12:34:04 +08:00
)
2025-04-16 14:01:21 +08:00
outgoing_query = """ SELECT * FROM cypher( ' %s ' , $$
2025-04-13 01:07:07 +08:00
UNWIND [ % s ] AS node_id
MATCH ( n : base { entity_id : node_id } )
2025-04-16 02:13:30 +08:00
OPTIONAL MATCH ( n : base ) - [ ] - > ( connected : base )
2025-04-13 01:07:07 +08:00
RETURN node_id , connected . entity_id AS connected_id
$ $ ) AS ( node_id text , connected_id text ) """ % (
self . graph_name ,
2025-04-15 12:34:04 +08:00
formatted_ids ,
2025-04-13 01:07:07 +08:00
)
2025-04-15 12:34:04 +08:00
2025-04-16 14:01:21 +08:00
incoming_query = """ SELECT * FROM cypher( ' %s ' , $$
UNWIND [ % s ] AS node_id
MATCH ( n : base { entity_id : node_id } )
OPTIONAL MATCH ( n : base ) < - [ ] - ( connected : base )
RETURN node_id , connected . entity_id AS connected_id
$ $ ) AS ( node_id text , connected_id text ) """ % (
self . graph_name ,
formatted_ids ,
)
outgoing_results = await self . _query ( outgoing_query )
incoming_results = await self . _query ( incoming_query )
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
nodes_edges_dict = { node_id : [ ] for node_id in node_ids }
2025-04-16 14:07:22 +08:00
2025-04-16 14:01:21 +08:00
for result in outgoing_results :
2025-04-13 01:07:07 +08:00
if result [ " node_id " ] and result [ " connected_id " ] :
nodes_edges_dict [ result [ " node_id " ] ] . append (
( result [ " node_id " ] , result [ " connected_id " ] )
)
2025-04-16 14:07:22 +08:00
2025-04-16 14:01:21 +08:00
for result in incoming_results :
if result [ " node_id " ] and result [ " connected_id " ] :
nodes_edges_dict [ result [ " node_id " ] ] . append (
( result [ " connected_id " ] , result [ " node_id " ] )
)
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
return nodes_edges_dict
2025-04-15 12:34:04 +08:00
2025-03-04 15:50:53 +08:00
async def get_all_labels ( self ) - > list [ str ] :
"""
Get all labels ( node IDs ) in the graph .
Returns :
list [ str ] : A list of all labels in the graph .
"""
2025-03-04 15:53:20 +08:00
query = (
""" SELECT * FROM cypher( ' %s ' , $$
2025-04-02 14:03:56 +08:00
MATCH ( n : base )
WHERE n . entity_id IS NOT NULL
RETURN DISTINCT n . entity_id AS label
2025-04-03 04:10:20 +08:00
ORDER BY n . entity_id
2025-03-04 15:53:20 +08:00
$ $ ) AS ( label text ) """
% self . graph_name
)
2025-03-04 15:50:53 +08:00
results = await self . _query ( query )
2025-04-22 19:49:31 +10:00
labels = [ ]
for result in results :
if result and isinstance ( result , dict ) and " label " in result :
labels . append ( result [ " label " ] )
2025-03-04 15:50:53 +08:00
return labels
2025-02-16 13:55:30 +01:00
2025-02-20 14:29:36 +01:00
async def get_knowledge_graph (
2025-04-03 15:16:48 +08:00
self ,
node_label : str ,
max_depth : int = 3 ,
max_nodes : int = MAX_GRAPH_NODES ,
2025-02-20 14:29:36 +01:00
) - > KnowledgeGraph :
2025-03-04 15:50:53 +08:00
"""
2025-04-03 15:16:48 +08:00
Retrieve a connected subgraph of nodes where the label includes the specified ` node_label ` .
2025-03-04 15:50:53 +08:00
Args :
2025-04-03 15:16:48 +08:00
node_label : Label of the starting node , * means all nodes
max_depth : Maximum depth of the subgraph , Defaults to 3
2025-04-03 16:30:06 +08:00
max_nodes : Maxiumu nodes to return , Defaults to 1000 ( not BFS nor DFS garanteed )
2025-03-04 15:50:53 +08:00
Returns :
2025-04-03 15:16:48 +08:00
KnowledgeGraph object containing nodes and edges , with an is_truncated flag
indicating whether the graph was truncated due to max_nodes limit
2025-03-04 15:50:53 +08:00
"""
2025-04-03 16:30:06 +08:00
# First, count the total number of nodes that would be returned without limit
if node_label == " * " :
count_query = f """ SELECT * FROM cypher( ' { self . graph_name } ' , $$
MATCH ( n : base )
RETURN count ( distinct n ) AS total_nodes
$ $ ) AS ( total_nodes bigint ) """
else :
2025-04-17 02:31:56 +08:00
strip_label = self . _normalize_node_id ( node_label )
2025-04-03 16:30:06 +08:00
count_query = f """ SELECT * FROM cypher( ' { self . graph_name } ' , $$
2025-04-24 08:40:06 +08:00
MATCH ( n : base { { entity_id : " {strip_label} " } } ) - [ r ] - ( )
RETURN count ( r ) AS total_nodes
2025-04-03 16:30:06 +08:00
$ $ ) AS ( total_nodes bigint ) """
2025-03-04 15:50:53 +08:00
2025-04-03 16:30:06 +08:00
count_result = await self . _query ( count_query )
total_nodes = count_result [ 0 ] [ " total_nodes " ] if count_result else 0
is_truncated = total_nodes > max_nodes
# Now get the actual data with limit
2025-03-04 15:50:53 +08:00
if node_label == " * " :
2025-03-13 11:30:52 -07:00
query = f """ SELECT * FROM cypher( ' { self . graph_name } ' , $$
2025-04-16 11:14:48 +08:00
MATCH ( node : base )
OPTIONAL MATCH ( node ) - [ r ] - > ( )
RETURN collect ( distinct node ) AS n , collect ( distinct r ) AS r
2025-04-03 16:30:06 +08:00
LIMIT { max_nodes }
2025-04-03 15:16:48 +08:00
$ $ ) AS ( n agtype , r agtype ) """
2025-03-04 15:50:53 +08:00
else :
2025-04-17 02:31:56 +08:00
strip_label = self . _normalize_node_id ( node_label )
2025-04-16 11:14:48 +08:00
if total_nodes > 0 :
query = f """ SELECT * FROM cypher( ' { self . graph_name } ' , $$
MATCH ( node : base { { entity_id : " {strip_label} " } } )
OPTIONAL MATCH p = ( node ) - [ * . . { max_depth } ] - ( )
RETURN nodes ( p ) AS n , relationships ( p ) AS r
LIMIT { max_nodes }
$ $ ) AS ( n agtype , r agtype ) """
else :
query = f """ SELECT * FROM cypher( ' { self . graph_name } ' , $$
MATCH ( node : base { { entity_id : " {strip_label} " } } )
RETURN node AS n
$ $ ) AS ( n agtype ) """
2025-03-04 15:50:53 +08:00
results = await self . _query ( query )
2025-04-03 15:16:48 +08:00
# Process the query results with deduplication by node and edge IDs
nodes_dict = { }
edges_dict = { }
for result in results :
# Handle single node cases
if result . get ( " n " ) and isinstance ( result [ " n " ] , dict ) :
node_id = str ( result [ " n " ] [ " id " ] )
if node_id not in nodes_dict :
nodes_dict [ node_id ] = KnowledgeGraphNode (
id = node_id ,
2025-04-03 15:40:55 +08:00
labels = [ result [ " n " ] [ " properties " ] [ " entity_id " ] ] ,
2025-04-03 15:16:48 +08:00
properties = result [ " n " ] [ " properties " ] ,
2025-03-17 15:59:54 +08:00
)
2025-04-03 15:16:48 +08:00
# Handle node list cases
elif result . get ( " n " ) and isinstance ( result [ " n " ] , list ) :
for node in result [ " n " ] :
if isinstance ( node , dict ) and " id " in node :
node_id = str ( node [ " id " ] )
if node_id not in nodes_dict and " properties " in node :
nodes_dict [ node_id ] = KnowledgeGraphNode (
id = node_id ,
2025-04-03 15:40:55 +08:00
labels = [ node [ " properties " ] [ " entity_id " ] ] ,
2025-04-03 15:16:48 +08:00
properties = node [ " properties " ] ,
)
2025-04-03 15:40:55 +08:00
2025-04-03 15:16:48 +08:00
# Handle single edge cases
if result . get ( " r " ) and isinstance ( result [ " r " ] , dict ) :
edge_id = str ( result [ " r " ] [ " id " ] )
if edge_id not in edges_dict :
edges_dict [ edge_id ] = KnowledgeGraphEdge (
id = edge_id ,
type = " DIRECTED " ,
source = str ( result [ " r " ] [ " start_id " ] ) ,
target = str ( result [ " r " ] [ " end_id " ] ) ,
properties = result [ " r " ] [ " properties " ] ,
)
# Handle edge list cases
elif result . get ( " r " ) and isinstance ( result [ " r " ] , list ) :
for edge in result [ " r " ] :
if isinstance ( edge , dict ) and " id " in edge :
edge_id = str ( edge [ " id " ] )
if edge_id not in edges_dict :
edges_dict [ edge_id ] = KnowledgeGraphEdge (
id = edge_id ,
type = " DIRECTED " ,
source = str ( edge [ " start_id " ] ) ,
target = str ( edge [ " end_id " ] ) ,
properties = edge [ " properties " ] ,
)
2025-03-04 15:50:53 +08:00
2025-04-03 15:16:48 +08:00
# Construct and return the KnowledgeGraph with deduplicated nodes and edges
2025-03-04 15:50:53 +08:00
kg = KnowledgeGraph (
2025-04-03 15:16:48 +08:00
nodes = list ( nodes_dict . values ( ) ) ,
edges = list ( edges_dict . values ( ) ) ,
2025-04-03 16:30:06 +08:00
is_truncated = is_truncated ,
2025-03-04 15:50:53 +08:00
)
2025-04-03 21:33:46 +08:00
logger . info (
f " Subgraph query successful | Node count: { len ( kg . nodes ) } | Edge count: { len ( kg . edges ) } "
)
2025-03-04 15:50:53 +08:00
return kg
2025-03-31 01:03:41 +08:00
async def drop ( self ) - > dict [ str , str ] :
2025-02-18 10:01:21 +01:00
""" Drop the storage """
2025-03-31 01:03:41 +08:00
try :
drop_query = f """ SELECT * FROM cypher( ' { self . graph_name } ' , $$
MATCH ( n )
DETACH DELETE n
$ $ ) AS ( result agtype ) """
2025-03-31 23:22:27 +08:00
2025-03-31 01:03:41 +08:00
await self . _query ( drop_query , readonly = False )
return { " status " : " success " , " message " : " graph data dropped " }
except Exception as e :
logger . error ( f " Error dropping graph: { e } " )
return { " status " : " error " , " message " : str ( e ) }
2025-02-16 13:55:30 +01:00
2025-02-18 10:24:19 +01:00
2025-01-27 09:39:39 +01:00
NAMESPACE_TABLE_MAP = {
2025-02-08 16:05:59 +08:00
NameSpace . KV_STORE_FULL_DOCS : " LIGHTRAG_DOC_FULL " ,
NameSpace . KV_STORE_TEXT_CHUNKS : " LIGHTRAG_DOC_CHUNKS " ,
NameSpace . VECTOR_STORE_CHUNKS : " LIGHTRAG_DOC_CHUNKS " ,
NameSpace . VECTOR_STORE_ENTITIES : " LIGHTRAG_VDB_ENTITY " ,
NameSpace . VECTOR_STORE_RELATIONSHIPS : " LIGHTRAG_VDB_RELATION " ,
NameSpace . DOC_STATUS : " LIGHTRAG_DOC_STATUS " ,
NameSpace . KV_STORE_LLM_RESPONSE_CACHE : " LIGHTRAG_LLM_CACHE " ,
2025-01-27 09:39:39 +01:00
}
2025-01-01 22:43:59 +08:00
2025-02-08 16:05:59 +08:00
def namespace_to_table_name ( namespace : str ) - > str :
for k , v in NAMESPACE_TABLE_MAP . items ( ) :
if is_namespace ( namespace , k ) :
return v
2025-01-27 09:39:39 +01:00
TABLES = {
" LIGHTRAG_DOC_FULL " : {
" ddl " : """ CREATE TABLE LIGHTRAG_DOC_FULL (
id VARCHAR ( 255 ) ,
workspace VARCHAR ( 255 ) ,
doc_name VARCHAR ( 1024 ) ,
content TEXT ,
meta JSONB ,
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP ,
update_time TIMESTAMP ,
CONSTRAINT LIGHTRAG_DOC_FULL_PK PRIMARY KEY ( workspace , id )
) """
} ,
" LIGHTRAG_DOC_CHUNKS " : {
" ddl " : """ CREATE TABLE LIGHTRAG_DOC_CHUNKS (
id VARCHAR ( 255 ) ,
workspace VARCHAR ( 255 ) ,
full_doc_id VARCHAR ( 256 ) ,
chunk_order_index INTEGER ,
tokens INTEGER ,
content TEXT ,
content_vector VECTOR ,
2025-03-17 23:59:47 +08:00
file_path VARCHAR ( 256 ) ,
2025-01-27 09:39:39 +01:00
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP ,
update_time TIMESTAMP ,
CONSTRAINT LIGHTRAG_DOC_CHUNKS_PK PRIMARY KEY ( workspace , id )
) """
} ,
" LIGHTRAG_VDB_ENTITY " : {
" ddl " : """ CREATE TABLE LIGHTRAG_VDB_ENTITY (
id VARCHAR ( 255 ) ,
workspace VARCHAR ( 255 ) ,
entity_name VARCHAR ( 255 ) ,
content TEXT ,
content_vector VECTOR ,
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP ,
update_time TIMESTAMP ,
2025-03-19 12:59:44 +08:00
chunk_ids VARCHAR ( 255 ) [ ] NULL ,
2025-03-17 23:59:47 +08:00
file_path TEXT NULL ,
2025-01-27 09:39:39 +01:00
CONSTRAINT LIGHTRAG_VDB_ENTITY_PK PRIMARY KEY ( workspace , id )
) """
} ,
" LIGHTRAG_VDB_RELATION " : {
" ddl " : """ CREATE TABLE LIGHTRAG_VDB_RELATION (
id VARCHAR ( 255 ) ,
workspace VARCHAR ( 255 ) ,
source_id VARCHAR ( 256 ) ,
target_id VARCHAR ( 256 ) ,
content TEXT ,
content_vector VECTOR ,
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP ,
update_time TIMESTAMP ,
2025-03-19 12:59:44 +08:00
chunk_ids VARCHAR ( 255 ) [ ] NULL ,
2025-03-17 23:59:47 +08:00
file_path TEXT NULL ,
2025-01-27 09:39:39 +01:00
CONSTRAINT LIGHTRAG_VDB_RELATION_PK PRIMARY KEY ( workspace , id )
) """
} ,
" LIGHTRAG_LLM_CACHE " : {
" ddl " : """ CREATE TABLE LIGHTRAG_LLM_CACHE (
workspace varchar ( 255 ) NOT NULL ,
id varchar ( 255 ) NOT NULL ,
mode varchar ( 32 ) NOT NULL ,
original_prompt TEXT ,
return_value TEXT ,
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP ,
update_time TIMESTAMP ,
CONSTRAINT LIGHTRAG_LLM_CACHE_PK PRIMARY KEY ( workspace , mode , id )
) """
} ,
" LIGHTRAG_DOC_STATUS " : {
" ddl " : """ CREATE TABLE LIGHTRAG_DOC_STATUS (
workspace varchar ( 255 ) NOT NULL ,
id varchar ( 255 ) NOT NULL ,
2025-02-11 16:11:15 +08:00
content TEXT NULL ,
2025-01-27 09:39:39 +01:00
content_summary varchar ( 255 ) NULL ,
content_length int4 NULL ,
chunks_count int4 NULL ,
status varchar ( 64 ) NULL ,
2025-03-17 23:59:47 +08:00
file_path TEXT NULL ,
2025-01-27 09:39:39 +01:00
created_at timestamp DEFAULT CURRENT_TIMESTAMP NULL ,
updated_at timestamp DEFAULT CURRENT_TIMESTAMP NULL ,
CONSTRAINT LIGHTRAG_DOC_STATUS_PK PRIMARY KEY ( workspace , id )
) """
} ,
}
2025-01-01 22:43:59 +08:00
2025-01-27 09:39:39 +01:00
SQL_TEMPLATES = {
# SQL for KVStorage
" get_by_id_full_docs " : """ SELECT id, COALESCE(content, ' ' ) as content
FROM LIGHTRAG_DOC_FULL WHERE workspace = $ 1 AND id = $ 2
""" ,
" get_by_id_text_chunks " : """ SELECT id, tokens, COALESCE(content, ' ' ) as content,
2025-04-04 12:27:36 +02:00
chunk_order_index , full_doc_id , file_path
2025-01-27 09:39:39 +01:00
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace = $ 1 AND id = $ 2
""" ,
" get_by_id_llm_response_cache " : """ SELECT id, original_prompt, COALESCE(return_value, ' ' ) as " return " , mode
FROM LIGHTRAG_LLM_CACHE WHERE workspace = $ 1 AND mode = $ 2
""" ,
" get_by_mode_id_llm_response_cache " : """ SELECT id, original_prompt, COALESCE(return_value, ' ' ) as " return " , mode
FROM LIGHTRAG_LLM_CACHE WHERE workspace = $ 1 AND mode = $ 2 AND id = $ 3
""" ,
" get_by_ids_full_docs " : """ SELECT id, COALESCE(content, ' ' ) as content
FROM LIGHTRAG_DOC_FULL WHERE workspace = $ 1 AND id IN ( { ids } )
""" ,
" get_by_ids_text_chunks " : """ SELECT id, tokens, COALESCE(content, ' ' ) as content,
2025-04-04 12:27:36 +02:00
chunk_order_index , full_doc_id , file_path
2025-01-27 09:39:39 +01:00
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace = $ 1 AND id IN ( { ids } )
""" ,
" get_by_ids_llm_response_cache " : """ SELECT id, original_prompt, COALESCE(return_value, ' ' ) as " return " , mode
FROM LIGHTRAG_LLM_CACHE WHERE workspace = $ 1 AND mode = IN ( { ids } )
""" ,
" filter_keys " : " SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ( {ids} ) " ,
" upsert_doc_full " : """ INSERT INTO LIGHTRAG_DOC_FULL (id, content, workspace)
VALUES ( $ 1 , $ 2 , $ 3 )
ON CONFLICT ( workspace , id ) DO UPDATE
SET content = $ 2 , update_time = CURRENT_TIMESTAMP
""" ,
" upsert_llm_response_cache " : """ INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,mode)
VALUES ( $ 1 , $ 2 , $ 3 , $ 4 , $ 5 )
ON CONFLICT ( workspace , mode , id ) DO UPDATE
SET original_prompt = EXCLUDED . original_prompt ,
return_value = EXCLUDED . return_value ,
mode = EXCLUDED . mode ,
update_time = CURRENT_TIMESTAMP
""" ,
" upsert_chunk " : """ INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens,
2025-03-17 23:59:47 +08:00
chunk_order_index , full_doc_id , content , content_vector , file_path )
VALUES ( $ 1 , $ 2 , $ 3 , $ 4 , $ 5 , $ 6 , $ 7 , $ 8 )
2025-01-27 09:39:39 +01:00
ON CONFLICT ( workspace , id ) DO UPDATE
SET tokens = EXCLUDED . tokens ,
chunk_order_index = EXCLUDED . chunk_order_index ,
full_doc_id = EXCLUDED . full_doc_id ,
content = EXCLUDED . content ,
content_vector = EXCLUDED . content_vector ,
2025-03-17 23:59:47 +08:00
file_path = EXCLUDED . file_path ,
2025-01-27 09:39:39 +01:00
update_time = CURRENT_TIMESTAMP
""" ,
2025-03-31 01:03:41 +08:00
# SQL for VectorStorage
2025-03-10 15:39:18 +00:00
" upsert_entity " : """ INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content,
2025-03-17 23:59:47 +08:00
content_vector , chunk_ids , file_path )
2025-03-19 13:01:07 +08:00
VALUES ( $ 1 , $ 2 , $ 3 , $ 4 , $ 5 , $ 6 : : varchar [ ] , $ 7 )
2025-01-27 09:39:39 +01:00
ON CONFLICT ( workspace , id ) DO UPDATE
SET entity_name = EXCLUDED . entity_name ,
content = EXCLUDED . content ,
content_vector = EXCLUDED . content_vector ,
2025-03-13 13:45:09 +02:00
chunk_ids = EXCLUDED . chunk_ids ,
2025-03-17 23:59:47 +08:00
file_path = EXCLUDED . file_path ,
2025-01-27 09:39:39 +01:00
update_time = CURRENT_TIMESTAMP
""" ,
" upsert_relationship " : """ INSERT INTO LIGHTRAG_VDB_RELATION (workspace, id, source_id,
2025-03-17 23:59:47 +08:00
target_id , content , content_vector , chunk_ids , file_path )
2025-03-19 13:01:07 +08:00
VALUES ( $ 1 , $ 2 , $ 3 , $ 4 , $ 5 , $ 6 , $ 7 : : varchar [ ] , $ 8 )
2025-01-27 09:39:39 +01:00
ON CONFLICT ( workspace , id ) DO UPDATE
SET source_id = EXCLUDED . source_id ,
target_id = EXCLUDED . target_id ,
content = EXCLUDED . content ,
2025-03-13 13:45:09 +02:00
content_vector = EXCLUDED . content_vector ,
chunk_ids = EXCLUDED . chunk_ids ,
2025-03-17 23:59:47 +08:00
file_path = EXCLUDED . file_path ,
2025-03-13 13:45:09 +02:00
update_time = CURRENT_TIMESTAMP
2025-03-31 23:22:27 +08:00
""" ,
2025-03-08 15:43:17 +00:00
" relationships " : """
WITH relevant_chunks AS (
2025-03-10 15:39:18 +00:00
SELECT id as chunk_id
FROM LIGHTRAG_DOC_CHUNKS
2025-04-21 21:26:57 +03:00
WHERE $ 2 : : varchar [ ] IS NULL OR full_doc_id = ANY ( $ 2 : : varchar [ ] )
2025-03-08 15:43:17 +00:00
)
2025-03-10 15:39:18 +00:00
SELECT source_id as src_id , target_id as tgt_id
2025-03-08 15:43:17 +00:00
FROM (
SELECT r . id , r . source_id , r . target_id , 1 - ( r . content_vector < = > ' [ {embedding_string} ] ' : : vector ) as distance
FROM LIGHTRAG_VDB_RELATION r
2025-03-13 13:45:09 +02:00
JOIN relevant_chunks c ON c . chunk_id = ANY ( r . chunk_ids )
2025-03-10 15:39:18 +00:00
WHERE r . workspace = $ 1
2025-03-08 15:43:17 +00:00
) filtered
2025-04-21 21:20:21 +03:00
WHERE distance > $ 3
2025-03-10 15:39:18 +00:00
ORDER BY distance DESC
2025-04-21 21:20:21 +03:00
LIMIT $ 4
2025-03-08 15:43:17 +00:00
""" ,
2025-03-10 15:39:18 +00:00
" entities " : """
2025-03-08 15:43:17 +00:00
WITH relevant_chunks AS (
2025-03-10 15:39:18 +00:00
SELECT id as chunk_id
FROM LIGHTRAG_DOC_CHUNKS
2025-04-21 21:26:57 +03:00
WHERE $ 2 : : varchar [ ] IS NULL OR full_doc_id = ANY ( $ 2 : : varchar [ ] )
2025-03-08 15:43:17 +00:00
)
SELECT entity_name FROM
(
2025-03-13 13:45:09 +02:00
SELECT e . id , e . entity_name , 1 - ( e . content_vector < = > ' [ {embedding_string} ] ' : : vector ) as distance
FROM LIGHTRAG_VDB_ENTITY e
JOIN relevant_chunks c ON c . chunk_id = ANY ( e . chunk_ids )
WHERE e . workspace = $ 1
2025-04-15 14:57:17 +08:00
) as chunk_distances
2025-04-21 21:20:21 +03:00
WHERE distance > $ 3
2025-04-15 14:57:17 +08:00
ORDER BY distance DESC
2025-04-21 21:20:21 +03:00
LIMIT $ 4
2025-03-10 15:39:18 +00:00
""" ,
" chunks " : """
2025-03-08 15:43:17 +00:00
WITH relevant_chunks AS (
2025-03-10 15:39:18 +00:00
SELECT id as chunk_id
FROM LIGHTRAG_DOC_CHUNKS
2025-04-21 21:26:57 +03:00
WHERE $ 2 : : varchar [ ] IS NULL OR full_doc_id = ANY ( $ 2 : : varchar [ ] )
2025-03-08 15:43:17 +00:00
)
2025-04-02 14:30:13 +08:00
SELECT id , content , file_path FROM
2025-03-08 15:43:17 +00:00
(
2025-04-02 14:30:13 +08:00
SELECT id , content , file_path , 1 - ( content_vector < = > ' [ {embedding_string} ] ' : : vector ) as distance
2025-03-10 15:39:18 +00:00
FROM LIGHTRAG_DOC_CHUNKS
2025-04-21 21:20:21 +03:00
WHERE workspace = $ 1
2025-03-08 20:25:20 +00:00
AND id IN ( SELECT chunk_id FROM relevant_chunks )
2025-03-17 10:47:17 +08:00
) as chunk_distances
2025-04-21 21:20:21 +03:00
WHERE distance > $ 3
2025-03-10 15:39:18 +00:00
ORDER BY distance DESC
2025-04-21 21:20:21 +03:00
LIMIT $ 4
2025-03-10 15:39:18 +00:00
""" ,
2025-03-31 01:03:41 +08:00
# DROP tables
" drop_specifiy_table_workspace " : """
DELETE FROM { table_name } WHERE workspace = $ 1
""" ,
2025-03-10 15:39:18 +00:00
}