2025-01-01 22:43:59 +08:00
import asyncio
2025-01-27 09:39:39 +01:00
import json
2025-01-01 22:43:59 +08:00
import os
2025-05-01 10:04:17 +08:00
import datetime
2025-05-01 15:13:42 +08:00
from datetime import timezone
2025-02-19 04:53:15 +08:00
from dataclasses import dataclass , field
2025-02-19 13:31:30 +01:00
from typing import Any , Union , final
2025-02-09 19:51:05 +01:00
import numpy as np
2025-02-19 03:46:18 +08:00
import configparser
2025-01-27 23:21:34 +08:00
2025-03-04 15:50:53 +08:00
from lightrag . types import KnowledgeGraph , KnowledgeGraphNode , KnowledgeGraphEdge
2025-02-16 13:53:59 +01:00
2025-01-27 09:39:39 +01:00
from tenacity import (
retry ,
retry_if_exception_type ,
stop_after_attempt ,
wait_exponential ,
)
2025-01-01 22:43:59 +08:00
2025-01-27 09:39:39 +01:00
from . . base import (
2025-02-09 19:51:05 +01:00
BaseGraphStorage ,
2025-01-27 09:39:39 +01:00
BaseKVStorage ,
BaseVectorStorage ,
DocProcessingStatus ,
2025-02-09 19:51:05 +01:00
DocStatus ,
DocStatusStorage ,
2025-01-27 09:39:39 +01:00
)
2025-02-08 16:05:59 +08:00
from . . namespace import NameSpace , is_namespace
2025-02-09 19:51:05 +01:00
from . . utils import logger
2025-01-01 22:43:59 +08:00
2025-02-16 15:08:50 +01:00
import pipmaster as pm
if not pm . is_installed ( " asyncpg " ) :
pm . install ( " asyncpg " )
2025-03-01 16:23:34 +08:00
import asyncpg # type: ignore
from asyncpg import Pool # type: ignore
2025-01-01 22:43:59 +08:00
2025-04-17 01:28:22 +08:00
from dotenv import load_dotenv
# use the .env that is inside the current folder
# allows to use different .env file for each lightrag instance
# the OS environment variables take precedence over the .env file
load_dotenv ( dotenv_path = " .env " , override = False )
2025-04-03 04:10:20 +08:00
# Get maximum number of graph nodes from environment variable, default is 1000
MAX_GRAPH_NODES = int ( os . getenv ( " MAX_GRAPH_NODES " , 1000 ) )
2025-02-19 20:50:39 +01:00
2025-01-27 09:39:39 +01:00
class PostgreSQLDB :
2025-02-19 13:31:30 +01:00
def __init__ ( self , config : dict [ str , Any ] , * * kwargs : Any ) :
2025-01-27 09:39:39 +01:00
self . host = config . get ( " host " , " localhost " )
self . port = config . get ( " port " , 5432 )
self . user = config . get ( " user " , " postgres " )
self . password = config . get ( " password " , None )
self . database = config . get ( " database " , " postgres " )
self . workspace = config . get ( " workspace " , " default " )
self . max = 12
self . increment = 1
2025-02-19 13:31:30 +01:00
self . pool : Pool | None = None
2025-01-01 22:43:59 +08:00
2025-01-27 09:39:39 +01:00
if self . user is None or self . password is None or self . database is None :
2025-02-20 13:09:33 +01:00
raise ValueError ( " Missing database user, password, or database " )
2025-01-01 22:43:59 +08:00
2025-01-27 09:39:39 +01:00
async def initdb ( self ) :
try :
2025-02-19 13:31:30 +01:00
self . pool = await asyncpg . create_pool ( # type: ignore
2025-01-27 09:39:39 +01:00
user = self . user ,
password = self . password ,
database = self . database ,
host = self . host ,
port = self . port ,
min_size = 1 ,
max_size = self . max ,
)
logger . info (
2025-02-19 13:31:30 +01:00
f " PostgreSQL, Connected to database at { self . host } : { self . port } / { self . database } "
2025-01-27 09:39:39 +01:00
)
except Exception as e :
logger . error (
2025-02-19 13:31:30 +01:00
f " PostgreSQL, Failed to connect database at { self . host } : { self . port } / { self . database } , Got: { e } "
2025-01-27 09:39:39 +01:00
)
raise
2025-01-01 22:43:59 +08:00
2025-02-19 14:26:46 +01:00
@staticmethod
async def configure_age ( connection : asyncpg . Connection , graph_name : str ) - > None :
""" Set the Apache AGE environment and creates a graph if it does not exist.
This method :
- Sets the PostgreSQL ` search_path ` to include ` ag_catalog ` , ensuring that Apache AGE functions can be used without specifying the schema .
- Attempts to create a new graph with the provided ` graph_name ` if it does not already exist .
- Silently ignores errors related to the graph already existing .
"""
try :
await connection . execute ( # type: ignore
' SET search_path = ag_catalog, " $user " , public '
)
await connection . execute ( # type: ignore
f " select create_graph( ' { graph_name } ' ) "
)
except (
asyncpg . exceptions . InvalidSchemaNameError ,
asyncpg . exceptions . UniqueViolationError ,
) :
pass
2025-02-19 13:31:30 +01:00
2025-05-03 00:44:55 +08:00
async def _migrate_timestamp_columns ( self ) :
""" Migrate timestamp columns in tables to timezone-aware types, assuming original data is in UTC time """
# Tables and columns that need migration
tables_to_migrate = {
" LIGHTRAG_VDB_ENTITY " : [ " create_time " , " update_time " ] ,
" LIGHTRAG_VDB_RELATION " : [ " create_time " , " update_time " ] ,
" LIGHTRAG_DOC_CHUNKS " : [ " create_time " , " update_time " ] ,
}
for table_name , columns in tables_to_migrate . items ( ) :
for column_name in columns :
try :
# Check if column exists
check_column_sql = f """
SELECT column_name , data_type
FROM information_schema . columns
WHERE table_name = ' { table_name.lower()} '
AND column_name = ' {column_name} '
"""
column_info = await self . query ( check_column_sql )
if not column_info :
logger . warning (
f " Column { table_name } . { column_name } does not exist, skipping migration "
)
continue
# Check column type
data_type = column_info . get ( " data_type " )
if data_type == " timestamp with time zone " :
logger . info (
f " Column { table_name } . { column_name } is already timezone-aware, no migration needed "
)
continue
# Execute migration, explicitly specifying UTC timezone for interpreting original data
logger . info (
f " Migrating { table_name } . { column_name } to timezone-aware type "
)
migration_sql = f """
ALTER TABLE { table_name }
ALTER COLUMN { column_name } TYPE TIMESTAMP ( 0 ) WITH TIME ZONE
USING { column_name } AT TIME ZONE ' UTC '
"""
await self . execute ( migration_sql )
logger . info (
f " Successfully migrated { table_name } . { column_name } to timezone-aware type "
)
except Exception as e :
# Log error but don't interrupt the process
logger . warning ( f " Failed to migrate { table_name } . { column_name } : { e } " )
2025-01-27 09:39:39 +01:00
async def check_tables ( self ) :
2025-05-03 00:44:55 +08:00
# First create all tables
2025-01-27 09:39:39 +01:00
for k , v in TABLES . items ( ) :
2025-01-01 22:43:59 +08:00
try :
2025-02-09 19:51:05 +01:00
await self . query ( f " SELECT 1 FROM { k } LIMIT 1 " )
2025-02-19 14:26:46 +01:00
except Exception :
2025-01-27 09:39:39 +01:00
try :
2025-02-19 13:31:30 +01:00
logger . info ( f " PostgreSQL, Try Creating table { k } in database " )
2025-01-27 09:39:39 +01:00
await self . execute ( v [ " ddl " ] )
2025-02-19 14:26:46 +01:00
logger . info (
f " PostgreSQL, Creation success table { k } in PostgreSQL database "
)
2025-01-27 09:39:39 +01:00
except Exception as e :
2025-02-19 13:31:30 +01:00
logger . error (
f " PostgreSQL, Failed to create table { k } in database, Please verify the connection with PostgreSQL database, Got: { e } "
)
raise e
2025-04-03 17:31:01 +08:00
2025-04-03 17:29:52 +08:00
# Create index for id column in each table
try :
index_name = f " idx_ { k . lower ( ) } _id "
check_index_sql = f """
2025-04-03 17:31:01 +08:00
SELECT 1 FROM pg_indexes
WHERE indexname = ' {index_name} '
2025-04-03 17:29:52 +08:00
AND tablename = ' { k.lower()} '
"""
index_exists = await self . query ( check_index_sql )
2025-04-03 17:31:01 +08:00
2025-04-03 17:29:52 +08:00
if not index_exists :
create_index_sql = f " CREATE INDEX { index_name } ON { k } (id) "
logger . info ( f " PostgreSQL, Creating index { index_name } on table { k } " )
await self . execute ( create_index_sql )
except Exception as e :
2025-04-03 17:31:01 +08:00
logger . error (
f " PostgreSQL, Failed to create index on table { k } , Got: { e } "
)
2025-01-27 09:39:39 +01:00
2025-05-03 00:44:55 +08:00
# After all tables are created, attempt to migrate timestamp fields
try :
await self . _migrate_timestamp_columns ( )
except Exception as e :
logger . error ( f " PostgreSQL, Failed to migrate timestamp columns: { e } " )
# Don't throw an exception, allow the initialization process to continue
2025-01-27 09:39:39 +01:00
async def query (
self ,
sql : str ,
2025-02-19 13:31:30 +01:00
params : dict [ str , Any ] | None = None ,
2025-01-27 09:39:39 +01:00
multirows : bool = False ,
2025-02-19 14:26:46 +01:00
with_age : bool = False ,
graph_name : str | None = None ,
2025-02-19 13:31:30 +01:00
) - > dict [ str , Any ] | None | list [ dict [ str , Any ] ] :
2025-04-16 17:55:49 +08:00
# start_time = time.time()
# logger.info(f"PostgreSQL, Querying:\n{sql}")
2025-02-19 13:31:30 +01:00
async with self . pool . acquire ( ) as connection : # type: ignore
2025-02-19 14:26:46 +01:00
if with_age and graph_name :
await self . configure_age ( connection , graph_name ) # type: ignore
elif with_age and not graph_name :
raise ValueError ( " Graph name is required when with_age is True " )
2025-01-27 09:39:39 +01:00
try :
if params :
rows = await connection . fetch ( sql , * params . values ( ) )
else :
rows = await connection . fetch ( sql )
if multirows :
if rows :
columns = [ col for col in rows [ 0 ] . keys ( ) ]
data = [ dict ( zip ( columns , row ) ) for row in rows ]
else :
data = [ ]
else :
if rows :
columns = rows [ 0 ] . keys ( )
data = dict ( zip ( columns , rows [ 0 ] ) )
else :
data = None
2025-04-16 17:55:49 +08:00
# query_time = time.time() - start_time
# logger.info(f"PostgreSQL, Query result len: {len(data)}")
# logger.info(f"PostgreSQL, Query execution time: {query_time:.4f}s")
2025-01-27 09:39:39 +01:00
return data
except Exception as e :
2025-02-19 14:26:46 +01:00
logger . error ( f " PostgreSQL database, error: { e } " )
2025-01-27 09:39:39 +01:00
raise
async def execute (
self ,
sql : str ,
2025-02-19 13:31:30 +01:00
data : dict [ str , Any ] | None = None ,
2025-01-27 09:39:39 +01:00
upsert : bool = False ,
2025-02-19 14:26:46 +01:00
with_age : bool = False ,
graph_name : str | None = None ,
2025-01-27 09:39:39 +01:00
) :
try :
2025-02-19 13:31:30 +01:00
async with self . pool . acquire ( ) as connection : # type: ignore
2025-02-19 15:09:41 +01:00
if with_age and graph_name :
await self . configure_age ( connection , graph_name ) # type: ignore
elif with_age and not graph_name :
raise ValueError ( " Graph name is required when with_age is True " )
2025-01-27 09:39:39 +01:00
if data is None :
2025-02-19 13:31:30 +01:00
await connection . execute ( sql ) # type: ignore
2025-01-27 09:39:39 +01:00
else :
2025-02-19 13:31:30 +01:00
await connection . execute ( sql , * data . values ( ) ) # type: ignore
2025-01-27 09:39:39 +01:00
except (
asyncpg . exceptions . UniqueViolationError ,
asyncpg . exceptions . DuplicateTableError ,
) as e :
2025-02-20 15:09:43 +01:00
if upsert :
print ( " Key value duplicate, but upsert succeeded. " )
else :
logger . error ( f " Upsert error: { e } " )
2025-01-27 09:39:39 +01:00
except Exception as e :
2025-02-20 15:09:43 +01:00
logger . error ( f " PostgreSQL database, \n sql: { sql } , \n data: { data } , \n error: { e } " )
2025-01-27 09:39:39 +01:00
raise
2025-02-19 03:46:18 +08:00
class ClientManager :
2025-02-19 13:31:30 +01:00
_instances : dict [ str , Any ] = { " db " : None , " ref_count " : 0 }
2025-02-19 03:46:18 +08:00
_lock = asyncio . Lock ( )
@staticmethod
2025-02-19 13:31:30 +01:00
def get_config ( ) - > dict [ str , Any ] :
2025-02-19 03:46:18 +08:00
config = configparser . ConfigParser ( )
config . read ( " config.ini " , " utf-8 " )
return {
" host " : os . environ . get (
" POSTGRES_HOST " ,
config . get ( " postgres " , " host " , fallback = " localhost " ) ,
) ,
" port " : os . environ . get (
" POSTGRES_PORT " , config . get ( " postgres " , " port " , fallback = 5432 )
) ,
" user " : os . environ . get (
" POSTGRES_USER " , config . get ( " postgres " , " user " , fallback = None )
) ,
" password " : os . environ . get (
" POSTGRES_PASSWORD " ,
config . get ( " postgres " , " password " , fallback = None ) ,
) ,
" database " : os . environ . get (
" POSTGRES_DATABASE " ,
config . get ( " postgres " , " database " , fallback = None ) ,
) ,
" workspace " : os . environ . get (
" POSTGRES_WORKSPACE " ,
config . get ( " postgres " , " workspace " , fallback = " default " ) ,
) ,
}
@classmethod
async def get_client ( cls ) - > PostgreSQLDB :
async with cls . _lock :
if cls . _instances [ " db " ] is None :
config = ClientManager . get_config ( )
db = PostgreSQLDB ( config )
await db . initdb ( )
await db . check_tables ( )
cls . _instances [ " db " ] = db
cls . _instances [ " ref_count " ] = 0
cls . _instances [ " ref_count " ] + = 1
return cls . _instances [ " db " ]
@classmethod
async def release_client ( cls , db : PostgreSQLDB ) :
async with cls . _lock :
if db is not None :
if db is cls . _instances [ " db " ] :
cls . _instances [ " ref_count " ] - = 1
if cls . _instances [ " ref_count " ] == 0 :
await db . pool . close ( )
logger . info ( " Closed PostgreSQL database connection pool " )
cls . _instances [ " db " ] = None
else :
await db . pool . close ( )
2025-02-16 14:38:09 +01:00
@final
2025-01-27 09:39:39 +01:00
@dataclass
class PGKVStorage ( BaseKVStorage ) :
2025-02-19 04:55:59 +08:00
db : PostgreSQLDB = field ( default = None )
2025-01-27 09:39:39 +01:00
def __post_init__ ( self ) :
self . _max_batch_size = self . global_config [ " embedding_batch_num " ]
2025-02-19 03:46:18 +08:00
async def initialize ( self ) :
2025-02-19 04:53:15 +08:00
if self . db is None :
2025-02-19 03:46:18 +08:00
self . db = await ClientManager . get_client ( )
async def finalize ( self ) :
2025-02-19 04:53:15 +08:00
if self . db is not None :
2025-02-19 03:46:18 +08:00
await ClientManager . release_client ( self . db )
self . db = None
2025-01-27 09:39:39 +01:00
################ QUERY METHODS ################
2025-04-17 11:17:01 +08:00
async def get_all ( self ) - > dict [ str , Any ] :
""" Get all data from storage
Returns :
Dictionary containing all stored data
"""
table_name = namespace_to_table_name ( self . namespace )
if not table_name :
logger . error ( f " Unknown namespace for get_all: { self . namespace } " )
return { }
sql = f " SELECT * FROM { table_name } WHERE workspace=$1 "
params = { " workspace " : self . db . workspace }
try :
results = await self . db . query ( sql , params , multirows = True )
if is_namespace ( self . namespace , NameSpace . KV_STORE_LLM_RESPONSE_CACHE ) :
result_dict = { }
for row in results :
mode = row [ " mode " ]
if mode not in result_dict :
result_dict [ mode ] = { }
result_dict [ mode ] [ row [ " id " ] ] = row
return result_dict
else :
return { row [ " id " ] : row for row in results }
except Exception as e :
logger . error ( f " Error retrieving all data from { self . namespace } : { e } " )
return { }
2025-01-27 09:39:39 +01:00
2025-02-16 13:31:12 +01:00
async def get_by_id ( self , id : str ) - > dict [ str , Any ] | None :
2025-01-27 09:39:39 +01:00
""" Get doc_full data by id. """
2025-03-31 02:59:44 +08:00
sql = SQL_TEMPLATES [ " get_by_id_ " + self . namespace ]
2025-01-27 09:39:39 +01:00
params = { " workspace " : self . db . workspace , " id " : id }
2025-02-08 16:05:59 +08:00
if is_namespace ( self . namespace , NameSpace . KV_STORE_LLM_RESPONSE_CACHE ) :
2025-01-27 09:39:39 +01:00
array_res = await self . db . query ( sql , params , multirows = True )
res = { }
for row in array_res :
res [ row [ " id " ] ] = row
2025-02-09 19:51:05 +01:00
return res if res else None
2025-01-27 09:39:39 +01:00
else :
2025-02-09 19:51:05 +01:00
response = await self . db . query ( sql , params )
return response if response else None
2025-01-27 09:39:39 +01:00
async def get_by_mode_and_id ( self , mode : str , id : str ) - > Union [ dict , None ] :
""" Specifically for llm_response_cache. """
2025-03-31 02:59:44 +08:00
sql = SQL_TEMPLATES [ " get_by_mode_id_ " + self . namespace ]
2025-04-22 00:26:16 +08:00
params = { " workspace " : self . db . workspace , " mode " : mode , " id " : id }
2025-02-08 16:05:59 +08:00
if is_namespace ( self . namespace , NameSpace . KV_STORE_LLM_RESPONSE_CACHE ) :
2025-01-27 09:39:39 +01:00
array_res = await self . db . query ( sql , params , multirows = True )
res = { }
for row in array_res :
res [ row [ " id " ] ] = row
return res
else :
return None
# Query by id
2025-02-09 10:33:15 +01:00
async def get_by_ids ( self , ids : list [ str ] ) - > list [ dict [ str , Any ] ] :
2025-01-27 09:39:39 +01:00
""" Get doc_chunks data by id """
2025-03-31 02:59:44 +08:00
sql = SQL_TEMPLATES [ " get_by_ids_ " + self . namespace ] . format (
2025-01-27 09:39:39 +01:00
ids = " , " . join ( [ f " ' { id } ' " for id in ids ] )
)
params = { " workspace " : self . db . workspace }
2025-02-08 16:05:59 +08:00
if is_namespace ( self . namespace , NameSpace . KV_STORE_LLM_RESPONSE_CACHE ) :
2025-01-27 09:39:39 +01:00
array_res = await self . db . query ( sql , params , multirows = True )
modes = set ( )
dict_res : dict [ str , dict ] = { }
for row in array_res :
modes . add ( row [ " mode " ] )
for mode in modes :
if mode not in dict_res :
dict_res [ mode ] = { }
for row in array_res :
dict_res [ row [ " mode " ] ] [ row [ " id " ] ] = row
2025-02-09 10:33:15 +01:00
return [ { k : v } for k , v in dict_res . items ( ) ]
2025-01-27 09:39:39 +01:00
else :
2025-02-09 10:33:15 +01:00
return await self . db . query ( sql , params , multirows = True )
2025-02-08 23:58:15 +01:00
2025-02-09 11:24:08 +01:00
async def get_by_status ( self , status : str ) - > Union [ list [ dict [ str , Any ] ] , None ] :
2025-02-08 23:18:12 +01:00
""" Specifically for llm_response_cache. """
2025-03-31 02:59:44 +08:00
SQL = SQL_TEMPLATES [ " get_by_status_ " + self . namespace ]
2025-02-08 23:18:12 +01:00
params = { " workspace " : self . db . workspace , " status " : status }
2025-02-08 23:58:15 +01:00
return await self . db . query ( SQL , params , multirows = True )
2025-01-27 09:39:39 +01:00
2025-02-16 13:31:12 +01:00
async def filter_keys ( self , keys : set [ str ] ) - > set [ str ] :
2025-01-27 09:39:39 +01:00
""" Filter out duplicated content """
sql = SQL_TEMPLATES [ " filter_keys " ] . format (
2025-02-08 16:05:59 +08:00
table_name = namespace_to_table_name ( self . namespace ) ,
2025-01-27 09:39:39 +01:00
ids = " , " . join ( [ f " ' { id } ' " for id in keys ] ) ,
)
params = { " workspace " : self . db . workspace }
try :
res = await self . db . query ( sql , params , multirows = True )
if res :
exist_keys = [ key [ " id " ] for key in res ]
else :
exist_keys = [ ]
2025-02-18 10:07:57 +01:00
new_keys = set ( [ s for s in keys if s not in exist_keys ] )
return new_keys
2025-01-27 09:39:39 +01:00
except Exception as e :
2025-02-18 16:55:48 +01:00
logger . error (
f " PostgreSQL database, \n sql: { sql } , \n params: { params } , \n error: { e } "
)
raise
2025-01-27 09:39:39 +01:00
################ INSERT METHODS ################
2025-02-16 13:31:12 +01:00
async def upsert ( self , data : dict [ str , dict [ str , Any ] ] ) - > None :
2025-04-10 01:06:46 +08:00
logger . debug ( f " Inserting { len ( data ) } to { self . namespace } " )
2025-02-19 22:22:41 +01:00
if not data :
return
2025-02-08 16:05:59 +08:00
if is_namespace ( self . namespace , NameSpace . KV_STORE_TEXT_CHUNKS ) :
2025-01-27 09:39:39 +01:00
pass
2025-02-08 16:05:59 +08:00
elif is_namespace ( self . namespace , NameSpace . KV_STORE_FULL_DOCS ) :
2025-01-27 09:39:39 +01:00
for k , v in data . items ( ) :
upsert_sql = SQL_TEMPLATES [ " upsert_doc_full " ]
_data = {
" id " : k ,
" content " : v [ " content " ] ,
" workspace " : self . db . workspace ,
}
await self . db . execute ( upsert_sql , _data )
2025-02-08 16:05:59 +08:00
elif is_namespace ( self . namespace , NameSpace . KV_STORE_LLM_RESPONSE_CACHE ) :
2025-01-27 09:39:39 +01:00
for mode , items in data . items ( ) :
for k , v in items . items ( ) :
upsert_sql = SQL_TEMPLATES [ " upsert_llm_response_cache " ]
_data = {
" workspace " : self . db . workspace ,
" id " : k ,
" original_prompt " : v [ " original_prompt " ] ,
" return_value " : v [ " return " ] ,
" mode " : mode ,
}
await self . db . execute ( upsert_sql , _data )
2025-02-16 13:24:42 +01:00
async def index_done_callback ( self ) - > None :
2025-02-16 16:04:07 +01:00
# PG handles persistence automatically
2025-02-16 14:38:09 +01:00
pass
2025-02-17 23:20:10 +01:00
2025-03-31 23:10:21 +08:00
async def delete ( self , ids : list [ str ] ) - > None :
""" Delete specific records from storage by their IDs
2025-03-31 23:22:27 +08:00
2025-03-31 23:10:21 +08:00
Args :
ids ( list [ str ] ) : List of document IDs to be deleted from storage
2025-03-31 23:22:27 +08:00
2025-03-31 23:10:21 +08:00
Returns :
None
"""
if not ids :
return
table_name = namespace_to_table_name ( self . namespace )
if not table_name :
logger . error ( f " Unknown namespace for deletion: { self . namespace } " )
return
delete_sql = f " DELETE FROM { table_name } WHERE workspace=$1 AND id = ANY($2) "
try :
2025-03-31 23:22:27 +08:00
await self . db . execute (
delete_sql , { " workspace " : self . db . workspace , " ids " : ids }
)
logger . debug (
f " Successfully deleted { len ( ids ) } records from { self . namespace } "
)
2025-03-31 23:10:21 +08:00
except Exception as e :
logger . error ( f " Error while deleting records from { self . namespace } : { e } " )
async def drop_cache_by_modes ( self , modes : list [ str ] | None = None ) - > bool :
""" Delete specific records from storage by cache mode
2025-03-31 23:22:27 +08:00
2025-03-31 23:10:21 +08:00
Args :
modes ( list [ str ] ) : List of cache modes to be dropped from storage
2025-03-31 23:22:27 +08:00
2025-03-31 23:10:21 +08:00
Returns :
bool : True if successful , False otherwise
"""
if not modes :
return False
2025-03-31 23:22:27 +08:00
2025-03-31 23:10:21 +08:00
try :
table_name = namespace_to_table_name ( self . namespace )
if not table_name :
return False
2025-03-31 23:22:27 +08:00
2025-03-31 23:10:21 +08:00
if table_name != " LIGHTRAG_LLM_CACHE " :
return False
2025-03-31 23:22:27 +08:00
2025-03-31 23:10:21 +08:00
sql = f """
DELETE FROM { table_name }
WHERE workspace = $ 1 AND mode = ANY ( $ 2 )
"""
2025-03-31 23:22:27 +08:00
params = { " workspace " : self . db . workspace , " modes " : modes }
2025-03-31 23:10:21 +08:00
logger . info ( f " Deleting cache by modes: { modes } " )
await self . db . execute ( sql , params )
return True
except Exception as e :
logger . error ( f " Error deleting cache by modes { modes } : { e } " )
return False
2025-03-31 01:03:41 +08:00
async def drop ( self ) - > dict [ str , str ] :
2025-02-18 09:10:50 +01:00
""" Drop the storage """
2025-03-31 01:03:41 +08:00
try :
table_name = namespace_to_table_name ( self . namespace )
if not table_name :
2025-03-31 23:22:27 +08:00
return {
" status " : " error " ,
" message " : f " Unknown namespace: { self . namespace } " ,
}
2025-03-31 01:03:41 +08:00
drop_sql = SQL_TEMPLATES [ " drop_specifiy_table_workspace " ] . format (
table_name = table_name
)
await self . db . execute ( drop_sql , { " workspace " : self . db . workspace } )
return { " status " : " success " , " message " : " data dropped " }
except Exception as e :
return { " status " : " error " , " message " : str ( e ) }
2025-02-18 10:24:19 +01:00
2025-02-16 14:38:09 +01:00
@final
2025-01-27 09:39:39 +01:00
@dataclass
class PGVectorStorage ( BaseVectorStorage ) :
2025-02-19 13:42:49 +01:00
db : PostgreSQLDB | None = field ( default = None )
2025-02-19 04:53:15 +08:00
2025-01-27 09:39:39 +01:00
def __post_init__ ( self ) :
self . _max_batch_size = self . global_config [ " embedding_batch_num " ]
2025-01-29 23:47:57 +08:00
config = self . global_config . get ( " vector_db_storage_cls_kwargs " , { } )
2025-02-13 03:25:48 +08:00
cosine_threshold = config . get ( " cosine_better_than_threshold " )
if cosine_threshold is None :
2025-02-13 04:12:00 +08:00
raise ValueError (
" cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs "
)
2025-02-13 03:25:48 +08:00
self . cosine_better_than_threshold = cosine_threshold
2025-01-27 09:39:39 +01:00
2025-02-19 03:46:18 +08:00
async def initialize ( self ) :
2025-02-19 04:53:15 +08:00
if self . db is None :
2025-02-19 03:46:18 +08:00
self . db = await ClientManager . get_client ( )
async def finalize ( self ) :
2025-02-19 04:53:15 +08:00
if self . db is not None :
2025-02-19 03:46:18 +08:00
await ClientManager . release_client ( self . db )
self . db = None
2025-05-01 15:13:42 +08:00
def _upsert_chunks (
self , item : dict [ str , Any ] , current_time : datetime . datetime
) - > tuple [ str , dict [ str , Any ] ] :
2025-01-27 09:39:39 +01:00
try :
upsert_sql = SQL_TEMPLATES [ " upsert_chunk " ]
2025-02-19 13:42:49 +01:00
data : dict [ str , Any ] = {
2025-01-27 09:39:39 +01:00
" workspace " : self . db . workspace ,
" id " : item [ " __id__ " ] ,
" tokens " : item [ " tokens " ] ,
" chunk_order_index " : item [ " chunk_order_index " ] ,
" full_doc_id " : item [ " full_doc_id " ] ,
" content " : item [ " content " ] ,
" content_vector " : json . dumps ( item [ " __vector__ " ] . tolist ( ) ) ,
2025-03-17 23:59:47 +08:00
" file_path " : item [ " file_path " ] ,
2025-05-01 15:13:42 +08:00
" create_time " : current_time ,
" update_time " : current_time ,
2025-01-27 09:39:39 +01:00
}
except Exception as e :
2025-02-18 16:55:48 +01:00
logger . error ( f " Error to prepare upsert, \n sql: { e } \n item: { item } " )
raise
2025-01-27 09:39:39 +01:00
return upsert_sql , data
2025-05-01 15:13:42 +08:00
def _upsert_entities (
self , item : dict [ str , Any ] , current_time : datetime . datetime
) - > tuple [ str , dict [ str , Any ] ] :
2025-01-27 09:39:39 +01:00
upsert_sql = SQL_TEMPLATES [ " upsert_entity " ]
2025-03-13 13:45:09 +02:00
source_id = item [ " source_id " ]
if isinstance ( source_id , str ) and " <SEP> " in source_id :
chunk_ids = source_id . split ( " <SEP> " )
else :
chunk_ids = [ source_id ]
2025-03-17 15:59:54 +08:00
2025-02-19 13:42:49 +01:00
data : dict [ str , Any ] = {
2025-01-27 09:39:39 +01:00
" workspace " : self . db . workspace ,
" id " : item [ " __id__ " ] ,
" entity_name " : item [ " entity_name " ] ,
" content " : item [ " content " ] ,
" content_vector " : json . dumps ( item [ " __vector__ " ] . tolist ( ) ) ,
2025-03-13 13:45:09 +02:00
" chunk_ids " : chunk_ids ,
2025-04-26 22:15:54 +08:00
" file_path " : item . get ( " file_path " , None ) ,
2025-05-01 15:13:42 +08:00
" create_time " : current_time ,
" update_time " : current_time ,
2025-01-27 09:39:39 +01:00
}
return upsert_sql , data
2025-05-01 15:13:42 +08:00
def _upsert_relationships (
self , item : dict [ str , Any ] , current_time : datetime . datetime
) - > tuple [ str , dict [ str , Any ] ] :
2025-01-27 09:39:39 +01:00
upsert_sql = SQL_TEMPLATES [ " upsert_relationship " ]
2025-03-13 13:45:09 +02:00
source_id = item [ " source_id " ]
if isinstance ( source_id , str ) and " <SEP> " in source_id :
chunk_ids = source_id . split ( " <SEP> " )
else :
chunk_ids = [ source_id ]
2025-03-17 15:59:54 +08:00
2025-02-19 13:42:49 +01:00
data : dict [ str , Any ] = {
2025-01-27 09:39:39 +01:00
" workspace " : self . db . workspace ,
" id " : item [ " __id__ " ] ,
" source_id " : item [ " src_id " ] ,
" target_id " : item [ " tgt_id " ] ,
" content " : item [ " content " ] ,
" content_vector " : json . dumps ( item [ " __vector__ " ] . tolist ( ) ) ,
2025-03-13 13:45:09 +02:00
" chunk_ids " : chunk_ids ,
2025-04-26 22:15:54 +08:00
" file_path " : item . get ( " file_path " , None ) ,
2025-05-01 15:13:42 +08:00
" create_time " : current_time ,
" update_time " : current_time ,
2025-01-27 09:39:39 +01:00
}
return upsert_sql , data
2025-02-16 13:24:42 +01:00
async def upsert ( self , data : dict [ str , dict [ str , Any ] ] ) - > None :
2025-04-10 01:06:46 +08:00
logger . debug ( f " Inserting { len ( data ) } to { self . namespace } " )
2025-02-19 22:22:41 +01:00
if not data :
return
2025-05-01 15:13:42 +08:00
# Get current time with UTC timezone
2025-05-03 00:44:55 +08:00
current_time = datetime . datetime . now ( timezone . utc )
2025-01-27 09:39:39 +01:00
list_data = [
{
" __id__ " : k ,
* * { k1 : v1 for k1 , v1 in v . items ( ) } ,
}
for k , v in data . items ( )
]
contents = [ v [ " content " ] for v in data . values ( ) ]
batches = [
contents [ i : i + self . _max_batch_size ]
for i in range ( 0 , len ( contents ) , self . _max_batch_size )
]
2025-02-18 19:58:03 +01:00
embedding_tasks = [ self . embedding_func ( batch ) for batch in batches ]
2025-01-27 09:39:39 +01:00
embeddings_list = await asyncio . gather ( * embedding_tasks )
embeddings = np . concatenate ( embeddings_list )
for i , d in enumerate ( list_data ) :
d [ " __vector__ " ] = embeddings [ i ]
for item in list_data :
2025-02-08 16:05:59 +08:00
if is_namespace ( self . namespace , NameSpace . VECTOR_STORE_CHUNKS ) :
2025-05-01 15:13:42 +08:00
upsert_sql , data = self . _upsert_chunks ( item , current_time )
2025-02-08 16:05:59 +08:00
elif is_namespace ( self . namespace , NameSpace . VECTOR_STORE_ENTITIES ) :
2025-05-01 15:13:42 +08:00
upsert_sql , data = self . _upsert_entities ( item , current_time )
2025-02-08 16:05:59 +08:00
elif is_namespace ( self . namespace , NameSpace . VECTOR_STORE_RELATIONSHIPS ) :
2025-05-01 15:13:42 +08:00
upsert_sql , data = self . _upsert_relationships ( item , current_time )
2025-01-27 09:39:39 +01:00
else :
raise ValueError ( f " { self . namespace } is not supported " )
await self . db . execute ( upsert_sql , data )
#################### query method ###############
2025-03-10 15:39:18 +00:00
async def query (
self , query : str , top_k : int , ids : list [ str ] | None = None
) - > list [ dict [ str , Any ] ] :
2025-04-28 20:10:39 +08:00
embeddings = await self . embedding_func (
[ query ] , _priority = 5
) # higher priority for query
2025-01-27 09:39:39 +01:00
embedding = embeddings [ 0 ]
embedding_string = " , " . join ( map ( str , embedding ) )
2025-04-21 21:20:21 +03:00
# Use parameterized document IDs (None means search across all documents)
sql = SQL_TEMPLATES [ self . namespace ] . format ( embedding_string = embedding_string )
2025-01-27 09:39:39 +01:00
params = {
" workspace " : self . db . workspace ,
2025-04-21 21:20:21 +03:00
" doc_ids " : ids ,
2025-01-27 09:39:39 +01:00
" better_than_threshold " : self . cosine_better_than_threshold ,
" top_k " : top_k ,
}
results = await self . db . query ( sql , params = params , multirows = True )
return results
2025-02-16 16:04:07 +01:00
async def index_done_callback ( self ) - > None :
# PG handles persistence automatically
pass
2025-02-16 16:04:35 +01:00
2025-03-04 15:50:53 +08:00
async def delete ( self , ids : list [ str ] ) - > None :
""" Delete vectors with specified IDs from the storage.
Args :
ids : List of vector IDs to be deleted
"""
if not ids :
return
table_name = namespace_to_table_name ( self . namespace )
if not table_name :
logger . error ( f " Unknown namespace for vector deletion: { self . namespace } " )
return
2025-03-31 23:10:21 +08:00
delete_sql = f " DELETE FROM { table_name } WHERE workspace=$1 AND id = ANY($2) "
2025-03-04 15:53:20 +08:00
2025-03-04 15:50:53 +08:00
try :
2025-03-31 23:22:27 +08:00
await self . db . execute (
delete_sql , { " workspace " : self . db . workspace , " ids " : ids }
)
2025-03-04 15:53:20 +08:00
logger . debug (
f " Successfully deleted { len ( ids ) } vectors from { self . namespace } "
)
2025-03-04 15:50:53 +08:00
except Exception as e :
logger . error ( f " Error while deleting vectors from { self . namespace } : { e } " )
2025-02-16 13:24:42 +01:00
async def delete_entity ( self , entity_name : str ) - > None :
2025-03-04 15:50:53 +08:00
""" Delete an entity by its name from the vector storage.
Args :
entity_name : The name of the entity to delete
"""
try :
# Construct SQL to delete the entity
2025-03-04 15:53:20 +08:00
delete_sql = """ DELETE FROM LIGHTRAG_VDB_ENTITY
2025-03-04 15:50:53 +08:00
WHERE workspace = $ 1 AND entity_name = $ 2 """
2025-03-04 15:53:20 +08:00
2025-03-04 15:50:53 +08:00
await self . db . execute (
2025-03-04 15:53:20 +08:00
delete_sql , { " workspace " : self . db . workspace , " entity_name " : entity_name }
2025-03-04 15:50:53 +08:00
)
logger . debug ( f " Successfully deleted entity { entity_name } " )
except Exception as e :
logger . error ( f " Error deleting entity { entity_name } : { e } " )
2025-02-16 13:24:42 +01:00
async def delete_entity_relation ( self , entity_name : str ) - > None :
2025-03-04 15:50:53 +08:00
""" Delete all relations associated with an entity.
Args :
entity_name : The name of the entity whose relations should be deleted
"""
try :
# Delete relations where the entity is either the source or target
2025-03-04 15:53:20 +08:00
delete_sql = """ DELETE FROM LIGHTRAG_VDB_RELATION
2025-03-04 15:50:53 +08:00
WHERE workspace = $ 1 AND ( source_id = $ 2 OR target_id = $ 2 ) """
2025-03-04 15:53:20 +08:00
2025-03-04 15:50:53 +08:00
await self . db . execute (
2025-03-04 15:53:20 +08:00
delete_sql , { " workspace " : self . db . workspace , " entity_name " : entity_name }
2025-03-04 15:50:53 +08:00
)
logger . debug ( f " Successfully deleted relations for entity { entity_name } " )
except Exception as e :
logger . error ( f " Error deleting relations for entity { entity_name } : { e } " )
2025-01-27 09:39:39 +01:00
2025-03-11 16:05:04 +08:00
async def get_by_id ( self , id : str ) - > dict [ str , Any ] | None :
""" Get vector data by its ID
Args :
id : The unique identifier of the vector
Returns :
The vector data if found , or None if not found
"""
table_name = namespace_to_table_name ( self . namespace )
if not table_name :
logger . error ( f " Unknown namespace for ID lookup: { self . namespace } " )
return None
2025-05-01 15:13:42 +08:00
query = f " SELECT *, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM { table_name } WHERE workspace=$1 AND id=$2 "
2025-03-11 16:05:04 +08:00
params = { " workspace " : self . db . workspace , " id " : id }
try :
result = await self . db . query ( query , params )
if result :
return dict ( result )
return None
except Exception as e :
logger . error ( f " Error retrieving vector data for ID { id } : { e } " )
return None
async def get_by_ids ( self , ids : list [ str ] ) - > list [ dict [ str , Any ] ] :
""" Get multiple vector data by their IDs
Args :
ids : List of unique identifiers
Returns :
List of vector data objects that were found
"""
if not ids :
return [ ]
table_name = namespace_to_table_name ( self . namespace )
if not table_name :
logger . error ( f " Unknown namespace for IDs lookup: { self . namespace } " )
return [ ]
ids_str = " , " . join ( [ f " ' { id } ' " for id in ids ] )
2025-05-01 15:13:42 +08:00
query = f " SELECT *, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM { table_name } WHERE workspace=$1 AND id IN ( { ids_str } ) "
2025-03-11 16:05:04 +08:00
params = { " workspace " : self . db . workspace }
try :
results = await self . db . query ( query , params , multirows = True )
return [ dict ( record ) for record in results ]
except Exception as e :
logger . error ( f " Error retrieving vector data for IDs { ids } : { e } " )
return [ ]
2025-03-31 01:03:41 +08:00
async def drop ( self ) - > dict [ str , str ] :
""" Drop the storage """
try :
table_name = namespace_to_table_name ( self . namespace )
if not table_name :
2025-03-31 23:22:27 +08:00
return {
" status " : " error " ,
" message " : f " Unknown namespace: { self . namespace } " ,
}
2025-03-31 01:03:41 +08:00
drop_sql = SQL_TEMPLATES [ " drop_specifiy_table_workspace " ] . format (
table_name = table_name
)
await self . db . execute ( drop_sql , { " workspace " : self . db . workspace } )
return { " status " : " success " , " message " : " data dropped " }
except Exception as e :
return { " status " : " error " , " message " : str ( e ) }
2025-02-16 13:55:30 +01:00
2025-02-16 14:38:09 +01:00
@final
2025-01-27 09:39:39 +01:00
@dataclass
class PGDocStatusStorage ( DocStatusStorage ) :
2025-02-19 04:55:59 +08:00
db : PostgreSQLDB = field ( default = None )
2025-02-19 04:53:15 +08:00
2025-02-19 03:46:18 +08:00
async def initialize ( self ) :
2025-02-19 04:53:15 +08:00
if self . db is None :
2025-02-19 03:46:18 +08:00
self . db = await ClientManager . get_client ( )
async def finalize ( self ) :
2025-02-19 04:53:15 +08:00
if self . db is not None :
2025-02-19 03:46:18 +08:00
await ClientManager . release_client ( self . db )
self . db = None
2025-02-16 14:38:09 +01:00
async def filter_keys ( self , keys : set [ str ] ) - > set [ str ] :
2025-02-18 10:12:08 +01:00
""" Filter out duplicated content """
sql = SQL_TEMPLATES [ " filter_keys " ] . format (
table_name = namespace_to_table_name ( self . namespace ) ,
ids = " , " . join ( [ f " ' { id } ' " for id in keys ] ) ,
)
params = { " workspace " : self . db . workspace }
try :
res = await self . db . query ( sql , params , multirows = True )
if res :
exist_keys = [ key [ " id " ] for key in res ]
else :
exist_keys = [ ]
new_keys = set ( [ s for s in keys if s not in exist_keys ] )
print ( f " keys: { keys } " )
print ( f " new_keys: { new_keys } " )
return new_keys
except Exception as e :
2025-02-18 16:55:48 +01:00
logger . error (
f " PostgreSQL database, \n sql: { sql } , \n params: { params } , \n error: { e } "
)
raise
2025-01-01 22:43:59 +08:00
2025-02-09 19:51:05 +01:00
async def get_by_id ( self , id : str ) - > Union [ dict [ str , Any ] , None ] :
2025-02-02 18:20:32 +08:00
sql = " select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2 "
params = { " workspace " : self . db . workspace , " id " : id }
result = await self . db . query ( sql , params , True )
2025-02-04 17:09:34 +08:00
if result is None or result == [ ] :
2025-02-09 19:51:05 +01:00
return None
2025-02-02 18:20:32 +08:00
else :
2025-03-17 17:32:54 -07:00
return dict (
2025-02-09 15:36:01 +01:00
content = result [ 0 ] [ " content " ] ,
2025-02-02 18:20:32 +08:00
content_length = result [ 0 ] [ " content_length " ] ,
content_summary = result [ 0 ] [ " content_summary " ] ,
status = result [ 0 ] [ " status " ] ,
chunks_count = result [ 0 ] [ " chunks_count " ] ,
created_at = result [ 0 ] [ " created_at " ] ,
updated_at = result [ 0 ] [ " updated_at " ] ,
2025-03-17 23:59:47 +08:00
file_path = result [ 0 ] [ " file_path " ] ,
2025-02-02 18:20:32 +08:00
)
2025-02-16 14:38:09 +01:00
async def get_by_ids ( self , ids : list [ str ] ) - > list [ dict [ str , Any ] ] :
2025-03-19 15:24:25 -07:00
""" Get doc_chunks data by multiple IDs. """
if not ids :
return [ ]
sql = " SELECT * FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id = ANY($2) "
params = { " workspace " : self . db . workspace , " ids " : ids }
results = await self . db . query ( sql , params , True )
if not results :
return [ ]
return [
{
" content " : row [ " content " ] ,
" content_length " : row [ " content_length " ] ,
" content_summary " : row [ " content_summary " ] ,
" status " : row [ " status " ] ,
" chunks_count " : row [ " chunks_count " ] ,
" created_at " : row [ " created_at " ] ,
" updated_at " : row [ " updated_at " ] ,
" file_path " : row [ " file_path " ] ,
}
for row in results
]
2025-02-16 14:38:09 +01:00
2025-02-19 13:31:30 +01:00
async def get_status_counts ( self ) - > dict [ str , int ] :
2025-01-27 09:39:39 +01:00
""" Get counts of documents in each status """
sql = """ SELECT status as " status " , COUNT(1) as " count "
FROM LIGHTRAG_DOC_STATUS
where workspace = $ 1 GROUP BY STATUS
"""
result = await self . db . query ( sql , { " workspace " : self . db . workspace } , True )
counts = { }
for doc in result :
counts [ doc [ " status " ] ] = doc [ " count " ]
return counts
2025-01-01 22:43:59 +08:00
2025-01-27 09:39:39 +01:00
async def get_docs_by_status (
self , status : DocStatus
2025-02-19 13:31:30 +01:00
) - > dict [ str , DocProcessingStatus ] :
2025-02-16 21:28:58 +08:00
""" all documents with a specific status """
2025-02-11 16:11:15 +08:00
sql = " select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2 "
2025-02-16 15:52:59 +01:00
params = { " workspace " : self . db . workspace , " status " : status . value }
2025-01-27 09:39:39 +01:00
result = await self . db . query ( sql , params , True )
2025-02-18 10:16:00 +01:00
docs_by_status = {
2025-01-27 09:39:39 +01:00
element [ " id " ] : DocProcessingStatus (
2025-03-12 07:47:21 +08:00
content = element [ " content " ] ,
2025-01-27 09:39:39 +01:00
content_summary = element [ " content_summary " ] ,
content_length = element [ " content_length " ] ,
status = element [ " status " ] ,
2025-02-18 16:10:26 +01:00
created_at = element [ " created_at " ] ,
updated_at = element [ " updated_at " ] ,
2025-01-27 09:39:39 +01:00
chunks_count = element [ " chunks_count " ] ,
2025-03-17 23:59:47 +08:00
file_path = element [ " file_path " ] ,
2025-01-27 09:39:39 +01:00
)
for element in result
}
2025-02-18 10:16:00 +01:00
return docs_by_status
2025-01-01 22:43:59 +08:00
2025-02-16 14:38:09 +01:00
async def index_done_callback ( self ) - > None :
2025-02-16 16:04:07 +01:00
# PG handles persistence automatically
2025-02-16 14:38:09 +01:00
pass
2025-01-27 09:39:39 +01:00
2025-04-01 22:15:31 +08:00
async def delete ( self , ids : list [ str ] ) - > None :
""" Delete specific records from storage by their IDs
Args :
ids ( list [ str ] ) : List of document IDs to be deleted from storage
Returns :
None
"""
if not ids :
return
table_name = namespace_to_table_name ( self . namespace )
if not table_name :
logger . error ( f " Unknown namespace for deletion: { self . namespace } " )
return
delete_sql = f " DELETE FROM { table_name } WHERE workspace=$1 AND id = ANY($2) "
try :
await self . db . execute (
delete_sql , { " workspace " : self . db . workspace , " ids " : ids }
)
logger . debug (
f " Successfully deleted { len ( ids ) } records from { self . namespace } "
)
except Exception as e :
logger . error ( f " Error while deleting records from { self . namespace } : { e } " )
2025-02-16 14:50:04 +01:00
async def upsert ( self , data : dict [ str , dict [ str , Any ] ] ) - > None :
2025-01-27 09:39:39 +01:00
""" Update or insert document status
Args :
2025-02-19 13:31:30 +01:00
data : dictionary of document IDs and their status data
2025-01-27 09:39:39 +01:00
"""
2025-04-10 01:06:46 +08:00
logger . debug ( f " Inserting { len ( data ) } to { self . namespace } " )
2025-02-19 22:22:41 +01:00
if not data :
return
2025-05-01 10:04:17 +08:00
def parse_datetime ( dt_str ) :
if dt_str is None :
return None
if isinstance ( dt_str , ( datetime . date , datetime . datetime ) ) :
# If it's a datetime object without timezone info, remove timezone info
if isinstance ( dt_str , datetime . datetime ) :
# Remove timezone info, return naive datetime object
return dt_str . replace ( tzinfo = None )
return dt_str
try :
# Process ISO format string with timezone
dt = datetime . datetime . fromisoformat ( dt_str )
# Remove timezone info, return naive datetime object
return dt . replace ( tzinfo = None )
except ( ValueError , TypeError ) :
logger . warning ( f " Unable to parse datetime string: { dt_str } " )
return None
# Modified SQL to include created_at and updated_at in both INSERT and UPDATE operations
# Both fields are updated from the input data in both INSERT and UPDATE cases
sql = """ insert into LIGHTRAG_DOC_STATUS(workspace,id,content,content_summary,content_length,chunks_count,status,file_path,created_at,updated_at)
values ( $ 1 , $ 2 , $ 3 , $ 4 , $ 5 , $ 6 , $ 7 , $ 8 , $ 9 , $ 10 )
2025-01-27 09:39:39 +01:00
on conflict ( id , workspace ) do update set
2025-02-09 15:36:01 +01:00
content = EXCLUDED . content ,
2025-01-27 09:39:39 +01:00
content_summary = EXCLUDED . content_summary ,
content_length = EXCLUDED . content_length ,
chunks_count = EXCLUDED . chunks_count ,
status = EXCLUDED . status ,
2025-03-17 23:59:47 +08:00
file_path = EXCLUDED . file_path ,
2025-05-01 10:04:17 +08:00
created_at = EXCLUDED . created_at ,
updated_at = EXCLUDED . updated_at """
2025-01-27 09:39:39 +01:00
for k , v in data . items ( ) :
2025-05-01 10:04:17 +08:00
# Remove timezone information, store utc time in db
created_at = parse_datetime ( v . get ( " created_at " ) )
updated_at = parse_datetime ( v . get ( " updated_at " ) )
2025-05-01 10:04:31 +08:00
2025-01-27 09:39:39 +01:00
# chunks_count is optional
await self . db . execute (
sql ,
{
" workspace " : self . db . workspace ,
" id " : k ,
2025-02-09 15:36:01 +01:00
" content " : v [ " content " ] ,
2025-01-27 09:39:39 +01:00
" content_summary " : v [ " content_summary " ] ,
" content_length " : v [ " content_length " ] ,
" chunks_count " : v [ " chunks_count " ] if " chunks_count " in v else - 1 ,
" status " : v [ " status " ] ,
2025-03-17 23:59:47 +08:00
" file_path " : v [ " file_path " ] ,
2025-05-03 00:44:55 +08:00
" created_at " : created_at , # Use the converted datetime object
" updated_at " : updated_at , # Use the converted datetime object
2025-01-27 09:39:39 +01:00
} ,
)
2025-02-18 10:24:19 +01:00
2025-03-31 01:03:41 +08:00
async def drop ( self ) - > dict [ str , str ] :
2025-02-18 09:57:10 +01:00
""" Drop the storage """
2025-03-31 01:03:41 +08:00
try :
table_name = namespace_to_table_name ( self . namespace )
if not table_name :
2025-03-31 23:22:27 +08:00
return {
" status " : " error " ,
" message " : f " Unknown namespace: { self . namespace } " ,
}
2025-03-31 01:03:41 +08:00
drop_sql = SQL_TEMPLATES [ " drop_specifiy_table_workspace " ] . format (
table_name = table_name
)
await self . db . execute ( drop_sql , { " workspace " : self . db . workspace } )
return { " status " : " success " , " message " : " data dropped " }
except Exception as e :
return { " status " : " error " , " message " : str ( e ) }
2025-01-27 09:39:39 +01:00
2025-02-18 10:24:19 +01:00
2025-01-27 09:39:39 +01:00
class PGGraphQueryException ( Exception ) :
""" Exception for the AGE queries. """
2025-02-19 13:31:30 +01:00
def __init__ ( self , exception : Union [ str , dict [ str , Any ] ] ) - > None :
2025-01-27 09:39:39 +01:00
if isinstance ( exception , dict ) :
self . message = exception [ " message " ] if " message " in exception else " unknown "
self . details = exception [ " details " ] if " details " in exception else " unknown "
else :
self . message = exception
self . details = " unknown "
def get_message ( self ) - > str :
return self . message
def get_details ( self ) - > Any :
return self . details
2025-02-16 14:38:09 +01:00
@final
2025-01-27 09:39:39 +01:00
@dataclass
class PGGraphStorage ( BaseGraphStorage ) :
2025-02-12 22:25:34 +08:00
def __post_init__ ( self ) :
2025-02-13 04:52:54 +08:00
self . graph_name = self . namespace or os . environ . get ( " AGE_GRAPH_NAME " , " lightrag " )
2025-02-19 13:31:30 +01:00
self . db : PostgreSQLDB | None = None
2025-04-17 02:32:32 +08:00
2025-04-17 02:31:56 +08:00
@staticmethod
def _normalize_node_id ( node_id : str ) - > str :
"""
Normalize node ID to ensure special characters are properly handled in Cypher queries .
2025-04-17 02:32:32 +08:00
2025-04-17 02:31:56 +08:00
Args :
node_id : The original node ID
2025-04-17 02:32:32 +08:00
2025-04-17 02:31:56 +08:00
Returns :
Normalized node ID suitable for Cypher queries
"""
# Escape backslashes
2025-04-17 22:58:36 +08:00
normalized_id = node_id
2025-04-17 02:31:56 +08:00
normalized_id = normalized_id . replace ( " \\ " , " \\ \\ " )
2025-04-17 22:58:36 +08:00
normalized_id = normalized_id . replace ( ' " ' , ' \\ " ' )
2025-04-17 02:31:56 +08:00
return normalized_id
2025-01-27 09:39:39 +01:00
2025-02-19 03:46:18 +08:00
async def initialize ( self ) :
2025-02-19 04:53:15 +08:00
if self . db is None :
2025-02-19 03:46:18 +08:00
self . db = await ClientManager . get_client ( )
2025-04-25 21:54:04 +08:00
# Execute each statement separately and ignore errors
2025-04-24 01:41:33 +08:00
queries = [
f " SELECT create_graph( ' { self . graph_name } ' ) " ,
2025-04-24 17:59:34 +08:00
f " SELECT create_vlabel( ' { self . graph_name } ' , ' base ' ); " ,
f " SELECT create_elabel( ' { self . graph_name } ' , ' DIRECTED ' ); " ,
# f'CREATE INDEX CONCURRENTLY vertex_p_idx ON {self.graph_name}."_ag_label_vertex" (id)',
2025-04-24 01:41:33 +08:00
f ' CREATE INDEX CONCURRENTLY vertex_idx_node_id ON { self . graph_name } . " _ag_label_vertex " (ag_catalog.agtype_access_operator(properties, \' " entity_id " \' ::agtype)) ' ,
2025-04-24 17:59:34 +08:00
# f'CREATE INDEX CONCURRENTLY edge_p_idx ON {self.graph_name}."_ag_label_edge" (id)',
2025-04-24 01:41:33 +08:00
f ' CREATE INDEX CONCURRENTLY edge_sid_idx ON { self . graph_name } . " _ag_label_edge " (start_id) ' ,
f ' CREATE INDEX CONCURRENTLY edge_eid_idx ON { self . graph_name } . " _ag_label_edge " (end_id) ' ,
f ' CREATE INDEX CONCURRENTLY edge_seid_idx ON { self . graph_name } . " _ag_label_edge " (start_id,end_id) ' ,
f ' CREATE INDEX CONCURRENTLY directed_p_idx ON { self . graph_name } . " DIRECTED " (id) ' ,
f ' CREATE INDEX CONCURRENTLY directed_eid_idx ON { self . graph_name } . " DIRECTED " (end_id) ' ,
f ' CREATE INDEX CONCURRENTLY directed_sid_idx ON { self . graph_name } . " DIRECTED " (start_id) ' ,
f ' CREATE INDEX CONCURRENTLY directed_seid_idx ON { self . graph_name } . " DIRECTED " (start_id,end_id) ' ,
f ' CREATE INDEX CONCURRENTLY entity_p_idx ON { self . graph_name } . " base " (id) ' ,
f ' CREATE INDEX CONCURRENTLY entity_idx_node_id ON { self . graph_name } . " base " (ag_catalog.agtype_access_operator(properties, \' " entity_id " \' ::agtype)) ' ,
f ' CREATE INDEX CONCURRENTLY entity_node_id_gin_idx ON { self . graph_name } . " base " using gin(properties) ' ,
f ' ALTER TABLE { self . graph_name } . " DIRECTED " CLUSTER ON directed_sid_idx ' ,
]
2025-04-16 02:31:16 +08:00
2025-04-24 01:41:33 +08:00
for query in queries :
try :
await self . db . execute (
query ,
upsert = True ,
with_age = True ,
graph_name = self . graph_name ,
)
logger . info ( f " Successfully executed: { query } " )
except Exception :
continue
2025-04-16 02:31:16 +08:00
2025-02-19 03:46:18 +08:00
async def finalize ( self ) :
2025-02-19 04:53:15 +08:00
if self . db is not None :
2025-02-19 03:46:18 +08:00
await ClientManager . release_client ( self . db )
self . db = None
2025-02-16 14:38:09 +01:00
async def index_done_callback ( self ) - > None :
2025-02-16 16:04:07 +01:00
# PG handles persistence automatically
2025-02-16 14:38:09 +01:00
pass
2025-01-27 09:39:39 +01:00
@staticmethod
2025-02-19 13:31:30 +01:00
def _record_to_dict ( record : asyncpg . Record ) - > dict [ str , Any ] :
2025-01-27 09:39:39 +01:00
"""
Convert a record returned from an age query to a dictionary
Args :
record ( ) : a record from an age query result
Returns :
2025-02-19 13:31:30 +01:00
dict [ str , Any ] : a dictionary representation of the record where
2025-01-27 09:39:39 +01:00
the dictionary key is the field name and the value is the
value converted to a python type
"""
# result holder
d = { }
# prebuild a mapping of vertex_id to vertex mappings to be used
# later to build edges
vertices = { }
for k in record . keys ( ) :
v = record [ k ]
# agtype comes back '{key: value}::type' which must be parsed
if isinstance ( v , str ) and " :: " in v :
2025-03-08 11:45:59 +08:00
if v . startswith ( " [ " ) and v . endswith ( " ] " ) :
if " ::vertex " not in v :
continue
v = v . replace ( " ::vertex " , " " )
vertexes = json . loads ( v )
for vertex in vertexes :
vertices [ vertex [ " id " ] ] = vertex . get ( " properties " )
else :
dtype = v . split ( " :: " ) [ - 1 ]
v = v . split ( " :: " ) [ 0 ]
if dtype == " vertex " :
vertex = json . loads ( v )
vertices [ vertex [ " id " ] ] = vertex . get ( " properties " )
2025-01-27 09:39:39 +01:00
# iterate returned fields and parse appropriately
for k in record . keys ( ) :
v = record [ k ]
if isinstance ( v , str ) and " :: " in v :
2025-03-08 11:45:59 +08:00
if v . startswith ( " [ " ) and v . endswith ( " ] " ) :
if " ::vertex " in v :
v = v . replace ( " ::vertex " , " " )
2025-04-03 15:16:48 +08:00
d [ k ] = json . loads ( v )
2025-03-08 11:45:59 +08:00
elif " ::edge " in v :
v = v . replace ( " ::edge " , " " )
2025-04-03 15:16:48 +08:00
d [ k ] = json . loads ( v )
2025-03-08 11:45:59 +08:00
else :
print ( " WARNING: unsupported type " )
continue
else :
dtype = v . split ( " :: " ) [ - 1 ]
v = v . split ( " :: " ) [ 0 ]
if dtype == " vertex " :
2025-04-03 15:16:48 +08:00
d [ k ] = json . loads ( v )
2025-03-08 11:45:59 +08:00
elif dtype == " edge " :
2025-04-03 15:16:48 +08:00
d [ k ] = json . loads ( v )
2025-01-27 09:39:39 +01:00
else :
2025-05-04 02:20:43 +08:00
d [ k ] = v # Keep as string
2025-01-27 09:39:39 +01:00
return d
@staticmethod
def _format_properties (
2025-02-19 13:31:30 +01:00
properties : dict [ str , Any ] , _id : Union [ str , None ] = None
2025-01-27 09:39:39 +01:00
) - > str :
"""
Convert a dictionary of properties to a string representation that
can be used in a cypher query insert / merge statement .
Args :
2025-02-19 13:31:30 +01:00
properties ( dict [ str , str ] ) : a dictionary containing node / edge properties
2025-01-27 09:39:39 +01:00
_id ( Union [ str , None ] ) : the id of the node or None if none exists
Returns :
str : the properties dictionary as a properly formatted string
"""
props = [ ]
# wrap property key in backticks to escape
for k , v in properties . items ( ) :
prop = f " ` { k } `: { json . dumps ( v ) } "
props . append ( prop )
if _id is not None and " id " not in properties :
props . append (
f " id: { json . dumps ( _id ) } " if isinstance ( _id , str ) else f " id: { _id } "
)
return " { " + " , " . join ( props ) + " } "
async def _query (
2025-02-19 13:31:30 +01:00
self ,
query : str ,
readonly : bool = True ,
upsert : bool = False ,
) - > list [ dict [ str , Any ] ] :
2025-01-27 09:39:39 +01:00
"""
Query the graph by taking a cypher query , converting it to an
age compatible query , executing it and converting the result
Args :
query ( str ) : a cypher query to be executed
Returns :
2025-02-19 13:31:30 +01:00
list [ dict [ str , Any ] ] : a list of dictionaries containing the result set
2025-01-27 09:39:39 +01:00
"""
try :
if readonly :
data = await self . db . query (
2025-02-19 13:31:30 +01:00
query ,
2025-01-27 09:39:39 +01:00
multirows = True ,
2025-02-19 14:26:46 +01:00
with_age = True ,
graph_name = self . graph_name ,
2025-01-27 09:39:39 +01:00
)
else :
data = await self . db . execute (
2025-02-19 13:31:30 +01:00
query ,
2025-01-27 09:39:39 +01:00
upsert = upsert ,
2025-02-19 14:26:46 +01:00
with_age = True ,
graph_name = self . graph_name ,
2025-01-27 09:39:39 +01:00
)
2025-02-19 14:26:46 +01:00
2025-01-27 09:39:39 +01:00
except Exception as e :
raise PGGraphQueryException (
{
" message " : f " Error executing graph query: { query } " ,
2025-02-19 13:31:30 +01:00
" wrapped " : query ,
2025-01-27 09:39:39 +01:00
" detail " : str ( e ) ,
}
) from e
if data is None :
result = [ ]
# decode records
else :
2025-02-19 13:31:30 +01:00
result = [ self . _record_to_dict ( d ) for d in data ]
2025-01-27 09:39:39 +01:00
return result
async def has_node ( self , node_id : str ) - > bool :
2025-04-17 02:31:56 +08:00
entity_name_label = self . _normalize_node_id ( node_id )
2025-01-27 09:39:39 +01:00
query = """ SELECT * FROM cypher( ' %s ' , $$
2025-04-03 15:16:48 +08:00
MATCH ( n : base { entity_id : " %s " } )
2025-01-27 09:39:39 +01:00
RETURN count ( n ) > 0 AS node_exists
$ $ ) AS ( node_exists bool ) """ % (self.graph_name, entity_name_label)
single_result = ( await self . _query ( query ) ) [ 0 ]
return single_result [ " node_exists " ]
async def has_edge ( self , source_node_id : str , target_node_id : str ) - > bool :
2025-04-17 02:31:56 +08:00
src_label = self . _normalize_node_id ( source_node_id )
tgt_label = self . _normalize_node_id ( target_node_id )
2025-01-27 09:39:39 +01:00
query = """ SELECT * FROM cypher( ' %s ' , $$
2025-04-16 11:39:06 +08:00
MATCH ( a : base { entity_id : " %s " } ) - [ r ] - ( b : base { entity_id : " %s " } )
2025-01-27 09:39:39 +01:00
RETURN COUNT ( r ) > 0 AS edge_exists
$ $ ) AS ( edge_exists bool ) """ % (
self . graph_name ,
src_label ,
tgt_label ,
)
single_result = ( await self . _query ( query ) ) [ 0 ]
2025-02-19 13:42:49 +01:00
2025-01-27 09:39:39 +01:00
return single_result [ " edge_exists " ]
2025-02-16 13:53:59 +01:00
async def get_node ( self , node_id : str ) - > dict [ str , str ] | None :
2025-04-03 15:40:31 +08:00
""" Get node by its label identifier, return only node properties """
2025-04-17 02:31:56 +08:00
label = self . _normalize_node_id ( node_id )
2025-01-27 09:39:39 +01:00
query = """ SELECT * FROM cypher( ' %s ' , $$
2025-04-03 15:16:48 +08:00
MATCH ( n : base { entity_id : " %s " } )
2025-01-27 09:39:39 +01:00
RETURN n
$ $ ) AS ( n agtype ) """ % (self.graph_name, label)
record = await self . _query ( query )
if record :
node = record [ 0 ]
2025-04-03 15:40:31 +08:00
node_dict = node [ " n " ] [ " properties " ]
2025-02-19 13:42:49 +01:00
2025-01-27 09:39:39 +01:00
return node_dict
return None
async def node_degree ( self , node_id : str ) - > int :
2025-04-17 02:31:56 +08:00
label = self . _normalize_node_id ( node_id )
2025-01-27 09:39:39 +01:00
query = """ SELECT * FROM cypher( ' %s ' , $$
2025-04-16 14:01:21 +08:00
MATCH ( n : base { entity_id : " %s " } ) - [ r ] - ( )
RETURN count ( r ) AS total_edge_count
2025-01-27 09:39:39 +01:00
$ $ ) AS ( total_edge_count integer ) """ % (self.graph_name, label)
record = ( await self . _query ( query ) ) [ 0 ]
if record :
edge_count = int ( record [ " total_edge_count " ] )
return edge_count
async def edge_degree ( self , src_id : str , tgt_id : str ) - > int :
src_degree = await self . node_degree ( src_id )
trg_degree = await self . node_degree ( tgt_id )
# Convert None to 0 for addition
src_degree = 0 if src_degree is None else src_degree
trg_degree = 0 if trg_degree is None else trg_degree
degrees = int ( src_degree ) + int ( trg_degree )
2025-02-19 13:42:49 +01:00
2025-01-27 09:39:39 +01:00
return degrees
async def get_edge (
self , source_node_id : str , target_node_id : str
2025-02-16 13:53:59 +01:00
) - > dict [ str , str ] | None :
2025-04-03 15:40:31 +08:00
""" Get edge properties between two nodes """
2025-04-03 15:40:55 +08:00
2025-04-17 02:31:56 +08:00
src_label = self . _normalize_node_id ( source_node_id )
tgt_label = self . _normalize_node_id ( target_node_id )
2025-01-27 09:39:39 +01:00
query = """ SELECT * FROM cypher( ' %s ' , $$
2025-04-16 11:39:06 +08:00
MATCH ( a : base { entity_id : " %s " } ) - [ r ] - ( b : base { entity_id : " %s " } )
2025-01-27 09:39:39 +01:00
RETURN properties ( r ) as edge_properties
LIMIT 1
$ $ ) AS ( edge_properties agtype ) """ % (
self . graph_name ,
src_label ,
tgt_label ,
)
record = await self . _query ( query )
if record and record [ 0 ] and record [ 0 ] [ " edge_properties " ] :
result = record [ 0 ] [ " edge_properties " ]
2025-02-19 13:42:49 +01:00
2025-01-27 09:39:39 +01:00
return result
2025-02-16 13:53:59 +01:00
async def get_node_edges ( self , source_node_id : str ) - > list [ tuple [ str , str ] ] | None :
2025-01-27 09:39:39 +01:00
"""
Retrieves all edges ( relationships ) for a particular node identified by its label .
2025-02-19 13:31:30 +01:00
: return : list of dictionaries containing edge information
2025-01-27 09:39:39 +01:00
"""
2025-04-17 02:31:56 +08:00
label = self . _normalize_node_id ( source_node_id )
2025-01-27 09:39:39 +01:00
query = """ SELECT * FROM cypher( ' %s ' , $$
2025-04-03 15:16:48 +08:00
MATCH ( n : base { entity_id : " %s " } )
2025-04-16 11:39:06 +08:00
OPTIONAL MATCH ( n ) - [ ] - ( connected )
2025-02-02 18:20:32 +08:00
RETURN n , connected
$ $ ) AS ( n agtype , connected agtype ) """ % (
2025-01-27 09:39:39 +01:00
self . graph_name ,
label ,
)
results = await self . _query ( query )
edges = [ ]
for record in results :
source_node = record [ " n " ] if record [ " n " ] else None
connected_node = record [ " connected " ] if record [ " connected " ] else None
2025-04-03 15:16:48 +08:00
if (
source_node
and connected_node
and " properties " in source_node
and " properties " in connected_node
) :
source_label = source_node [ " properties " ] . get ( " entity_id " )
target_label = connected_node [ " properties " ] . get ( " entity_id " )
2025-01-27 09:39:39 +01:00
2025-04-03 15:16:48 +08:00
if source_label and target_label :
edges . append ( ( source_label , target_label ) )
2025-01-27 09:39:39 +01:00
return edges
@retry (
stop = stop_after_attempt ( 3 ) ,
wait = wait_exponential ( multiplier = 1 , min = 4 , max = 10 ) ,
retry = retry_if_exception_type ( ( PGGraphQueryException , ) ) ,
2025-01-27 09:36:53 +01:00
)
2025-02-16 13:53:59 +01:00
async def upsert_node ( self , node_id : str , node_data : dict [ str , str ] ) - > None :
2025-04-03 18:41:11 +08:00
"""
Upsert a node in the Neo4j database .
Args :
node_id : The unique identifier for the node ( used as label )
node_data : Dictionary of node properties
"""
if " entity_id " not in node_data :
raise ValueError (
" PostgreSQL: node properties must contain an ' entity_id ' field "
)
2025-04-17 02:31:56 +08:00
label = self . _normalize_node_id ( node_id )
2025-04-03 15:16:48 +08:00
properties = self . _format_properties ( node_data )
2025-01-27 09:39:39 +01:00
query = """ SELECT * FROM cypher( ' %s ' , $$
2025-04-03 15:16:48 +08:00
MERGE ( n : base { entity_id : " %s " } )
2025-01-27 09:39:39 +01:00
SET n + = % s
RETURN n
$ $ ) AS ( n agtype ) """ % (
self . graph_name ,
label ,
2025-04-03 15:16:48 +08:00
properties ,
2025-01-27 09:39:39 +01:00
)
try :
await self . _query ( query , readonly = False , upsert = True )
2025-02-19 13:42:49 +01:00
2025-04-03 18:41:11 +08:00
except Exception :
2025-04-03 21:15:01 +08:00
logger . error ( f " POSTGRES, upsert_node error on node_id: ` { node_id } ` " )
2025-01-27 09:39:39 +01:00
raise
@retry (
stop = stop_after_attempt ( 3 ) ,
wait = wait_exponential ( multiplier = 1 , min = 4 , max = 10 ) ,
retry = retry_if_exception_type ( ( PGGraphQueryException , ) ) ,
2025-01-01 22:43:59 +08:00
)
2025-01-27 09:39:39 +01:00
async def upsert_edge (
2025-02-16 13:53:59 +01:00
self , source_node_id : str , target_node_id : str , edge_data : dict [ str , str ]
) - > None :
2025-01-27 09:39:39 +01:00
"""
Upsert an edge and its properties between two nodes identified by their labels .
2025-01-01 22:43:59 +08:00
2025-01-27 09:39:39 +01:00
Args :
source_node_id ( str ) : Label of the source node ( used as identifier )
target_node_id ( str ) : Label of the target node ( used as identifier )
2025-02-19 13:31:30 +01:00
edge_data ( dict ) : dictionary of properties to set on the edge
2025-01-27 09:39:39 +01:00
"""
2025-04-17 02:31:56 +08:00
src_label = self . _normalize_node_id ( source_node_id )
tgt_label = self . _normalize_node_id ( target_node_id )
2025-04-03 15:16:48 +08:00
edge_properties = self . _format_properties ( edge_data )
2025-01-01 22:43:59 +08:00
2025-01-27 09:39:39 +01:00
query = """ SELECT * FROM cypher( ' %s ' , $$
2025-04-03 15:16:48 +08:00
MATCH ( source : base { entity_id : " %s " } )
2025-01-27 09:39:39 +01:00
WITH source
2025-04-03 15:16:48 +08:00
MATCH ( target : base { entity_id : " %s " } )
2025-04-15 16:49:36 +08:00
MERGE ( source ) - [ r : DIRECTED ] - ( target )
2025-01-27 09:39:39 +01:00
SET r + = % s
2025-04-24 11:09:34 +08:00
SET r + = % s
2025-01-27 09:39:39 +01:00
RETURN r
$ $ ) AS ( r agtype ) """ % (
self . graph_name ,
src_label ,
tgt_label ,
2025-04-03 15:16:48 +08:00
edge_properties ,
2025-04-24 11:09:34 +08:00
edge_properties , # https://github.com/HKUDS/LightRAG/issues/1438#issuecomment-2826000195
2025-01-27 09:39:39 +01:00
)
2025-02-19 13:42:49 +01:00
2025-01-01 22:43:59 +08:00
try :
2025-01-27 09:39:39 +01:00
await self . _query ( query , readonly = False , upsert = True )
2025-02-19 13:42:49 +01:00
2025-04-03 18:41:11 +08:00
except Exception :
logger . error (
2025-04-03 21:15:01 +08:00
f " POSTGRES, upsert_edge error on edge: ` { source_node_id } `-` { target_node_id } ` "
2025-04-03 18:41:11 +08:00
)
2025-01-27 09:39:39 +01:00
raise
2025-02-16 13:53:59 +01:00
async def delete_node ( self , node_id : str ) - > None :
2025-03-04 15:50:53 +08:00
"""
Delete a node from the graph .
Args :
node_id ( str ) : The ID of the node to delete .
"""
2025-04-17 02:31:56 +08:00
label = self . _normalize_node_id ( node_id )
2025-03-04 15:50:53 +08:00
query = """ SELECT * FROM cypher( ' %s ' , $$
2025-04-02 14:03:56 +08:00
MATCH ( n : base { entity_id : " %s " } )
2025-03-04 15:50:53 +08:00
DETACH DELETE n
$ $ ) AS ( n agtype ) """ % (self.graph_name, label)
try :
await self . _query ( query , readonly = False )
except Exception as e :
logger . error ( " Error during node deletion: { %s } " , e )
raise
async def remove_nodes ( self , node_ids : list [ str ] ) - > None :
"""
Remove multiple nodes from the graph .
Args :
node_ids ( list [ str ] ) : A list of node IDs to remove .
"""
2025-04-17 02:31:56 +08:00
node_ids = [ self . _normalize_node_id ( node_id ) for node_id in node_ids ]
2025-04-03 15:16:48 +08:00
node_id_list = " , " . join ( [ f ' " { node_id } " ' for node_id in node_ids ] )
2025-03-04 15:50:53 +08:00
query = """ SELECT * FROM cypher( ' %s ' , $$
2025-04-02 14:03:56 +08:00
MATCH ( n : base )
2025-04-03 15:16:48 +08:00
WHERE n . entity_id IN [ % s ]
2025-03-04 15:50:53 +08:00
DETACH DELETE n
$ $ ) AS ( n agtype ) """ % (self.graph_name, node_id_list)
try :
await self . _query ( query , readonly = False )
except Exception as e :
logger . error ( " Error during node removal: { %s } " , e )
raise
async def remove_edges ( self , edges : list [ tuple [ str , str ] ] ) - > None :
"""
Remove multiple edges from the graph .
Args :
edges ( list [ tuple [ str , str ] ] ) : A list of edges to remove , where each edge is a tuple of ( source_node_id , target_node_id ) .
"""
2025-04-02 14:03:56 +08:00
for source , target in edges :
2025-04-17 02:31:56 +08:00
src_label = self . _normalize_node_id ( source )
tgt_label = self . _normalize_node_id ( target )
2025-03-04 15:50:53 +08:00
2025-04-02 14:03:56 +08:00
query = """ SELECT * FROM cypher( ' %s ' , $$
2025-04-15 16:49:36 +08:00
MATCH ( a : base { entity_id : " %s " } ) - [ r ] - ( b : base { entity_id : " %s " } )
2025-04-02 14:03:56 +08:00
DELETE r
$ $ ) AS ( r agtype ) """ % (self.graph_name, src_label, tgt_label)
2025-03-04 15:50:53 +08:00
2025-04-02 14:03:56 +08:00
try :
await self . _query ( query , readonly = False )
logger . debug ( f " Deleted edge from ' { source } ' to ' { target } ' " )
except Exception as e :
logger . error ( f " Error during edge deletion: { str ( e ) } " )
raise
2025-03-04 15:50:53 +08:00
2025-04-13 01:07:07 +08:00
async def get_nodes_batch ( self , node_ids : list [ str ] ) - > dict [ str , dict ] :
"""
Retrieve multiple nodes in one query using UNWIND .
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
Args :
node_ids : List of node entity IDs to fetch .
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
Returns :
A dictionary mapping each node_id to its node data ( or None if not found ) .
"""
if not node_ids :
return { }
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
# Format node IDs for the query
2025-04-15 12:34:04 +08:00
formatted_ids = " , " . join (
2025-04-17 02:31:56 +08:00
[ ' " ' + self . _normalize_node_id ( node_id ) + ' " ' for node_id in node_ids ]
2025-04-15 12:34:04 +08:00
)
2025-04-13 01:07:07 +08:00
query = """ SELECT * FROM cypher( ' %s ' , $$
UNWIND [ % s ] AS node_id
MATCH ( n : base { entity_id : node_id } )
RETURN node_id , n
2025-04-15 12:34:04 +08:00
$ $ ) AS ( node_id text , n agtype ) """ % (self.graph_name, formatted_ids)
2025-04-13 01:07:07 +08:00
results = await self . _query ( query )
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
# Build result dictionary
nodes_dict = { }
for result in results :
if result [ " node_id " ] and result [ " n " ] :
node_dict = result [ " n " ] [ " properties " ]
# Remove the 'base' label if present in a 'labels' property
if " labels " in node_dict :
2025-04-15 12:34:04 +08:00
node_dict [ " labels " ] = [
label for label in node_dict [ " labels " ] if label != " base "
]
2025-04-13 01:07:07 +08:00
nodes_dict [ result [ " node_id " ] ] = node_dict
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
return nodes_dict
async def node_degrees_batch ( self , node_ids : list [ str ] ) - > dict [ str , int ] :
"""
Retrieve the degree for multiple nodes in a single query using UNWIND .
2025-04-16 14:01:21 +08:00
Calculates the total degree by counting distinct relationships .
Uses separate queries for outgoing and incoming edges .
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
Args :
node_ids : List of node labels ( entity_id values ) to look up .
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
Returns :
2025-04-16 14:01:21 +08:00
A dictionary mapping each node_id to its degree ( total number of relationships ) .
2025-04-13 01:07:07 +08:00
If a node is not found , its degree will be set to 0.
"""
if not node_ids :
return { }
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
# Format node IDs for the query
2025-04-15 12:34:04 +08:00
formatted_ids = " , " . join (
2025-04-17 02:31:56 +08:00
[ ' " ' + self . _normalize_node_id ( node_id ) + ' " ' for node_id in node_ids ]
2025-04-15 12:34:04 +08:00
)
2025-04-16 14:01:21 +08:00
outgoing_query = """ SELECT * FROM cypher( ' %s ' , $$
2025-04-13 01:07:07 +08:00
UNWIND [ % s ] AS node_id
MATCH ( n : base { entity_id : node_id } )
2025-04-16 14:01:21 +08:00
OPTIONAL MATCH ( n ) - [ r ] - > ( a )
RETURN node_id , count ( a ) AS out_degree
$ $ ) AS ( node_id text , out_degree bigint ) """ % (
2025-04-13 01:07:07 +08:00
self . graph_name ,
2025-04-15 12:34:04 +08:00
formatted_ids ,
2025-04-13 01:07:07 +08:00
)
2025-04-15 12:34:04 +08:00
2025-04-16 14:01:21 +08:00
incoming_query = """ SELECT * FROM cypher( ' %s ' , $$
UNWIND [ % s ] AS node_id
MATCH ( n : base { entity_id : node_id } )
OPTIONAL MATCH ( n ) < - [ r ] - ( b )
RETURN node_id , count ( b ) AS in_degree
$ $ ) AS ( node_id text , in_degree bigint ) """ % (
self . graph_name ,
formatted_ids ,
)
2025-04-15 12:34:04 +08:00
2025-04-16 14:01:21 +08:00
outgoing_results = await self . _query ( outgoing_query )
incoming_results = await self . _query ( incoming_query )
2025-04-15 12:34:04 +08:00
2025-04-16 14:01:21 +08:00
out_degrees = { }
in_degrees = { }
2025-04-16 14:07:22 +08:00
2025-04-16 14:01:21 +08:00
for result in outgoing_results :
if result [ " node_id " ] is not None :
out_degrees [ result [ " node_id " ] ] = int ( result [ " out_degree " ] )
2025-04-16 14:07:22 +08:00
2025-04-16 14:01:21 +08:00
for result in incoming_results :
if result [ " node_id " ] is not None :
in_degrees [ result [ " node_id " ] ] = int ( result [ " in_degree " ] )
2025-04-16 14:07:22 +08:00
2025-04-16 14:01:21 +08:00
degrees_dict = { }
2025-04-13 01:07:07 +08:00
for node_id in node_ids :
2025-04-16 14:01:21 +08:00
out_degree = out_degrees . get ( node_id , 0 )
in_degree = in_degrees . get ( node_id , 0 )
degrees_dict [ node_id ] = out_degree + in_degree
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
return degrees_dict
2025-04-15 12:34:04 +08:00
async def edge_degrees_batch (
self , edges : list [ tuple [ str , str ] ]
) - > dict [ tuple [ str , str ] , int ] :
2025-04-13 01:07:07 +08:00
"""
Calculate the combined degree for each edge ( sum of the source and target node degrees )
in batch using the already implemented node_degrees_batch .
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
Args :
edges : List of ( source_node_id , target_node_id ) tuples
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
Returns :
Dictionary mapping edge tuples to their combined degrees
"""
if not edges :
return { }
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
# Use node_degrees_batch to get all node degrees efficiently
all_nodes = set ( )
for src , tgt in edges :
all_nodes . add ( src )
all_nodes . add ( tgt )
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
node_degrees = await self . node_degrees_batch ( list ( all_nodes ) )
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
# Calculate edge degrees
edge_degrees_dict = { }
for src , tgt in edges :
src_degree = node_degrees . get ( src , 0 )
tgt_degree = node_degrees . get ( tgt , 0 )
edge_degrees_dict [ ( src , tgt ) ] = src_degree + tgt_degree
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
return edge_degrees_dict
2025-04-15 12:34:04 +08:00
async def get_edges_batch (
self , pairs : list [ dict [ str , str ] ]
) - > dict [ tuple [ str , str ] , dict ] :
2025-04-13 01:07:07 +08:00
"""
Retrieve edge properties for multiple ( src , tgt ) pairs in one query .
2025-04-16 14:01:21 +08:00
Get forward and backward edges seperately and merge them before return
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
Args :
pairs : List of dictionaries , e . g . [ { " src " : " node1 " , " tgt " : " node2 " } , . . . ]
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
Returns :
A dictionary mapping ( src , tgt ) tuples to their edge properties .
"""
if not pairs :
return { }
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
src_nodes = [ ]
tgt_nodes = [ ]
for pair in pairs :
2025-04-17 02:31:56 +08:00
src_nodes . append ( self . _normalize_node_id ( pair [ " src " ] ) )
tgt_nodes . append ( self . _normalize_node_id ( pair [ " tgt " ] ) )
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
src_array = " , " . join ( [ f ' " { src } " ' for src in src_nodes ] )
tgt_array = " , " . join ( [ f ' " { tgt } " ' for tgt in tgt_nodes ] )
2025-04-15 12:34:04 +08:00
2025-04-16 14:01:21 +08:00
forward_query = f """ SELECT * FROM cypher( ' { self . graph_name } ' , $$
2025-04-13 01:07:07 +08:00
WITH [ { src_array } ] AS sources , [ { tgt_array } ] AS targets
UNWIND range ( 0 , size ( sources ) - 1 ) AS i
2025-04-16 02:13:30 +08:00
MATCH ( a : base { { entity_id : sources [ i ] } } ) - [ r : DIRECTED ] - > ( b : base { { entity_id : targets [ i ] } } )
2025-04-13 01:07:07 +08:00
RETURN sources [ i ] AS source , targets [ i ] AS target , properties ( r ) AS edge_properties
$ $ ) AS ( source text , target text , edge_properties agtype ) """
2025-04-15 12:34:04 +08:00
2025-04-16 14:01:21 +08:00
backward_query = f """ SELECT * FROM cypher( ' { self . graph_name } ' , $$
WITH [ { src_array } ] AS sources , [ { tgt_array } ] AS targets
UNWIND range ( 0 , size ( sources ) - 1 ) AS i
MATCH ( a : base { { entity_id : sources [ i ] } } ) < - [ r : DIRECTED ] - ( b : base { { entity_id : targets [ i ] } } )
RETURN sources [ i ] AS source , targets [ i ] AS target , properties ( r ) AS edge_properties
$ $ ) AS ( source text , target text , edge_properties agtype ) """
forward_results = await self . _query ( forward_query )
backward_results = await self . _query ( backward_query )
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
edges_dict = { }
2025-04-16 14:01:21 +08:00
for result in forward_results :
if result [ " source " ] and result [ " target " ] and result [ " edge_properties " ] :
2025-04-16 14:07:22 +08:00
edges_dict [ ( result [ " source " ] , result [ " target " ] ) ] = result [
" edge_properties "
]
2025-04-16 14:01:21 +08:00
for result in backward_results :
2025-04-13 01:07:07 +08:00
if result [ " source " ] and result [ " target " ] and result [ " edge_properties " ] :
2025-04-16 14:07:22 +08:00
edges_dict [ ( result [ " source " ] , result [ " target " ] ) ] = result [
" edge_properties "
]
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
return edges_dict
2025-04-15 12:34:04 +08:00
async def get_nodes_edges_batch (
self , node_ids : list [ str ]
) - > dict [ str , list [ tuple [ str , str ] ] ] :
2025-04-13 01:07:07 +08:00
"""
2025-04-16 14:01:21 +08:00
Get all edges ( both outgoing and incoming ) for multiple nodes in a single batch operation .
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
Args :
node_ids : List of node IDs to get edges for
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
Returns :
Dictionary mapping node IDs to lists of ( source , target ) edge tuples
"""
if not node_ids :
return { }
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
# Format node IDs for the query
2025-04-15 12:34:04 +08:00
formatted_ids = " , " . join (
2025-04-17 02:31:56 +08:00
[ ' " ' + self . _normalize_node_id ( node_id ) + ' " ' for node_id in node_ids ]
2025-04-15 12:34:04 +08:00
)
2025-04-16 14:01:21 +08:00
outgoing_query = """ SELECT * FROM cypher( ' %s ' , $$
2025-04-13 01:07:07 +08:00
UNWIND [ % s ] AS node_id
MATCH ( n : base { entity_id : node_id } )
2025-04-16 02:13:30 +08:00
OPTIONAL MATCH ( n : base ) - [ ] - > ( connected : base )
2025-04-13 01:07:07 +08:00
RETURN node_id , connected . entity_id AS connected_id
$ $ ) AS ( node_id text , connected_id text ) """ % (
self . graph_name ,
2025-04-15 12:34:04 +08:00
formatted_ids ,
2025-04-13 01:07:07 +08:00
)
2025-04-15 12:34:04 +08:00
2025-04-16 14:01:21 +08:00
incoming_query = """ SELECT * FROM cypher( ' %s ' , $$
UNWIND [ % s ] AS node_id
MATCH ( n : base { entity_id : node_id } )
OPTIONAL MATCH ( n : base ) < - [ ] - ( connected : base )
RETURN node_id , connected . entity_id AS connected_id
$ $ ) AS ( node_id text , connected_id text ) """ % (
self . graph_name ,
formatted_ids ,
)
outgoing_results = await self . _query ( outgoing_query )
incoming_results = await self . _query ( incoming_query )
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
nodes_edges_dict = { node_id : [ ] for node_id in node_ids }
2025-04-16 14:07:22 +08:00
2025-04-16 14:01:21 +08:00
for result in outgoing_results :
2025-04-13 01:07:07 +08:00
if result [ " node_id " ] and result [ " connected_id " ] :
nodes_edges_dict [ result [ " node_id " ] ] . append (
( result [ " node_id " ] , result [ " connected_id " ] )
)
2025-04-16 14:07:22 +08:00
2025-04-16 14:01:21 +08:00
for result in incoming_results :
if result [ " node_id " ] and result [ " connected_id " ] :
nodes_edges_dict [ result [ " node_id " ] ] . append (
( result [ " connected_id " ] , result [ " node_id " ] )
)
2025-04-15 12:34:04 +08:00
2025-04-13 01:07:07 +08:00
return nodes_edges_dict
2025-04-15 12:34:04 +08:00
2025-03-04 15:50:53 +08:00
async def get_all_labels ( self ) - > list [ str ] :
"""
Get all labels ( node IDs ) in the graph .
Returns :
list [ str ] : A list of all labels in the graph .
"""
2025-03-04 15:53:20 +08:00
query = (
""" SELECT * FROM cypher( ' %s ' , $$
2025-04-02 14:03:56 +08:00
MATCH ( n : base )
WHERE n . entity_id IS NOT NULL
RETURN DISTINCT n . entity_id AS label
2025-04-03 04:10:20 +08:00
ORDER BY n . entity_id
2025-03-04 15:53:20 +08:00
$ $ ) AS ( label text ) """
% self . graph_name
)
2025-03-04 15:50:53 +08:00
results = await self . _query ( query )
2025-04-22 19:49:31 +10:00
labels = [ ]
for result in results :
if result and isinstance ( result , dict ) and " label " in result :
labels . append ( result [ " label " ] )
2025-03-04 15:50:53 +08:00
return labels
2025-02-16 13:55:30 +01:00
2025-04-24 12:27:12 +08:00
async def _bfs_subgraph (
self , node_label : str , max_depth : int , max_nodes : int
) - > KnowledgeGraph :
"""
Implements a true breadth - first search algorithm for subgraph retrieval .
This method is used as a fallback when the standard Cypher query is too slow
or when we need to guarantee BFS ordering .
Args :
node_label : Label of the starting node
max_depth : Maximum depth of the subgraph
max_nodes : Maximum number of nodes to return
Returns :
KnowledgeGraph object containing nodes and edges
"""
from collections import deque
result = KnowledgeGraph ( )
visited_nodes = set ( )
visited_node_ids = set ( )
visited_edges = set ( )
visited_edge_pairs = set ( )
# Get starting node data
label = self . _normalize_node_id ( node_label )
query = """ SELECT * FROM cypher( ' %s ' , $$
MATCH ( n : base { entity_id : " %s " } )
RETURN id ( n ) as node_id , n
$ $ ) AS ( node_id bigint , n agtype ) """ % (self.graph_name, label)
node_result = await self . _query ( query )
if not node_result or not node_result [ 0 ] . get ( " n " ) :
return result
# Create initial KnowledgeGraphNode
start_node_data = node_result [ 0 ] [ " n " ]
entity_id = start_node_data [ " properties " ] [ " entity_id " ]
internal_id = str ( start_node_data [ " id " ] )
start_node = KnowledgeGraphNode (
id = internal_id ,
labels = [ entity_id ] ,
properties = start_node_data [ " properties " ] ,
)
# Initialize BFS queue, each element is a tuple of (node, depth)
queue = deque ( [ ( start_node , 0 ) ] )
visited_nodes . add ( entity_id )
visited_node_ids . add ( internal_id )
result . nodes . append ( start_node )
2025-04-25 16:04:41 +08:00
result . is_truncated = False
2025-04-25 16:55:47 +08:00
# BFS search main loop
2025-04-25 16:04:41 +08:00
while queue :
2025-04-25 16:55:47 +08:00
# Get all nodes at the current depth
current_level_nodes = [ ]
current_depth = None
2025-04-25 21:25:37 +08:00
2025-04-25 16:55:47 +08:00
# Determine current depth
if queue :
current_depth = queue [ 0 ] [ 1 ]
2025-04-25 21:25:37 +08:00
2025-04-25 16:55:47 +08:00
# Extract all nodes at current depth from the queue
while queue and queue [ 0 ] [ 1 ] == current_depth :
node , depth = queue . popleft ( )
if depth > max_depth :
continue
current_level_nodes . append ( node )
2025-04-25 21:25:37 +08:00
2025-04-25 16:55:47 +08:00
if not current_level_nodes :
continue
2025-04-25 21:25:37 +08:00
2025-04-25 16:55:47 +08:00
# Check depth limit
2025-04-25 16:04:41 +08:00
if current_depth > max_depth :
2025-04-24 12:27:12 +08:00
continue
2025-04-25 21:25:37 +08:00
2025-04-25 16:55:47 +08:00
# Prepare node IDs list
node_ids = [ node . labels [ 0 ] for node in current_level_nodes ]
2025-04-25 21:25:37 +08:00
formatted_ids = " , " . join (
[ f ' " { self . _normalize_node_id ( node_id ) } " ' for node_id in node_ids ]
)
2025-04-25 16:55:47 +08:00
# Construct batch query for outgoing edges
outgoing_query = f """ SELECT * FROM cypher( ' { self . graph_name } ' , $$
UNWIND [ { formatted_ids } ] AS node_id
MATCH ( n : base { { entity_id : node_id } } )
OPTIONAL MATCH ( n ) - [ r ] - > ( neighbor : base )
2025-04-25 21:25:37 +08:00
RETURN node_id AS current_id ,
id ( n ) AS current_internal_id ,
id ( neighbor ) AS neighbor_internal_id ,
neighbor . entity_id AS neighbor_id ,
id ( r ) AS edge_id ,
r ,
2025-04-25 16:55:47 +08:00
neighbor ,
true AS is_outgoing
2025-04-25 21:25:37 +08:00
$ $ ) AS ( current_id text , current_internal_id bigint , neighbor_internal_id bigint ,
2025-04-25 16:55:47 +08:00
neighbor_id text , edge_id bigint , r agtype , neighbor agtype , is_outgoing bool ) """
2025-04-25 21:25:37 +08:00
2025-04-25 16:55:47 +08:00
# Construct batch query for incoming edges
incoming_query = f """ SELECT * FROM cypher( ' { self . graph_name } ' , $$
UNWIND [ { formatted_ids } ] AS node_id
MATCH ( n : base { { entity_id : node_id } } )
OPTIONAL MATCH ( n ) < - [ r ] - ( neighbor : base )
2025-04-25 21:25:37 +08:00
RETURN node_id AS current_id ,
id ( n ) AS current_internal_id ,
id ( neighbor ) AS neighbor_internal_id ,
neighbor . entity_id AS neighbor_id ,
id ( r ) AS edge_id ,
r ,
2025-04-25 16:55:47 +08:00
neighbor ,
false AS is_outgoing
2025-04-25 21:25:37 +08:00
$ $ ) AS ( current_id text , current_internal_id bigint , neighbor_internal_id bigint ,
2025-04-25 16:55:47 +08:00
neighbor_id text , edge_id bigint , r agtype , neighbor agtype , is_outgoing bool ) """
2025-04-25 21:25:37 +08:00
2025-04-25 16:55:47 +08:00
# Execute queries
outgoing_results = await self . _query ( outgoing_query )
incoming_results = await self . _query ( incoming_query )
2025-04-25 21:25:37 +08:00
2025-04-25 16:55:47 +08:00
# Combine results
neighbors = outgoing_results + incoming_results
2025-04-25 21:25:37 +08:00
2025-04-25 16:55:47 +08:00
# Create mapping from node ID to node object
node_map = { node . labels [ 0 ] : node for node in current_level_nodes }
2025-04-25 21:25:37 +08:00
2025-04-25 16:55:47 +08:00
# Process all results in a single loop
2025-04-24 12:27:12 +08:00
for record in neighbors :
2025-04-25 16:55:47 +08:00
if not record . get ( " neighbor " ) or not record . get ( " r " ) :
2025-04-24 12:27:12 +08:00
continue
2025-04-25 21:25:37 +08:00
2025-04-25 16:55:47 +08:00
# Get current node information
current_entity_id = record [ " current_id " ]
current_node = node_map [ current_entity_id ]
2025-04-25 21:25:37 +08:00
2025-04-25 16:55:47 +08:00
# Get neighbor node information
neighbor_entity_id = record [ " neighbor_id " ]
neighbor_internal_id = str ( record [ " neighbor_internal_id " ] )
is_outgoing = record [ " is_outgoing " ]
2025-04-25 21:25:37 +08:00
2025-04-25 16:55:47 +08:00
# Determine edge direction
if is_outgoing :
source_id = current_node . id
target_id = neighbor_internal_id
else :
source_id = neighbor_internal_id
target_id = current_node . id
2025-04-25 21:25:37 +08:00
2025-04-25 16:55:47 +08:00
if not neighbor_entity_id :
continue
2025-04-25 21:25:37 +08:00
2025-04-25 16:55:47 +08:00
# Get edge and node information
b_node = record [ " neighbor " ]
2025-04-24 12:27:12 +08:00
rel = record [ " r " ]
edge_id = str ( record [ " edge_id " ] )
2025-04-25 21:25:37 +08:00
2025-04-25 16:55:47 +08:00
# Create neighbor node object
neighbor_node = KnowledgeGraphNode (
id = neighbor_internal_id ,
labels = [ neighbor_entity_id ] ,
2025-04-24 12:27:12 +08:00
properties = b_node [ " properties " ] ,
)
2025-04-25 21:25:37 +08:00
2025-04-25 16:04:41 +08:00
# Sort entity_ids to ensure (A,B) and (B,A) are treated as the same edge
2025-04-25 16:55:47 +08:00
sorted_pair = tuple ( sorted ( [ current_entity_id , neighbor_entity_id ] ) )
2025-04-25 21:25:37 +08:00
2025-04-24 12:27:12 +08:00
# Create edge object
edge = KnowledgeGraphEdge (
id = edge_id ,
type = rel [ " label " ] ,
2025-04-25 16:55:47 +08:00
source = source_id ,
target = target_id ,
2025-04-24 12:27:12 +08:00
properties = rel [ " properties " ] ,
)
2025-04-25 21:25:37 +08:00
2025-04-25 16:55:47 +08:00
if neighbor_internal_id in visited_node_ids :
# Add backward edge if neighbor node is already visited
2025-04-25 16:04:41 +08:00
if (
edge_id not in visited_edges
and sorted_pair not in visited_edge_pairs
) :
result . edges . append ( edge )
visited_edges . add ( edge_id )
visited_edge_pairs . add ( sorted_pair )
else :
if len ( visited_node_ids ) < max_nodes and current_depth < max_depth :
2025-04-25 16:55:47 +08:00
# Add new node to result and queue
result . nodes . append ( neighbor_node )
visited_nodes . add ( neighbor_entity_id )
visited_node_ids . add ( neighbor_internal_id )
2025-04-25 21:25:37 +08:00
2025-04-25 16:04:41 +08:00
# Add node to queue with incremented depth
2025-04-25 16:55:47 +08:00
queue . append ( ( neighbor_node , current_depth + 1 ) )
2025-04-25 21:25:37 +08:00
2025-04-25 16:04:41 +08:00
# Add forward edge
if (
edge_id not in visited_edges
and sorted_pair not in visited_edge_pairs
) :
result . edges . append ( edge )
visited_edges . add ( edge_id )
visited_edge_pairs . add ( sorted_pair )
else :
if current_depth < max_depth :
result . is_truncated = True
2025-04-24 12:27:12 +08:00
return result
2025-02-20 14:29:36 +01:00
async def get_knowledge_graph (
2025-04-03 15:16:48 +08:00
self ,
node_label : str ,
max_depth : int = 3 ,
max_nodes : int = MAX_GRAPH_NODES ,
2025-02-20 14:29:36 +01:00
) - > KnowledgeGraph :
2025-03-04 15:50:53 +08:00
"""
2025-04-03 15:16:48 +08:00
Retrieve a connected subgraph of nodes where the label includes the specified ` node_label ` .
2025-03-04 15:50:53 +08:00
Args :
2025-04-03 15:16:48 +08:00
node_label : Label of the starting node , * means all nodes
max_depth : Maximum depth of the subgraph , Defaults to 3
2025-04-24 12:27:12 +08:00
max_nodes : Maxiumu nodes to return , Defaults to 1000
2025-03-04 15:50:53 +08:00
Returns :
2025-04-03 15:16:48 +08:00
KnowledgeGraph object containing nodes and edges , with an is_truncated flag
indicating whether the graph was truncated due to max_nodes limit
2025-03-04 15:50:53 +08:00
"""
2025-04-25 21:25:37 +08:00
kg = KnowledgeGraph ( )
2025-04-24 12:27:12 +08:00
# Handle wildcard query - get all nodes
2025-04-03 16:30:06 +08:00
if node_label == " * " :
2025-04-24 12:27:12 +08:00
# First check total node count to determine if graph should be truncated
2025-04-03 16:30:06 +08:00
count_query = f """ SELECT * FROM cypher( ' { self . graph_name } ' , $$
MATCH ( n : base )
RETURN count ( distinct n ) AS total_nodes
$ $ ) AS ( total_nodes bigint ) """
2025-03-04 15:50:53 +08:00
2025-04-24 12:27:12 +08:00
count_result = await self . _query ( count_query )
total_nodes = count_result [ 0 ] [ " total_nodes " ] if count_result else 0
is_truncated = total_nodes > max_nodes
2025-04-03 16:30:06 +08:00
2025-04-25 21:25:37 +08:00
# Get max_nodes with highest degrees
query_nodes = f """ SELECT * FROM cypher( ' { self . graph_name } ' , $$
MATCH ( n : base )
OPTIONAL MATCH ( n ) - [ r ] - > ( )
RETURN id ( n ) as node_id , count ( r ) as degree
$ $ ) AS ( node_id BIGINT , degree BIGINT )
ORDER BY degree DESC
LIMIT { max_nodes } """
node_results = await self . _query ( query_nodes )
node_ids = [ str ( result [ " node_id " ] ) for result in node_results ]
logger . info ( f " Total nodes: { total_nodes } , Selected nodes: { len ( node_ids ) } " )
if node_ids :
formatted_ids = " , " . join ( node_ids )
# Construct batch query for subgraph within max_nodes
query = f """ SELECT * FROM cypher( ' { self . graph_name } ' , $$
WITH [ { formatted_ids } ] AS node_ids
MATCH ( a )
WHERE id ( a ) IN node_ids
OPTIONAL MATCH ( a ) - [ r ] - > ( b )
WHERE id ( b ) IN node_ids
RETURN a , r , b
$ $ ) AS ( a AGTYPE , r AGTYPE , b AGTYPE ) """
results = await self . _query ( query )
# Process query results, deduplicate nodes and edges
nodes_dict = { }
edges_dict = { }
for result in results :
2025-04-25 21:54:04 +08:00
# Process node a
2025-04-25 21:25:37 +08:00
if result . get ( " a " ) and isinstance ( result [ " a " ] , dict ) :
node_a = result [ " a " ]
node_id = str ( node_a [ " id " ] )
if node_id not in nodes_dict and " properties " in node_a :
nodes_dict [ node_id ] = KnowledgeGraphNode (
id = node_id ,
labels = [ node_a [ " properties " ] [ " entity_id " ] ] ,
properties = node_a [ " properties " ] ,
)
2025-04-25 21:54:04 +08:00
# Process node b
2025-04-25 21:25:37 +08:00
if result . get ( " b " ) and isinstance ( result [ " b " ] , dict ) :
node_b = result [ " b " ]
node_id = str ( node_b [ " id " ] )
if node_id not in nodes_dict and " properties " in node_b :
nodes_dict [ node_id ] = KnowledgeGraphNode (
id = node_id ,
labels = [ node_b [ " properties " ] [ " entity_id " ] ] ,
properties = node_b [ " properties " ] ,
)
2025-04-25 21:54:04 +08:00
# Process edge r
2025-04-25 21:25:37 +08:00
if result . get ( " r " ) and isinstance ( result [ " r " ] , dict ) :
edge = result [ " r " ]
edge_id = str ( edge [ " id " ] )
if edge_id not in edges_dict :
edges_dict [ edge_id ] = KnowledgeGraphEdge (
id = edge_id ,
type = edge [ " label " ] ,
source = str ( edge [ " start_id " ] ) ,
target = str ( edge [ " end_id " ] ) ,
properties = edge [ " properties " ] ,
)
kg = KnowledgeGraph (
nodes = list ( nodes_dict . values ( ) ) ,
edges = list ( edges_dict . values ( ) ) ,
is_truncated = is_truncated ,
)
else :
# For single node query, use BFS algorithm
kg = await self . _bfs_subgraph ( node_label , max_depth , max_nodes )
logger . info (
f " Subgraph query successful | Node count: { len ( kg . nodes ) } | Edge count: { len ( kg . edges ) } "
2025-04-24 12:27:12 +08:00
)
else :
2025-04-25 21:54:04 +08:00
# For non-wildcard queries, use the BFS algorithm
2025-04-24 12:27:12 +08:00
kg = await self . _bfs_subgraph ( node_label , max_depth , max_nodes )
2025-04-25 21:25:37 +08:00
logger . info (
f " Subgraph query for ' { node_label } ' successful | Node count: { len ( kg . nodes ) } | Edge count: { len ( kg . edges ) } "
)
2025-03-04 15:50:53 +08:00
return kg
2025-03-31 01:03:41 +08:00
async def drop ( self ) - > dict [ str , str ] :
2025-02-18 10:01:21 +01:00
""" Drop the storage """
2025-03-31 01:03:41 +08:00
try :
drop_query = f """ SELECT * FROM cypher( ' { self . graph_name } ' , $$
MATCH ( n )
DETACH DELETE n
$ $ ) AS ( result agtype ) """
2025-03-31 23:22:27 +08:00
2025-03-31 01:03:41 +08:00
await self . _query ( drop_query , readonly = False )
return { " status " : " success " , " message " : " graph data dropped " }
except Exception as e :
logger . error ( f " Error dropping graph: { e } " )
return { " status " : " error " , " message " : str ( e ) }
2025-02-16 13:55:30 +01:00
2025-02-18 10:24:19 +01:00
2025-01-27 09:39:39 +01:00
NAMESPACE_TABLE_MAP = {
2025-02-08 16:05:59 +08:00
NameSpace . KV_STORE_FULL_DOCS : " LIGHTRAG_DOC_FULL " ,
NameSpace . KV_STORE_TEXT_CHUNKS : " LIGHTRAG_DOC_CHUNKS " ,
NameSpace . VECTOR_STORE_CHUNKS : " LIGHTRAG_DOC_CHUNKS " ,
NameSpace . VECTOR_STORE_ENTITIES : " LIGHTRAG_VDB_ENTITY " ,
NameSpace . VECTOR_STORE_RELATIONSHIPS : " LIGHTRAG_VDB_RELATION " ,
NameSpace . DOC_STATUS : " LIGHTRAG_DOC_STATUS " ,
NameSpace . KV_STORE_LLM_RESPONSE_CACHE : " LIGHTRAG_LLM_CACHE " ,
2025-01-27 09:39:39 +01:00
}
2025-01-01 22:43:59 +08:00
2025-02-08 16:05:59 +08:00
def namespace_to_table_name ( namespace : str ) - > str :
for k , v in NAMESPACE_TABLE_MAP . items ( ) :
if is_namespace ( namespace , k ) :
return v
2025-01-27 09:39:39 +01:00
TABLES = {
" LIGHTRAG_DOC_FULL " : {
" ddl " : """ CREATE TABLE LIGHTRAG_DOC_FULL (
id VARCHAR ( 255 ) ,
workspace VARCHAR ( 255 ) ,
doc_name VARCHAR ( 1024 ) ,
content TEXT ,
meta JSONB ,
2025-05-01 15:13:42 +08:00
create_time TIMESTAMP ( 0 )
update_time TIMESTAMP ( 0 )
2025-01-27 09:39:39 +01:00
CONSTRAINT LIGHTRAG_DOC_FULL_PK PRIMARY KEY ( workspace , id )
) """
} ,
" LIGHTRAG_DOC_CHUNKS " : {
" ddl " : """ CREATE TABLE LIGHTRAG_DOC_CHUNKS (
id VARCHAR ( 255 ) ,
workspace VARCHAR ( 255 ) ,
full_doc_id VARCHAR ( 256 ) ,
chunk_order_index INTEGER ,
tokens INTEGER ,
content TEXT ,
content_vector VECTOR ,
2025-03-17 23:59:47 +08:00
file_path VARCHAR ( 256 ) ,
2025-05-03 00:44:55 +08:00
create_time TIMESTAMP ( 0 ) WITH TIME ZONE ,
update_time TIMESTAMP ( 0 ) WITH TIME ZONE ,
2025-01-27 09:39:39 +01:00
CONSTRAINT LIGHTRAG_DOC_CHUNKS_PK PRIMARY KEY ( workspace , id )
) """
} ,
" LIGHTRAG_VDB_ENTITY " : {
" ddl " : """ CREATE TABLE LIGHTRAG_VDB_ENTITY (
id VARCHAR ( 255 ) ,
workspace VARCHAR ( 255 ) ,
entity_name VARCHAR ( 255 ) ,
content TEXT ,
content_vector VECTOR ,
2025-05-03 00:44:55 +08:00
create_time TIMESTAMP ( 0 ) WITH TIME ZONE ,
update_time TIMESTAMP ( 0 ) WITH TIME ZONE ,
2025-03-19 12:59:44 +08:00
chunk_ids VARCHAR ( 255 ) [ ] NULL ,
2025-03-17 23:59:47 +08:00
file_path TEXT NULL ,
2025-01-27 09:39:39 +01:00
CONSTRAINT LIGHTRAG_VDB_ENTITY_PK PRIMARY KEY ( workspace , id )
) """
} ,
" LIGHTRAG_VDB_RELATION " : {
" ddl " : """ CREATE TABLE LIGHTRAG_VDB_RELATION (
id VARCHAR ( 255 ) ,
workspace VARCHAR ( 255 ) ,
source_id VARCHAR ( 256 ) ,
target_id VARCHAR ( 256 ) ,
content TEXT ,
content_vector VECTOR ,
2025-05-03 00:44:55 +08:00
create_time TIMESTAMP ( 0 ) WITH TIME ZONE ,
update_time TIMESTAMP ( 0 ) WITH TIME ZONE ,
2025-03-19 12:59:44 +08:00
chunk_ids VARCHAR ( 255 ) [ ] NULL ,
2025-03-17 23:59:47 +08:00
file_path TEXT NULL ,
2025-01-27 09:39:39 +01:00
CONSTRAINT LIGHTRAG_VDB_RELATION_PK PRIMARY KEY ( workspace , id )
) """
} ,
" LIGHTRAG_LLM_CACHE " : {
" ddl " : """ CREATE TABLE LIGHTRAG_LLM_CACHE (
workspace varchar ( 255 ) NOT NULL ,
id varchar ( 255 ) NOT NULL ,
mode varchar ( 32 ) NOT NULL ,
original_prompt TEXT ,
return_value TEXT ,
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP ,
update_time TIMESTAMP ,
CONSTRAINT LIGHTRAG_LLM_CACHE_PK PRIMARY KEY ( workspace , mode , id )
) """
} ,
" LIGHTRAG_DOC_STATUS " : {
" ddl " : """ CREATE TABLE LIGHTRAG_DOC_STATUS (
workspace varchar ( 255 ) NOT NULL ,
id varchar ( 255 ) NOT NULL ,
2025-02-11 16:11:15 +08:00
content TEXT NULL ,
2025-01-27 09:39:39 +01:00
content_summary varchar ( 255 ) NULL ,
content_length int4 NULL ,
chunks_count int4 NULL ,
status varchar ( 64 ) NULL ,
2025-03-17 23:59:47 +08:00
file_path TEXT NULL ,
2025-05-01 10:04:17 +08:00
created_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP NULL ,
updated_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP NULL ,
2025-01-27 09:39:39 +01:00
CONSTRAINT LIGHTRAG_DOC_STATUS_PK PRIMARY KEY ( workspace , id )
) """
} ,
}
2025-01-01 22:43:59 +08:00
2025-01-27 09:39:39 +01:00
SQL_TEMPLATES = {
# SQL for KVStorage
" get_by_id_full_docs " : """ SELECT id, COALESCE(content, ' ' ) as content
FROM LIGHTRAG_DOC_FULL WHERE workspace = $ 1 AND id = $ 2
""" ,
" get_by_id_text_chunks " : """ SELECT id, tokens, COALESCE(content, ' ' ) as content,
2025-04-04 12:27:36 +02:00
chunk_order_index , full_doc_id , file_path
2025-01-27 09:39:39 +01:00
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace = $ 1 AND id = $ 2
""" ,
" get_by_id_llm_response_cache " : """ SELECT id, original_prompt, COALESCE(return_value, ' ' ) as " return " , mode
FROM LIGHTRAG_LLM_CACHE WHERE workspace = $ 1 AND mode = $ 2
""" ,
" get_by_mode_id_llm_response_cache " : """ SELECT id, original_prompt, COALESCE(return_value, ' ' ) as " return " , mode
FROM LIGHTRAG_LLM_CACHE WHERE workspace = $ 1 AND mode = $ 2 AND id = $ 3
""" ,
" get_by_ids_full_docs " : """ SELECT id, COALESCE(content, ' ' ) as content
FROM LIGHTRAG_DOC_FULL WHERE workspace = $ 1 AND id IN ( { ids } )
""" ,
" get_by_ids_text_chunks " : """ SELECT id, tokens, COALESCE(content, ' ' ) as content,
2025-04-04 12:27:36 +02:00
chunk_order_index , full_doc_id , file_path
2025-01-27 09:39:39 +01:00
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace = $ 1 AND id IN ( { ids } )
""" ,
" get_by_ids_llm_response_cache " : """ SELECT id, original_prompt, COALESCE(return_value, ' ' ) as " return " , mode
FROM LIGHTRAG_LLM_CACHE WHERE workspace = $ 1 AND mode = IN ( { ids } )
""" ,
" filter_keys " : " SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ( {ids} ) " ,
" upsert_doc_full " : """ INSERT INTO LIGHTRAG_DOC_FULL (id, content, workspace)
VALUES ( $ 1 , $ 2 , $ 3 )
ON CONFLICT ( workspace , id ) DO UPDATE
SET content = $ 2 , update_time = CURRENT_TIMESTAMP
""" ,
" upsert_llm_response_cache " : """ INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,mode)
VALUES ( $ 1 , $ 2 , $ 3 , $ 4 , $ 5 )
ON CONFLICT ( workspace , mode , id ) DO UPDATE
SET original_prompt = EXCLUDED . original_prompt ,
return_value = EXCLUDED . return_value ,
mode = EXCLUDED . mode ,
update_time = CURRENT_TIMESTAMP
""" ,
" upsert_chunk " : """ INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens,
2025-05-01 15:13:42 +08:00
chunk_order_index , full_doc_id , content , content_vector , file_path ,
create_time , update_time )
VALUES ( $ 1 , $ 2 , $ 3 , $ 4 , $ 5 , $ 6 , $ 7 , $ 8 , $ 9 , $ 10 )
2025-01-27 09:39:39 +01:00
ON CONFLICT ( workspace , id ) DO UPDATE
SET tokens = EXCLUDED . tokens ,
chunk_order_index = EXCLUDED . chunk_order_index ,
full_doc_id = EXCLUDED . full_doc_id ,
content = EXCLUDED . content ,
content_vector = EXCLUDED . content_vector ,
2025-03-17 23:59:47 +08:00
file_path = EXCLUDED . file_path ,
2025-05-01 15:13:42 +08:00
update_time = EXCLUDED . update_time
2025-01-27 09:39:39 +01:00
""" ,
2025-03-31 01:03:41 +08:00
# SQL for VectorStorage
2025-03-10 15:39:18 +00:00
" upsert_entity " : """ INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content,
2025-05-01 15:13:42 +08:00
content_vector , chunk_ids , file_path , create_time , update_time )
VALUES ( $ 1 , $ 2 , $ 3 , $ 4 , $ 5 , $ 6 : : varchar [ ] , $ 7 , $ 8 , $ 9 )
2025-01-27 09:39:39 +01:00
ON CONFLICT ( workspace , id ) DO UPDATE
SET entity_name = EXCLUDED . entity_name ,
content = EXCLUDED . content ,
content_vector = EXCLUDED . content_vector ,
2025-03-13 13:45:09 +02:00
chunk_ids = EXCLUDED . chunk_ids ,
2025-03-17 23:59:47 +08:00
file_path = EXCLUDED . file_path ,
2025-05-01 15:13:42 +08:00
update_time = EXCLUDED . update_time
2025-01-27 09:39:39 +01:00
""" ,
" upsert_relationship " : """ INSERT INTO LIGHTRAG_VDB_RELATION (workspace, id, source_id,
2025-05-01 15:13:42 +08:00
target_id , content , content_vector , chunk_ids , file_path , create_time , update_time )
VALUES ( $ 1 , $ 2 , $ 3 , $ 4 , $ 5 , $ 6 , $ 7 : : varchar [ ] , $ 8 , $ 9 , $ 10 )
2025-01-27 09:39:39 +01:00
ON CONFLICT ( workspace , id ) DO UPDATE
SET source_id = EXCLUDED . source_id ,
target_id = EXCLUDED . target_id ,
content = EXCLUDED . content ,
2025-03-13 13:45:09 +02:00
content_vector = EXCLUDED . content_vector ,
chunk_ids = EXCLUDED . chunk_ids ,
2025-03-17 23:59:47 +08:00
file_path = EXCLUDED . file_path ,
2025-05-01 15:13:42 +08:00
update_time = EXCLUDED . update_time
2025-03-31 23:22:27 +08:00
""" ,
2025-03-08 15:43:17 +00:00
" relationships " : """
WITH relevant_chunks AS (
2025-03-10 15:39:18 +00:00
SELECT id as chunk_id
FROM LIGHTRAG_DOC_CHUNKS
2025-04-21 21:26:57 +03:00
WHERE $ 2 : : varchar [ ] IS NULL OR full_doc_id = ANY ( $ 2 : : varchar [ ] )
2025-03-08 15:43:17 +00:00
)
2025-05-01 15:13:42 +08:00
SELECT source_id as src_id , target_id as tgt_id , EXTRACT ( EPOCH FROM create_time ) : : BIGINT as created_at
2025-03-08 15:43:17 +00:00
FROM (
2025-05-01 15:13:42 +08:00
SELECT r . id , r . source_id , r . target_id , r . create_time , 1 - ( r . content_vector < = > ' [ {embedding_string} ] ' : : vector ) as distance
2025-03-08 15:43:17 +00:00
FROM LIGHTRAG_VDB_RELATION r
2025-03-13 13:45:09 +02:00
JOIN relevant_chunks c ON c . chunk_id = ANY ( r . chunk_ids )
2025-03-10 15:39:18 +00:00
WHERE r . workspace = $ 1
2025-03-08 15:43:17 +00:00
) filtered
2025-04-21 21:20:21 +03:00
WHERE distance > $ 3
2025-03-10 15:39:18 +00:00
ORDER BY distance DESC
2025-04-21 21:20:21 +03:00
LIMIT $ 4
2025-03-08 15:43:17 +00:00
""" ,
2025-03-10 15:39:18 +00:00
" entities " : """
2025-03-08 15:43:17 +00:00
WITH relevant_chunks AS (
2025-03-10 15:39:18 +00:00
SELECT id as chunk_id
FROM LIGHTRAG_DOC_CHUNKS
2025-04-21 21:26:57 +03:00
WHERE $ 2 : : varchar [ ] IS NULL OR full_doc_id = ANY ( $ 2 : : varchar [ ] )
2025-03-08 15:43:17 +00:00
)
2025-05-01 15:13:42 +08:00
SELECT entity_name , EXTRACT ( EPOCH FROM create_time ) : : BIGINT as created_at FROM
2025-03-08 15:43:17 +00:00
(
2025-05-01 15:13:42 +08:00
SELECT e . id , e . entity_name , e . create_time , 1 - ( e . content_vector < = > ' [ {embedding_string} ] ' : : vector ) as distance
2025-03-13 13:45:09 +02:00
FROM LIGHTRAG_VDB_ENTITY e
JOIN relevant_chunks c ON c . chunk_id = ANY ( e . chunk_ids )
WHERE e . workspace = $ 1
2025-04-15 14:57:17 +08:00
) as chunk_distances
2025-04-21 21:20:21 +03:00
WHERE distance > $ 3
2025-04-15 14:57:17 +08:00
ORDER BY distance DESC
2025-04-21 21:20:21 +03:00
LIMIT $ 4
2025-03-10 15:39:18 +00:00
""" ,
" chunks " : """
2025-03-08 15:43:17 +00:00
WITH relevant_chunks AS (
2025-03-10 15:39:18 +00:00
SELECT id as chunk_id
FROM LIGHTRAG_DOC_CHUNKS
2025-04-21 21:26:57 +03:00
WHERE $ 2 : : varchar [ ] IS NULL OR full_doc_id = ANY ( $ 2 : : varchar [ ] )
2025-03-08 15:43:17 +00:00
)
2025-05-01 15:13:42 +08:00
SELECT id , content , file_path , EXTRACT ( EPOCH FROM create_time ) : : BIGINT as created_at FROM
2025-03-08 15:43:17 +00:00
(
2025-05-01 15:13:42 +08:00
SELECT id , content , file_path , create_time , 1 - ( content_vector < = > ' [ {embedding_string} ] ' : : vector ) as distance
2025-03-10 15:39:18 +00:00
FROM LIGHTRAG_DOC_CHUNKS
2025-04-21 21:20:21 +03:00
WHERE workspace = $ 1
2025-03-08 20:25:20 +00:00
AND id IN ( SELECT chunk_id FROM relevant_chunks )
2025-03-17 10:47:17 +08:00
) as chunk_distances
2025-04-21 21:20:21 +03:00
WHERE distance > $ 3
2025-03-10 15:39:18 +00:00
ORDER BY distance DESC
2025-04-21 21:20:21 +03:00
LIMIT $ 4
2025-03-10 15:39:18 +00:00
""" ,
2025-03-31 01:03:41 +08:00
# DROP tables
" drop_specifiy_table_workspace " : """
DELETE FROM { table_name } WHERE workspace = $ 1
""" ,
2025-03-10 15:39:18 +00:00
}