mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-07-19 06:53:02 +00:00
Oracle Database support
Add oracle 23ai database as the KV/vector/graph storage
This commit is contained in:
parent
d86aed734d
commit
1bc4e2382b
127
examples/lightrag_oracle_demo.py
Normal file
127
examples/lightrag_oracle_demo.py
Normal file
@ -0,0 +1,127 @@
|
||||
|
||||
|
||||
import sys, os
|
||||
print(os.getcwd())
|
||||
from pathlib import Path
|
||||
script_directory = Path(__file__).resolve().parent.parent
|
||||
sys.path.append(os.path.abspath(script_directory))
|
||||
|
||||
import asyncio
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.llm import openai_complete_if_cache, openai_embedding
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
|
||||
from lightrag.kg.oracle_impl import OracleDB
|
||||
|
||||
|
||||
WORKING_DIR = "./dickens"
|
||||
|
||||
# We use OpenAI compatible API to call LLM on Oracle Cloud
|
||||
# More docs here https://github.com/jin38324/OCI_GenAI_access_gateway
|
||||
BASE_URL = "http://xxx.xxx.xxx.xxx:8088/v1/"
|
||||
APIKEY = "ocigenerativeai"
|
||||
CHATMODEL = "cohere.command-r-plus"
|
||||
EMBEDMODEL = "cohere.embed-multilingual-v3.0"
|
||||
|
||||
|
||||
if not os.path.exists(WORKING_DIR):
|
||||
os.mkdir(WORKING_DIR)
|
||||
|
||||
async def llm_model_func(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
return await openai_complete_if_cache(
|
||||
CHATMODEL,
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
api_key=APIKEY,
|
||||
base_url=BASE_URL,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||
return await openai_embedding(
|
||||
texts,
|
||||
model=EMBEDMODEL,
|
||||
api_key=APIKEY,
|
||||
base_url=BASE_URL,
|
||||
)
|
||||
|
||||
|
||||
async def get_embedding_dim():
|
||||
test_text = ["This is a test sentence."]
|
||||
embedding = await embedding_func(test_text)
|
||||
embedding_dim = embedding.shape[1]
|
||||
return embedding_dim
|
||||
|
||||
|
||||
async def main():
|
||||
try:
|
||||
# Detect embedding dimension
|
||||
embedding_dimension = await get_embedding_dim()
|
||||
print(f"Detected embedding dimension: {embedding_dimension}")
|
||||
|
||||
# Create Oracle DB connection
|
||||
# The `config` parameter is the connection configuration of Oracle DB
|
||||
# More docs here https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html
|
||||
# We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
|
||||
# Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
|
||||
oracle_db = OracleDB(config={
|
||||
"user":"RAG",
|
||||
"password":"xxxxxxxxx",
|
||||
"dsn":"xxxxxxx_medium",
|
||||
"config_dir":"dir/path/to/oracle/config",
|
||||
"wallet_location":"dir/path/to/oracle/wallet",
|
||||
"wallet_password":"xxxxxxxxx",
|
||||
"workspace":"company" # specify which docs we want to store and query
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# Check if Oracle DB tables exist, if not, tables will be created
|
||||
await oracle_db.check_tables()
|
||||
|
||||
|
||||
# Initialize LightRAG
|
||||
# We use Oracle DB as the KV/vector/graph storage
|
||||
rag = LightRAG(
|
||||
enable_llm_cache=False,
|
||||
working_dir=WORKING_DIR,
|
||||
chunk_token_size=512,
|
||||
llm_model_func=llm_model_func,
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=embedding_dimension,
|
||||
max_token_size=512,
|
||||
func=embedding_func,
|
||||
),
|
||||
graph_storage = "OracleGraphStorage",
|
||||
kv_storage="OracleKVStorage",
|
||||
vector_storage="OracleVectorDBStorage"
|
||||
)
|
||||
|
||||
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
|
||||
rag.graph_storage_cls.db = oracle_db
|
||||
rag.key_string_value_json_storage_cls.db = oracle_db
|
||||
rag.vector_db_storage_cls.db = oracle_db
|
||||
|
||||
# Extract and Insert into LightRAG storage
|
||||
with open("./dickens/demo.txt", "r", encoding="utf-8") as f:
|
||||
await rag.ainsert(f.read())
|
||||
|
||||
# Perform search in different modes
|
||||
modes = ["naive", "local", "global", "hybrid"]
|
||||
for mode in modes:
|
||||
print("="*20, mode, "="*20)
|
||||
print(await rag.aquery("这个文章讲了什么?", param=QueryParam(mode=mode)))
|
||||
print("-"*100, "\n")
|
||||
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
@ -59,6 +59,7 @@ class BaseVectorStorage(StorageNameSpace):
|
||||
|
||||
@dataclass
|
||||
class BaseKVStorage(Generic[T], StorageNameSpace):
|
||||
embedding_func: EmbeddingFunc
|
||||
async def all_keys(self) -> list[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -83,6 +84,7 @@ class BaseKVStorage(Generic[T], StorageNameSpace):
|
||||
|
||||
@dataclass
|
||||
class BaseGraphStorage(StorageNameSpace):
|
||||
embedding_func: EmbeddingFunc
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
|
767
lightrag/kg/oracle_impl.py
Normal file
767
lightrag/kg/oracle_impl.py
Normal file
@ -0,0 +1,767 @@
|
||||
import asyncio
|
||||
#import html
|
||||
#import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Union, cast
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import array
|
||||
|
||||
from ..utils import logger
|
||||
from ..base import (
|
||||
BaseGraphStorage,
|
||||
BaseKVStorage,
|
||||
BaseVectorStorage,
|
||||
)
|
||||
|
||||
import oracledb
|
||||
|
||||
class OracleDB:
|
||||
def __init__(self,config,**kwargs):
|
||||
self.host = config.get("host", None)
|
||||
self.port = config.get("port", None)
|
||||
self.user = config.get("user", None)
|
||||
self.password = config.get("password", None)
|
||||
self.dsn = config.get("dsn", None)
|
||||
self.config_dir = config.get("config_dir", None)
|
||||
self.wallet_location = config.get("wallet_location", None)
|
||||
self.wallet_password = config.get("wallet_password", None)
|
||||
self.workspace = config.get("workspace", None)
|
||||
self.max = 12
|
||||
self.increment = 1
|
||||
logger.info(f"Using the label {self.workspace} for Oracle Graph as identifier")
|
||||
if self.user is None or self.password is None:
|
||||
raise ValueError("Missing database user or password in addon_params")
|
||||
|
||||
try:
|
||||
oracledb.defaults.fetch_lobs = False
|
||||
|
||||
self.pool = oracledb.create_pool_async(
|
||||
user = self.user,
|
||||
password = self.password,
|
||||
dsn = self.dsn,
|
||||
config_dir = self.config_dir,
|
||||
wallet_location = self.wallet_location,
|
||||
wallet_password = self.wallet_password,
|
||||
min = 1,
|
||||
max = self.max,
|
||||
increment = self.increment
|
||||
)
|
||||
logger.info(f"Connected to Oracle database at {self.dsn}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Oracle database at {self.dsn}")
|
||||
logger.error(f"Oracle database error: {e}")
|
||||
raise
|
||||
|
||||
def numpy_converter_in(self, value):
|
||||
"""Convert numpy array to array.array"""
|
||||
if value.dtype == np.float64:
|
||||
dtype = "d"
|
||||
elif value.dtype == np.float32:
|
||||
dtype = "f"
|
||||
else:
|
||||
dtype = "b"
|
||||
return array.array(dtype, value)
|
||||
|
||||
def input_type_handler(self, cursor, value, arraysize):
|
||||
"""Set the type handler for the input data"""
|
||||
if isinstance(value, np.ndarray):
|
||||
return cursor.var(
|
||||
oracledb.DB_TYPE_VECTOR,
|
||||
arraysize=arraysize,
|
||||
inconverter=self.numpy_converter_in,
|
||||
)
|
||||
|
||||
def numpy_converter_out(self, value):
|
||||
"""Convert array.array to numpy array"""
|
||||
if value.typecode == "b":
|
||||
dtype = np.int8
|
||||
elif value.typecode == "f":
|
||||
dtype = np.float32
|
||||
else:
|
||||
dtype = np.float64
|
||||
return np.array(value, copy=False, dtype=dtype)
|
||||
|
||||
def output_type_handler(self, cursor, metadata):
|
||||
"""Set the type handler for the output data"""
|
||||
if metadata.type_code is oracledb.DB_TYPE_VECTOR:
|
||||
return cursor.var(
|
||||
metadata.type_code,
|
||||
arraysize=cursor.arraysize,
|
||||
outconverter=self.numpy_converter_out,
|
||||
)
|
||||
|
||||
async def check_tables(self):
|
||||
for k,v in TABLES.items():
|
||||
try:
|
||||
if k.lower() == "lightrag_graph":
|
||||
await self.query("SELECT id FROM GRAPH_TABLE (lightrag_graph MATCH (a) COLUMNS (a.id)) fetch first row only")
|
||||
else:
|
||||
await self.query("SELECT 1 FROM {k}".format(k=k))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check table {k} in Oracle database")
|
||||
logger.error(f"Oracle database error: {e}")
|
||||
try:
|
||||
# print(v["ddl"])
|
||||
await self.execute(v["ddl"])
|
||||
logger.info(f"Created table {k} in Oracle database")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create table {k} in Oracle database")
|
||||
logger.error(f"Oracle database error: {e}")
|
||||
|
||||
logger.info(f"Finished check all tables in Oracle database")
|
||||
|
||||
|
||||
async def query(self,sql: str, multirows: bool = False) -> Union[dict, None]:
|
||||
async with self.pool.acquire() as connection:
|
||||
connection.inputtypehandler = self.input_type_handler
|
||||
connection.outputtypehandler = self.output_type_handler
|
||||
with connection.cursor() as cursor:
|
||||
try:
|
||||
await cursor.execute(sql)
|
||||
except Exception as e:
|
||||
logger.error(f"Oracle database error: {e}")
|
||||
print(sql)
|
||||
raise
|
||||
columns = [column[0].lower() for column in cursor.description]
|
||||
if multirows:
|
||||
rows = await cursor.fetchall()
|
||||
if rows:
|
||||
data = [dict(zip(columns, row)) for row in rows]
|
||||
else:
|
||||
data = []
|
||||
else:
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
data = dict(zip(columns, row))
|
||||
else:
|
||||
data = None
|
||||
return data
|
||||
|
||||
async def execute(self,sql: str, data: list = None):
|
||||
# logger.info("go into OracleDB execute method")
|
||||
try:
|
||||
async with self.pool.acquire() as connection:
|
||||
connection.inputtypehandler = self.input_type_handler
|
||||
connection.outputtypehandler = self.output_type_handler
|
||||
with connection.cursor() as cursor:
|
||||
if data is None:
|
||||
await cursor.execute(sql)
|
||||
else:
|
||||
#print(data)
|
||||
#print(sql)
|
||||
await cursor.execute(sql,data)
|
||||
await connection.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Oracle database error: {e}")
|
||||
print(sql)
|
||||
print(data)
|
||||
raise
|
||||
|
||||
@dataclass
|
||||
class OracleKVStorage(BaseKVStorage):
|
||||
|
||||
# should pass db object to self.db
|
||||
def __post_init__(self):
|
||||
self._data = {}
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
|
||||
################ QUERY METHODS ################
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[dict, None]:
|
||||
"""根据 id 获取 doc_full 数据."""
|
||||
SQL = SQL_TEMPLATES["get_by_id_"+self.namespace].format(workspace=self.db.workspace,id=id)
|
||||
#print("get_by_id:"+SQL)
|
||||
res = await self.db.query(SQL)
|
||||
if res:
|
||||
data = res #{"data":res}
|
||||
#print (data)
|
||||
return data
|
||||
else:
|
||||
return None
|
||||
|
||||
# Query by id
|
||||
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict],None]:
|
||||
"""根据 id 获取 doc_chunks 数据"""
|
||||
SQL = SQL_TEMPLATES["get_by_ids_"+self.namespace].format(workspace=self.db.workspace,
|
||||
ids=",".join([f"'{id}'" for id in ids]))
|
||||
#print("get_by_ids:"+SQL)
|
||||
res = await self.db.query(SQL,multirows=True)
|
||||
if res:
|
||||
data = res # [{"data":i} for i in res]
|
||||
#print(data)
|
||||
return data
|
||||
else:
|
||||
return None
|
||||
|
||||
async def filter_keys(self, keys: list[str]) -> set[str]:
|
||||
"""过滤掉重复内容"""
|
||||
SQL = SQL_TEMPLATES["filter_keys"].format(table_name=N_T[self.namespace],
|
||||
workspace=self.db.workspace,
|
||||
ids=",".join([f"'{k}'" for k in keys]))
|
||||
res = await self.db.query(SQL,multirows=True)
|
||||
data = None
|
||||
if res:
|
||||
exist_keys = [key["id"] for key in res]
|
||||
data = set([s for s in keys if s not in exist_keys])
|
||||
else:
|
||||
exist_keys = []
|
||||
data = set([s for s in keys if s not in exist_keys])
|
||||
return data
|
||||
|
||||
|
||||
################ INSERT METHODS ################
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
left_data = {k: v for k, v in data.items() if k not in self._data}
|
||||
self._data.update(left_data)
|
||||
#print(self._data)
|
||||
#values = []
|
||||
if self.namespace == "text_chunks":
|
||||
list_data = [
|
||||
{
|
||||
"__id__": k,
|
||||
**{k1: v1 for k1, v1 in v.items()},
|
||||
}
|
||||
for k, v in data.items()
|
||||
]
|
||||
contents = [v["content"] for v in data.values()]
|
||||
batches = [
|
||||
contents[i: i + self._max_batch_size]
|
||||
for i in range(0, len(contents), self._max_batch_size)
|
||||
]
|
||||
embeddings_list = await asyncio.gather(
|
||||
*[self.embedding_func(batch) for batch in batches]
|
||||
)
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
for i, d in enumerate(list_data):
|
||||
d["__vector__"] = embeddings[i]
|
||||
#print(list_data)
|
||||
for item in list_data:
|
||||
merge_sql = SQL_TEMPLATES["merge_chunk"].format(
|
||||
check_id=item["__id__"]
|
||||
)
|
||||
|
||||
values = [item["__id__"], item["content"], self.db.workspace, item["tokens"],
|
||||
item["chunk_order_index"], item["full_doc_id"], item["__vector__"]]
|
||||
#print(merge_sql)
|
||||
await self.db.execute(merge_sql, values)
|
||||
|
||||
if self.namespace == "full_docs":
|
||||
for k, v in self._data.items():
|
||||
#values.clear()
|
||||
merge_sql = SQL_TEMPLATES["merge_doc_full"].format(
|
||||
check_id=k,
|
||||
)
|
||||
values = [k, self._data[k]["content"], self.db.workspace]
|
||||
#print(merge_sql)
|
||||
await self.db.execute(merge_sql, values)
|
||||
return left_data
|
||||
|
||||
|
||||
async def index_done_callback(self):
|
||||
if self.namespace in ["full_docs", "text_chunks"]:
|
||||
logger.info("full doc and chunk data had been saved into oracle db!")
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class OracleVectorDBStorage(BaseVectorStorage):
|
||||
cosine_better_than_threshold: float = 0.2
|
||||
|
||||
def __post_init__(self):
|
||||
pass
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
"""向向量数据库中插入数据"""
|
||||
pass
|
||||
|
||||
async def index_done_callback(self):
|
||||
pass
|
||||
|
||||
|
||||
#################### query method ################
|
||||
async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]:
|
||||
"""从向量数据库中查询数据"""
|
||||
embeddings = await self.embedding_func([query])
|
||||
embedding = embeddings[0]
|
||||
# 转换精度
|
||||
dtype = str(embedding.dtype).upper()
|
||||
dimension = embedding.shape[0]
|
||||
embedding_string = ', '.join(map(str, embedding.tolist()))
|
||||
|
||||
SQL = SQL_TEMPLATES[self.namespace].format(
|
||||
embedding_string=embedding_string,
|
||||
dimension=dimension,
|
||||
dtype=dtype,
|
||||
workspace=self.db.workspace,
|
||||
top_k=top_k,
|
||||
better_than_threshold=self.cosine_better_than_threshold,
|
||||
)
|
||||
# print(SQL)
|
||||
results = await self.db.query(SQL, multirows=True)
|
||||
#print("vector search result:",results)
|
||||
return results
|
||||
|
||||
|
||||
@dataclass
|
||||
class OracleGraphStorage(BaseGraphStorage):
|
||||
"""基于Oracle的图存储模块"""
|
||||
# @staticmethod
|
||||
# def load_graph(file_name) -> nx.Graph:
|
||||
# """读取graphhml图文件"""
|
||||
|
||||
# @staticmethod
|
||||
# def write_graph(graph: nx.Graph, file_name):
|
||||
# # """写入graphhml图文件"""
|
||||
|
||||
# @staticmethod
|
||||
# def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
|
||||
# """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
|
||||
# Return the largest connected component of the graph, with nodes and edges sorted in a stable way.
|
||||
# 用于产生稳定的最大连通分量的模块,即相同的输入图==相同的输出lcc。
|
||||
# """
|
||||
|
||||
|
||||
# @staticmethod
|
||||
# def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
|
||||
# """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
|
||||
# Ensure an undirected graph with the same relationships will always be read the same way.
|
||||
# 确保具有相同关系的无向图始终以相同的方式读取。
|
||||
# """
|
||||
|
||||
def __post_init__(self):
|
||||
"""从graphml文件加载图"""
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
|
||||
|
||||
#################### insert method ################
|
||||
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
||||
"""插入或更新节点"""
|
||||
#print("go into upsert node method")
|
||||
entity_name = node_id
|
||||
entity_type = node_data["entity_type"]
|
||||
description = node_data["description"]
|
||||
source_id = node_data["source_id"]
|
||||
content = entity_name+description
|
||||
contents = [content]
|
||||
batches = [
|
||||
contents[i: i + self._max_batch_size]
|
||||
for i in range(0, len(contents), self._max_batch_size)
|
||||
]
|
||||
embeddings_list = await asyncio.gather(
|
||||
*[self.embedding_func(batch) for batch in batches]
|
||||
)
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
content_vector = embeddings[0]
|
||||
merge_sql = SQL_TEMPLATES["merge_node"].format(
|
||||
workspace=self.db.workspace,name=entity_name, source_chunk_id=source_id
|
||||
)
|
||||
#print(merge_sql)
|
||||
await self.db.execute(merge_sql, [self.db.workspace,entity_name,entity_type,description,source_id,content,content_vector])
|
||||
#self._graph.add_node(node_id, **node_data)
|
||||
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
):
|
||||
"""插入或更新边"""
|
||||
#print("go into upsert edge method")
|
||||
source_name = source_node_id
|
||||
target_name = target_node_id
|
||||
weight = edge_data["weight"]
|
||||
keywords = edge_data["keywords"]
|
||||
description = edge_data["description"]
|
||||
source_chunk_id = edge_data["source_id"]
|
||||
content = keywords+source_name+target_name+description
|
||||
contents = [content]
|
||||
batches = [
|
||||
contents[i: i + self._max_batch_size]
|
||||
for i in range(0, len(contents), self._max_batch_size)
|
||||
]
|
||||
embeddings_list = await asyncio.gather(
|
||||
*[self.embedding_func(batch) for batch in batches]
|
||||
)
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
content_vector = embeddings[0]
|
||||
merge_sql = SQL_TEMPLATES["merge_edge"].format(
|
||||
workspace=self.db.workspace,source_name=source_name, target_name=target_name, source_chunk_id=source_chunk_id
|
||||
)
|
||||
#print(merge_sql)
|
||||
await self.db.execute(merge_sql, [self.db.workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector])
|
||||
#self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
||||
|
||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
||||
"""为节点生成向量"""
|
||||
if algorithm not in self._node_embed_algorithms:
|
||||
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
||||
return await self._node_embed_algorithms[algorithm]()
|
||||
|
||||
async def _node2vec_embed(self):
|
||||
"""为节点生成向量"""
|
||||
from graspologic import embed
|
||||
|
||||
embeddings, nodes = embed.node2vec_embed(
|
||||
self._graph,
|
||||
**self.config["node2vec_params"],
|
||||
)
|
||||
|
||||
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
|
||||
return embeddings, nodes_ids
|
||||
|
||||
|
||||
async def index_done_callback(self):
|
||||
"""写入graphhml图文件"""
|
||||
logger.info("Node and edge data had been saved into oracle db already, so nothing to do here!")
|
||||
|
||||
#################### query method ################
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
"""根据节点id检查节点是否存在"""
|
||||
SQL = SQL_TEMPLATES["has_node"].format(workspace=self.db.workspace, node_id=node_id)
|
||||
# print(SQL)
|
||||
#print(self.db.workspace, node_id)
|
||||
res = await self.db.query(SQL)
|
||||
if res:
|
||||
#print("Node exist!",res)
|
||||
return True
|
||||
else:
|
||||
#print("Node not exist!")
|
||||
return False
|
||||
|
||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||
"""根据源和目标节点id检查边是否存在"""
|
||||
SQL = SQL_TEMPLATES["has_edge"].format(workspace=self.db.workspace,
|
||||
source_node_id=source_node_id,
|
||||
target_node_id=target_node_id)
|
||||
# print(SQL)
|
||||
res = await self.db.query(SQL)
|
||||
if res:
|
||||
#print("Edge exist!",res)
|
||||
return True
|
||||
else:
|
||||
#print("Edge not exist!")
|
||||
return False
|
||||
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
"""根据节点id获取节点的度"""
|
||||
SQL = SQL_TEMPLATES["node_degree"].format(workspace=self.db.workspace, node_id=node_id)
|
||||
# print(SQL)
|
||||
res = await self.db.query(SQL)
|
||||
if res:
|
||||
#print("Node degree",res["degree"])
|
||||
return res["degree"]
|
||||
else:
|
||||
#print("Edge not exist!")
|
||||
return 0
|
||||
|
||||
|
||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||
"""根据源和目标节点id获取边的度"""
|
||||
degree = await self.node_degree(src_id) + await self.node_degree(tgt_id)
|
||||
#print("Edge degree",degree)
|
||||
return degree
|
||||
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
"""根据节点id获取节点数据"""
|
||||
SQL = SQL_TEMPLATES["get_node"].format(workspace=self.db.workspace, node_id=node_id)
|
||||
# print(self.db.workspace, node_id)
|
||||
# print(SQL)
|
||||
res = await self.db.query(SQL)
|
||||
if res:
|
||||
#print("Get node!",self.db.workspace, node_id,res)
|
||||
return res
|
||||
else:
|
||||
#print("Can't get node!",self.db.workspace, node_id)
|
||||
return None
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
"""根据源和目标节点id获取边"""
|
||||
SQL = SQL_TEMPLATES["get_edge"].format(workspace=self.db.workspace,
|
||||
source_node_id=source_node_id,
|
||||
target_node_id=target_node_id)
|
||||
res = await self.db.query(SQL)
|
||||
if res:
|
||||
#print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
|
||||
return res
|
||||
else:
|
||||
#print("Edge not exist!",self.db.workspace, source_node_id, target_node_id)
|
||||
return None
|
||||
|
||||
async def get_node_edges(self, source_node_id: str):
|
||||
"""根据节点id获取节点的所有边"""
|
||||
if await self.has_node(source_node_id):
|
||||
SQL = SQL_TEMPLATES["get_node_edges"].format(workspace=self.db.workspace,
|
||||
source_node_id=source_node_id)
|
||||
res = await self.db.query(sql=SQL, multirows=True)
|
||||
if res:
|
||||
data = [(i["source_name"],i["target_name"]) for i in res]
|
||||
#print("Get node edge!",self.db.workspace, source_node_id,data)
|
||||
return data
|
||||
else:
|
||||
#print("Node Edge not exist!",self.db.workspace, source_node_id)
|
||||
return []
|
||||
|
||||
#################### INSERT method ################
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
||||
"""插入或更新节点"""
|
||||
#print("go into upsert node method")
|
||||
entity_name = node_id
|
||||
entity_type = node_data["entity_type"]
|
||||
description = node_data["description"]
|
||||
source_id = node_data["source_id"]
|
||||
content = entity_name+description
|
||||
contents = [content]
|
||||
batches = [
|
||||
contents[i: i + self._max_batch_size]
|
||||
for i in range(0, len(contents), self._max_batch_size)
|
||||
]
|
||||
embeddings_list = await asyncio.gather(
|
||||
*[self.embedding_func(batch) for batch in batches]
|
||||
)
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
content_vector = embeddings[0]
|
||||
merge_sql = SQL_TEMPLATES["merge_node"].format(
|
||||
workspace=self.db.workspace,name=entity_name, source_chunk_id=source_id
|
||||
)
|
||||
#print(merge_sql)
|
||||
await self.db.execute(merge_sql, [self.db.workspace,entity_name,entity_type,description,source_id,content,content_vector])
|
||||
#self._graph.add_node(node_id, **node_data)
|
||||
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
):
|
||||
"""插入或更新边"""
|
||||
#print("go into upsert edge method")
|
||||
source_name = source_node_id
|
||||
target_name = target_node_id
|
||||
weight = edge_data["weight"]
|
||||
keywords = edge_data["keywords"]
|
||||
description = edge_data["description"]
|
||||
source_chunk_id = edge_data["source_id"]
|
||||
content = keywords+source_name+target_name+description
|
||||
contents = [content]
|
||||
batches = [
|
||||
contents[i: i + self._max_batch_size]
|
||||
for i in range(0, len(contents), self._max_batch_size)
|
||||
]
|
||||
embeddings_list = await asyncio.gather(
|
||||
*[self.embedding_func(batch) for batch in batches]
|
||||
)
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
content_vector = embeddings[0]
|
||||
merge_sql = SQL_TEMPLATES["merge_edge"].format(
|
||||
workspace=self.db.workspace,source_name=source_name, target_name=target_name, source_chunk_id=source_chunk_id
|
||||
)
|
||||
#print(merge_sql)
|
||||
await self.db.execute(merge_sql, [self.db.workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector])
|
||||
#self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
||||
|
||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
||||
"""为节点生成向量"""
|
||||
if algorithm not in self._node_embed_algorithms:
|
||||
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
||||
return await self._node_embed_algorithms[algorithm]()
|
||||
|
||||
async def _node2vec_embed(self):
|
||||
"""为节点生成向量"""
|
||||
from graspologic import embed
|
||||
|
||||
embeddings, nodes = embed.node2vec_embed(
|
||||
self._graph,
|
||||
**self.config["node2vec_params"],
|
||||
)
|
||||
|
||||
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
|
||||
return embeddings, nodes_ids
|
||||
|
||||
|
||||
N_T = {
|
||||
"full_docs": "LIGHTRAG_DOC_FULL",
|
||||
"text_chunks": "LIGHTRAG_DOC_CHUNKS",
|
||||
"chunks": "LIGHTRAG_DOC_CHUNKS",
|
||||
"entities": "LIGHTRAG_GRAPH_NODES",
|
||||
"relationships": "LIGHTRAG_GRAPH_EDGES"
|
||||
}
|
||||
|
||||
TABLES = {
|
||||
"LIGHTRAG_DOC_FULL":
|
||||
{"ddl":"""CREATE TABLE LIGHTRAG_DOC_FULL (
|
||||
id varchar(256)PRIMARY KEY,
|
||||
workspace varchar(1024),
|
||||
doc_name varchar(1024),
|
||||
content CLOB,
|
||||
meta JSON
|
||||
)"""},
|
||||
|
||||
"LIGHTRAG_DOC_CHUNKS":
|
||||
{"ddl":"""CREATE TABLE LIGHTRAG_DOC_CHUNKS (
|
||||
id varchar(256) PRIMARY KEY,
|
||||
workspace varchar(1024),
|
||||
full_doc_id varchar(256),
|
||||
chunk_order_index NUMBER,
|
||||
tokens NUMBER,
|
||||
content CLOB,
|
||||
content_vector VECTOR
|
||||
)"""},
|
||||
|
||||
"LIGHTRAG_GRAPH_NODES":
|
||||
{"ddl":"""CREATE TABLE LIGHTRAG_GRAPH_NODES (
|
||||
id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY,
|
||||
workspace varchar(1024),
|
||||
name varchar(2048),
|
||||
entity_type varchar(1024),
|
||||
description CLOB,
|
||||
source_chunk_id varchar(256),
|
||||
content CLOB,
|
||||
content_vector VECTOR
|
||||
)"""},
|
||||
"LIGHTRAG_GRAPH_EDGES":
|
||||
{"ddl":"""CREATE TABLE LIGHTRAG_GRAPH_EDGES (
|
||||
id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY,
|
||||
workspace varchar(1024),
|
||||
source_name varchar(2048),
|
||||
target_name varchar(2048),
|
||||
weight NUMBER,
|
||||
keywords CLOB,
|
||||
description CLOB,
|
||||
source_chunk_id varchar(256),
|
||||
content CLOB,
|
||||
content_vector VECTOR
|
||||
)"""},
|
||||
"LIGHTRAG_LLM_CACHE":
|
||||
{"ddl":"""CREATE TABLE LIGHTRAG_LLM_CACHE (
|
||||
id varchar(256) PRIMARY KEY,
|
||||
return clob,
|
||||
model varchar(1024)
|
||||
)"""},
|
||||
|
||||
"LIGHTRAG_GRAPH":
|
||||
{"ddl":"""CREATE OR REPLACE PROPERTY GRAPH lightrag_graph
|
||||
VERTEX TABLES (
|
||||
lightrag_graph_nodes KEY (id)
|
||||
LABEL entity
|
||||
PROPERTIES (id,workspace,name) -- ,entity_type,description,source_chunk_id)
|
||||
)
|
||||
EDGE TABLES (
|
||||
lightrag_graph_edges KEY (id)
|
||||
SOURCE KEY (source_name) REFERENCES lightrag_graph_nodes(name)
|
||||
DESTINATION KEY (target_name) REFERENCES lightrag_graph_nodes(name)
|
||||
LABEL has_relation
|
||||
PROPERTIES (id,workspace,source_name,target_name) -- ,weight, keywords,description,source_chunk_id)
|
||||
) OPTIONS(ALLOW MIXED PROPERTY TYPES)"""},
|
||||
}
|
||||
|
||||
|
||||
SQL_TEMPLATES = {
|
||||
# SQL for KVStorage
|
||||
"get_by_id_full_docs":
|
||||
"select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID='{id}'",
|
||||
|
||||
"get_by_id_text_chunks":
|
||||
"select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID='{id}'",
|
||||
|
||||
"get_by_ids_full_docs":
|
||||
"select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID in ({ids})",
|
||||
|
||||
"get_by_ids_text_chunks":
|
||||
"select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID in ({ids})",
|
||||
|
||||
"filter_keys":
|
||||
"select id from {table_name} where workspace='{workspace}' and id in ({ids})",
|
||||
|
||||
"merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a
|
||||
USING DUAL
|
||||
ON (a.id = '{check_id}')
|
||||
WHEN NOT MATCHED THEN
|
||||
INSERT(id,content,workspace) values(:1,:2,:3)
|
||||
""",
|
||||
|
||||
"merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a
|
||||
USING DUAL
|
||||
ON (a.id = '{check_id}')
|
||||
WHEN NOT MATCHED THEN
|
||||
INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector)
|
||||
values (:1,:2,:3,:4,:5,:6,:7) """,
|
||||
|
||||
# SQL for VectorStorage
|
||||
"entities":
|
||||
"""SELECT name as entity_name FROM
|
||||
(SELECT id,name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
|
||||
FROM LIGHTRAG_GRAPH_NODES WHERE workspace='{workspace}')
|
||||
WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
|
||||
|
||||
"relationships":
|
||||
"""SELECT source_name as src_id, target_name as tgt_id FROM
|
||||
(SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
|
||||
FROM LIGHTRAG_GRAPH_EDGES WHERE workspace='{workspace}')
|
||||
WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
|
||||
|
||||
"chunks":
|
||||
"""SELECT id FROM
|
||||
(SELECT id,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
|
||||
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace='{workspace}')
|
||||
WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
|
||||
|
||||
# SQL for GraphStorage
|
||||
"has_node":
|
||||
"""SELECT * FROM GRAPH_TABLE (lightrag_graph
|
||||
MATCH (a)
|
||||
WHERE a.workspace='{workspace}' AND a.name='{node_id}'
|
||||
COLUMNS (a.name))""",
|
||||
|
||||
"has_edge":
|
||||
"""SELECT * FROM GRAPH_TABLE (lightrag_graph
|
||||
MATCH (a) -[e]-> (b)
|
||||
WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
|
||||
AND a.name='{source_node_id}' AND b.name='{target_node_id}'
|
||||
COLUMNS (e.source_name,e.target_name) )""",
|
||||
|
||||
"node_degree":
|
||||
"""SELECT count(1) as degree FROM GRAPH_TABLE (lightrag_graph
|
||||
MATCH (a)-[e]->(b)
|
||||
WHERE a.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
|
||||
AND a.name='{node_id}' or b.name = '{node_id}'
|
||||
COLUMNS (a.name))""",
|
||||
|
||||
"get_node":
|
||||
"""SELECT t1.name,t2.entity_type,t2.source_chunk_id as source_id,NVL(t2.description,'') AS description
|
||||
FROM GRAPH_TABLE (lightrag_graph
|
||||
MATCH (a)
|
||||
WHERE a.workspace='{workspace}' AND a.name='{node_id}'
|
||||
COLUMNS (a.name)
|
||||
) t1 JOIN LIGHTRAG_GRAPH_NODES t2 on t1.name=t2.name
|
||||
WHERE t2.workspace='{workspace}'""",
|
||||
|
||||
"get_edge":
|
||||
"""SELECT t1.source_id,t2.weight,t2.source_chunk_id as source_id,t2.keywords,
|
||||
NVL(t2.description,'') AS description,NVL(t2.KEYWORDS,'') AS keywords
|
||||
FROM GRAPH_TABLE (lightrag_graph
|
||||
MATCH (a)-[e]->(b)
|
||||
WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
|
||||
AND a.name='{source_node_id}' and b.name = '{target_node_id}'
|
||||
COLUMNS (e.id,a.name as source_id)
|
||||
) t1 JOIN LIGHTRAG_GRAPH_EDGES t2 on t1.id=t2.id""",
|
||||
|
||||
"get_node_edges":
|
||||
"""SELECT source_name,target_name
|
||||
FROM GRAPH_TABLE (lightrag_graph
|
||||
MATCH (a)-[e]->(b)
|
||||
WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
|
||||
AND a.name='{source_node_id}'
|
||||
COLUMNS (a.name as source_name,b.name as target_name))""",
|
||||
|
||||
"merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a
|
||||
USING DUAL
|
||||
ON (a.workspace = '{workspace}' and a.name='{name}' and a.source_chunk_id='{source_chunk_id}')
|
||||
WHEN NOT MATCHED THEN
|
||||
INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector)
|
||||
values (:1,:2,:3,:4,:5,:6,:7) """,
|
||||
"merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a
|
||||
USING DUAL
|
||||
ON (a.workspace = '{workspace}' and a.source_name='{source_name}' and a.target_name='{target_name}' and a.source_chunk_id='{source_chunk_id}')
|
||||
WHEN NOT MATCHED THEN
|
||||
INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
|
||||
values (:1,:2,:3,:4,:5,:6,:7,:8,:9) """
|
||||
}
|
@ -18,20 +18,6 @@ from .operate import (
|
||||
naive_query,
|
||||
)
|
||||
|
||||
from .storage import (
|
||||
JsonKVStorage,
|
||||
NanoVectorDBStorage,
|
||||
NetworkXStorage,
|
||||
)
|
||||
|
||||
from .kg.neo4j_impl import Neo4JStorage
|
||||
# future KG integrations
|
||||
|
||||
# from .kg.ArangoDB_impl import (
|
||||
# GraphStorage as ArangoDBStorage
|
||||
# )
|
||||
|
||||
|
||||
from .utils import (
|
||||
EmbeddingFunc,
|
||||
compute_mdhash_id,
|
||||
@ -49,6 +35,26 @@ from .base import (
|
||||
)
|
||||
|
||||
|
||||
from .storage import (
|
||||
JsonKVStorage,
|
||||
NanoVectorDBStorage,
|
||||
NetworkXStorage,
|
||||
)
|
||||
|
||||
from .kg.neo4j_impl import Neo4JStorage
|
||||
|
||||
from .kg.oracle_impl import (
|
||||
OracleKVStorage,
|
||||
OracleGraphStorage,
|
||||
OracleVectorDBStorage
|
||||
)
|
||||
|
||||
# future KG integrations
|
||||
|
||||
# from .kg.ArangoDB_impl import (
|
||||
# GraphStorage as ArangoDBStorage
|
||||
# )
|
||||
|
||||
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||
try:
|
||||
return asyncio.get_event_loop()
|
||||
@ -68,7 +74,9 @@ class LightRAG:
|
||||
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
|
||||
)
|
||||
|
||||
kg: str = field(default="NetworkXStorage")
|
||||
kv_storage : str = field(default="JsonKVStorage")
|
||||
vector_storage: str = field(default="NanoVectorDBStorage")
|
||||
graph_storage: str = field(default="NetworkXStorage")
|
||||
|
||||
current_log_level = logger.level
|
||||
log_level: str = field(default=current_log_level)
|
||||
@ -108,9 +116,16 @@ class LightRAG:
|
||||
llm_model_kwargs: dict = field(default_factory=dict)
|
||||
|
||||
# storage
|
||||
key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
|
||||
vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage
|
||||
|
||||
vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
|
||||
# if DATABASE_TYPE is None:
|
||||
# key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
|
||||
# vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage
|
||||
# vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
|
||||
# elif DATABASE_TYPE == "oracle":
|
||||
# key_string_value_json_storage_cls: Type[BaseKVStorage] = OracleKVStorage,
|
||||
# vector_db_storage_cls: Type[BaseVectorStorage] = OracleVectorDBStorage,
|
||||
|
||||
enable_llm_cache: bool = True
|
||||
|
||||
# extension
|
||||
@ -128,21 +143,16 @@ class LightRAG:
|
||||
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
||||
|
||||
# @TODO: should move all storage setup here to leverage initial start params attached to self.
|
||||
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[
|
||||
self.kg
|
||||
]
|
||||
|
||||
self. key_string_value_json_storage_cls: Type[BaseKVStorage] = self._get_storage_class()[self.kv_storage]
|
||||
|
||||
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[self.graph_storage]
|
||||
|
||||
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[self.vector_storage]
|
||||
if not os.path.exists(self.working_dir):
|
||||
logger.info(f"Creating working directory {self.working_dir}")
|
||||
os.makedirs(self.working_dir)
|
||||
|
||||
self.full_docs = self.key_string_value_json_storage_cls(
|
||||
namespace="full_docs", global_config=asdict(self)
|
||||
)
|
||||
|
||||
self.text_chunks = self.key_string_value_json_storage_cls(
|
||||
namespace="text_chunks", global_config=asdict(self)
|
||||
)
|
||||
|
||||
self.llm_response_cache = (
|
||||
self.key_string_value_json_storage_cls(
|
||||
@ -151,14 +161,27 @@ class LightRAG:
|
||||
if self.enable_llm_cache
|
||||
else None
|
||||
)
|
||||
self.chunk_entity_relation_graph = self.graph_storage_cls(
|
||||
namespace="chunk_entity_relation", global_config=asdict(self)
|
||||
)
|
||||
|
||||
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
|
||||
self.embedding_func
|
||||
)
|
||||
|
||||
####
|
||||
# add embedding func by walter
|
||||
####
|
||||
self.full_docs = self.key_string_value_json_storage_cls(
|
||||
namespace="full_docs", global_config=asdict(self), embedding_func=self.embedding_func
|
||||
)
|
||||
self.text_chunks = self.key_string_value_json_storage_cls(
|
||||
namespace="text_chunks", global_config=asdict(self), embedding_func=self.embedding_func
|
||||
)
|
||||
self.chunk_entity_relation_graph = self.graph_storage_cls(
|
||||
namespace="chunk_entity_relation", global_config=asdict(self), embedding_func=self.embedding_func
|
||||
)
|
||||
####
|
||||
# add embedding func by walter over
|
||||
####
|
||||
|
||||
self.entities_vdb = self.vector_db_storage_cls(
|
||||
namespace="entities",
|
||||
global_config=asdict(self),
|
||||
@ -187,8 +210,15 @@ class LightRAG:
|
||||
|
||||
def _get_storage_class(self) -> Type[BaseGraphStorage]:
|
||||
return {
|
||||
"JsonKVStorage":JsonKVStorage,
|
||||
"OracleKVStorage":OracleKVStorage,
|
||||
|
||||
"NanoVectorDBStorage":NanoVectorDBStorage,
|
||||
"OracleVectorDBStorage":OracleVectorDBStorage,
|
||||
|
||||
"Neo4JStorage": Neo4JStorage,
|
||||
"NetworkXStorage": NetworkXStorage,
|
||||
"OracleGraphStorage": OracleGraphStorage,
|
||||
# "ArangoDBStorage": ArangoDBStorage
|
||||
}
|
||||
|
||||
|
@ -222,14 +222,24 @@ Output:
|
||||
|
||||
"""
|
||||
|
||||
PROMPTS["naive_rag_response"] = """You're a helpful assistant
|
||||
Below are the knowledge you know:
|
||||
{content_data}
|
||||
---
|
||||
If you don't know the answer or if the provided knowledge do not contain sufficient information to provide an answer, just say so. Do not make anything up.
|
||||
PROMPTS["naive_rag_response"] = """---Role---
|
||||
|
||||
You are a helpful assistant responding to questions about documents provided.
|
||||
|
||||
|
||||
---Goal---
|
||||
|
||||
Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.
|
||||
If you don't know the answer, just say so. Do not make anything up.
|
||||
Do not include information where the supporting evidence for it is not provided.
|
||||
|
||||
---Target response length and format---
|
||||
|
||||
{response_type}
|
||||
|
||||
---Documents---
|
||||
|
||||
{content_data}
|
||||
|
||||
Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown.
|
||||
"""
|
||||
|
@ -15,3 +15,4 @@ torch
|
||||
transformers
|
||||
xxhash
|
||||
# lmdeploy[all]
|
||||
oracledb
|
||||
|
Loading…
x
Reference in New Issue
Block a user