Centralize query parameters into LightRAG class

This commit refactors query parameter management by consolidating settings like `top_k`, token limits, and thresholds into the `LightRAG` class, and consistently sourcing parameters from a single location.
This commit is contained in:
yangdx 2025-07-15 23:56:49 +08:00
parent 3ead0489b8
commit 5f7cb437e8
5 changed files with 78 additions and 50 deletions

View File

@ -62,6 +62,8 @@ ENABLE_LLM_CACHE=true
# MAX_RELATION_TOKENS=10000
### control the maximum tokens send to LLM (include entities, raltions and chunks)
# MAX_TOTAL_TOKENS=32000
### maxumium related chunks grab from single entity or relations
# RELATED_CHUNK_NUMBER=5
### Reranker configuration (Set ENABLE_RERANK to true in reranking model is configed)
ENABLE_RERANK=False

View File

@ -14,6 +14,11 @@ from lightrag.constants import (
DEFAULT_TOP_K,
DEFAULT_CHUNK_TOP_K,
DEFAULT_HISTORY_TURNS,
DEFAULT_MAX_ENTITY_TOKENS,
DEFAULT_MAX_RELATION_TOKENS,
DEFAULT_MAX_TOTAL_TOKENS,
DEFAULT_COSINE_THRESHOLD,
DEFAULT_RELATED_CHUNK_NUMBER,
)
# use the .env that is inside the current folder
@ -154,33 +159,6 @@ def parse_args() -> argparse.Namespace:
help="Path to SSL private key file (required if --ssl is enabled)",
)
parser.add_argument(
"--history-turns",
type=int,
default=get_env_value("HISTORY_TURNS", DEFAULT_HISTORY_TURNS, int),
help="Number of conversation history turns to include (default: from env or 3)",
)
# Search parameters
parser.add_argument(
"--top-k",
type=int,
default=get_env_value("TOP_K", DEFAULT_TOP_K, int),
help="Number of most similar results to return (default: from env or 60)",
)
parser.add_argument(
"--chunk-top-k",
type=int,
default=get_env_value("CHUNK_TOP_K", DEFAULT_CHUNK_TOP_K, int),
help="Number of text chunks to retrieve initially from vector search and keep after reranking (default: from env or 5)",
)
parser.add_argument(
"--cosine-threshold",
type=float,
default=get_env_value("COSINE_THRESHOLD", 0.2, float),
help="Cosine similarity threshold (default: from env or 0.4)",
)
# Ollama model name
parser.add_argument(
"--simulated-model-name",
@ -312,6 +290,26 @@ def parse_args() -> argparse.Namespace:
args.rerank_binding_host = get_env_value("RERANK_BINDING_HOST", None)
args.rerank_binding_api_key = get_env_value("RERANK_BINDING_API_KEY", None)
# Query configuration
args.history_turns = get_env_value("HISTORY_TURNS", DEFAULT_HISTORY_TURNS, int)
args.top_k = get_env_value("TOP_K", DEFAULT_TOP_K, int)
args.chunk_top_k = get_env_value("CHUNK_TOP_K", DEFAULT_CHUNK_TOP_K, int)
args.max_entity_tokens = get_env_value(
"MAX_ENTITY_TOKENS", DEFAULT_MAX_ENTITY_TOKENS, int
)
args.max_relation_tokens = get_env_value(
"MAX_RELATION_TOKENS", DEFAULT_MAX_RELATION_TOKENS, int
)
args.max_total_tokens = get_env_value(
"MAX_TOTAL_TOKENS", DEFAULT_MAX_TOTAL_TOKENS, int
)
args.cosine_threshold = get_env_value(
"COSINE_THRESHOLD", DEFAULT_COSINE_THRESHOLD, float
)
args.related_chunk_number = get_env_value(
"RELATED_CHUNK_NUMBER", DEFAULT_RELATED_CHUNK_NUMBER, int
)
ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
return args

View File

@ -20,6 +20,8 @@ DEFAULT_MAX_RELATION_TOKENS = 10000
DEFAULT_MAX_TOTAL_TOKENS = 32000
DEFAULT_HISTORY_TURNS = 3
DEFAULT_ENABLE_RERANK = True
DEFAULT_COSINE_THRESHOLD = 0.2
DEFAULT_RELATED_CHUNK_NUMBER = 5
# Separator for graph fields
GRAPH_FIELD_SEP = "<SEP>"
@ -28,6 +30,3 @@ GRAPH_FIELD_SEP = "<SEP>"
DEFAULT_LOG_MAX_BYTES = 10485760 # Default 10MB
DEFAULT_LOG_BACKUP_COUNT = 5 # Default 5 backups
DEFAULT_LOG_FILENAME = "lightrag.log" # Default log filename
# Related Chunk Number for Single Entity or Relation
DEFAULT_RELATED_CHUNK_NUMBER = 5

View File

@ -24,6 +24,13 @@ from typing import (
from lightrag.constants import (
DEFAULT_MAX_GLEANING,
DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE,
DEFAULT_TOP_K,
DEFAULT_CHUNK_TOP_K,
DEFAULT_MAX_ENTITY_TOKENS,
DEFAULT_MAX_RELATION_TOKENS,
DEFAULT_MAX_TOTAL_TOKENS,
DEFAULT_COSINE_THRESHOLD,
DEFAULT_RELATED_CHUNK_NUMBER,
)
from lightrag.utils import get_env_value
@ -125,6 +132,42 @@ class LightRAG:
log_level: int | None = field(default=None)
log_file_path: str | None = field(default=None)
# Query parameters
# ---
top_k: int = field(default=get_env_value("TOP_K", DEFAULT_TOP_K, int))
"""Number of entities/relations to retrieve for each query."""
chunk_top_k: int = field(
default=get_env_value("CHUNK_TOP_K", DEFAULT_CHUNK_TOP_K, int)
)
"""Maximum number of chunks in context."""
max_entity_tokens: int = field(
default=get_env_value("MAX_ENTITY_TOKENS", DEFAULT_MAX_ENTITY_TOKENS, int)
)
"""Maximum number of tokens for entity in context."""
max_relation_tokens: int = field(
default=get_env_value("MAX_RELATION_TOKENS", DEFAULT_MAX_RELATION_TOKENS, int)
)
"""Maximum number of tokens for relation in context."""
max_total_tokens: int = field(
default=get_env_value("MAX_TOTAL_TOKENS", DEFAULT_MAX_TOTAL_TOKENS, int)
)
"""Maximum total tokens in context (including system prompt, entities, relations and chunks)."""
cosine_threshold: int = field(
default=get_env_value("COSINE_THRESHOLD", DEFAULT_COSINE_THRESHOLD, int)
)
"""Cosine threshold of vector DB retrieval for entities, relations and chunks."""
related_chunk_number: int = field(
default=get_env_value("RELATED_CHUNK_NUMBER", DEFAULT_RELATED_CHUNK_NUMBER, int)
)
"""Number of related chunks to grab from single entity or relation."""
# Entity extraction
# ---

View File

@ -1908,7 +1908,6 @@ async def _build_query_context(
ll_keywords,
knowledge_graph_inst,
entities_vdb,
text_chunks_db,
query_param,
)
original_node_datas = node_datas
@ -1924,7 +1923,6 @@ async def _build_query_context(
hl_keywords,
knowledge_graph_inst,
relationships_vdb,
text_chunks_db,
query_param,
)
original_edge_datas = edge_datas
@ -1935,14 +1933,12 @@ async def _build_query_context(
ll_keywords,
knowledge_graph_inst,
entities_vdb,
text_chunks_db,
query_param,
)
hl_data = await _get_edge_data(
hl_keywords,
knowledge_graph_inst,
relationships_vdb,
text_chunks_db,
query_param,
)
@ -1985,23 +1981,17 @@ async def _build_query_context(
max_entity_tokens = getattr(
query_param,
"max_entity_tokens",
text_chunks_db.global_config.get(
"MAX_ENTITY_TOKENS", DEFAULT_MAX_ENTITY_TOKENS
),
text_chunks_db.global_config.get("max_entity_tokens", DEFAULT_MAX_ENTITY_TOKENS),
)
max_relation_tokens = getattr(
query_param,
"max_relation_tokens",
text_chunks_db.global_config.get(
"MAX_RELATION_TOKENS", DEFAULT_MAX_RELATION_TOKENS
),
text_chunks_db.global_config.get("max_relation_tokens", DEFAULT_MAX_RELATION_TOKENS),
)
max_total_tokens = getattr(
query_param,
"max_total_tokens",
text_chunks_db.global_config.get(
"MAX_TOTAL_TOKENS", DEFAULT_MAX_TOTAL_TOKENS
),
text_chunks_db.global_config.get("max_total_tokens", DEFAULT_MAX_TOTAL_TOKENS),
)
# Truncate entities based on complete JSON serialization
@ -2095,7 +2085,6 @@ async def _build_query_context(
final_edge_datas,
query_param,
text_chunks_db,
knowledge_graph_inst,
)
)
@ -2255,7 +2244,6 @@ async def _get_node_data(
query: str,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
):
# get similar entities
@ -2362,7 +2350,7 @@ async def _find_most_related_text_unit_from_entities(
text_units = [
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])[
:DEFAULT_RELATED_CHUNK_NUMBER
: text_chunks_db.global_config.get("related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER)
]
for dp in node_datas
if dp["source_id"] is not None
@ -2519,7 +2507,6 @@ async def _get_edge_data(
keywords,
knowledge_graph_inst: BaseGraphStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
):
logger.info(
@ -2668,13 +2655,12 @@ async def _find_related_text_unit_from_relationships(
edge_datas: list[dict],
query_param: QueryParam,
text_chunks_db: BaseKVStorage,
knowledge_graph_inst: BaseGraphStorage,
):
logger.debug(f"Searching text chunks for {len(edge_datas)} relationships")
text_units = [
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])[
:DEFAULT_RELATED_CHUNK_NUMBER
: text_chunks_db.global_config.get("related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER)
]
for dp in edge_datas
if dp["source_id"] is not None
@ -2761,7 +2747,7 @@ async def naive_query(
max_total_tokens = getattr(
query_param,
"max_total_tokens",
global_config.get("MAX_TOTAL_TOKENS", DEFAULT_MAX_TOTAL_TOKENS),
global_config.get("max_total_tokens", DEFAULT_MAX_TOTAL_TOKENS),
)
# Calculate conversation history tokens