LightRAG/lightrag/api/routers/query_routes.py

226 lines
8.2 KiB
Python
Raw Normal View History

"""
This module contains all query-related routes for the LightRAG API.
"""
import json
import logging
from typing import Any, Dict, List, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException
from lightrag.base import QueryParam
from ..utils_api import get_api_key_dependency
from pydantic import BaseModel, Field, field_validator
from ascii_colors import trace_exception
router = APIRouter(tags=["query"])
2025-02-20 04:12:21 +08:00
class QueryRequest(BaseModel):
query: str = Field(
min_length=1,
description="The query text",
)
mode: Literal["local", "global", "hybrid", "naive", "mix"] = Field(
default="hybrid",
description="Query mode",
)
only_need_context: Optional[bool] = Field(
default=None,
description="If True, only returns the retrieved context without generating a response.",
)
only_need_prompt: Optional[bool] = Field(
default=None,
description="If True, only returns the generated prompt without producing a response.",
)
response_type: Optional[str] = Field(
min_length=1,
default=None,
description="Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'.",
)
top_k: Optional[int] = Field(
ge=1,
default=None,
description="Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.",
)
max_token_for_text_unit: Optional[int] = Field(
gt=1,
default=None,
description="Maximum number of tokens allowed for each retrieved text chunk.",
)
max_token_for_global_context: Optional[int] = Field(
gt=1,
default=None,
description="Maximum number of tokens allocated for relationship descriptions in global retrieval.",
)
max_token_for_local_context: Optional[int] = Field(
gt=1,
default=None,
description="Maximum number of tokens allocated for entity descriptions in local retrieval.",
)
hl_keywords: Optional[List[str]] = Field(
default=None,
description="List of high-level keywords to prioritize in retrieval.",
)
ll_keywords: Optional[List[str]] = Field(
default=None,
description="List of low-level keywords to refine retrieval focus.",
)
conversation_history: Optional[List[Dict[str, Any]]] = Field(
default=None,
description="Stores past conversation history to maintain context. Format: [{'role': 'user/assistant', 'content': 'message'}].",
)
history_turns: Optional[int] = Field(
ge=0,
default=None,
description="Number of complete conversation turns (user-assistant pairs) to consider in the response context.",
)
@field_validator("query", mode="after")
@classmethod
def query_strip_after(cls, query: str) -> str:
return query.strip()
@field_validator("hl_keywords", mode="after")
@classmethod
def hl_keywords_strip_after(cls, hl_keywords: List[str] | None) -> List[str] | None:
if hl_keywords is None:
return None
return [keyword.strip() for keyword in hl_keywords]
@field_validator("ll_keywords", mode="after")
@classmethod
def ll_keywords_strip_after(cls, ll_keywords: List[str] | None) -> List[str] | None:
if ll_keywords is None:
return None
return [keyword.strip() for keyword in ll_keywords]
@field_validator("conversation_history", mode="after")
@classmethod
def conversation_history_role_check(
cls, conversation_history: List[Dict[str, Any]] | None
) -> List[Dict[str, Any]] | None:
if conversation_history is None:
return None
for msg in conversation_history:
if "role" not in msg or msg["role"] not in {"user", "assistant"}:
raise ValueError(
"Each message must have a 'role' key with value 'user' or 'assistant'."
)
return conversation_history
def to_query_params(self, is_stream: bool) -> "QueryParam":
"""Converts a QueryRequest instance into a QueryParam instance."""
# Use Pydantic's `.model_dump(exclude_none=True)` to remove None values automatically
request_data = self.model_dump(exclude_none=True, exclude={"query"})
# Ensure `mode` and `stream` are set explicitly
param = QueryParam(**request_data)
param.stream = is_stream
return param
2025-02-20 04:12:21 +08:00
class QueryResponse(BaseModel):
response: str = Field(
description="The generated response",
)
2025-02-20 04:12:21 +08:00
def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
optional_api_key = get_api_key_dependency(api_key)
2025-02-20 04:12:21 +08:00
@router.post(
"/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]
)
async def query_text(request: QueryRequest):
"""
Handle a POST request at the /query endpoint to process user queries using RAG capabilities.
Parameters:
request (QueryRequest): The request object containing the query parameters.
Returns:
QueryResponse: A Pydantic model containing the result of the query processing.
If a string is returned (e.g., cache hit), it's directly returned.
Otherwise, an async generator may be used to build the response.
Raises:
HTTPException: Raised when an error occurs during the request handling process,
with status code 500 and detail containing the exception message.
"""
try:
param = request.to_query_params(False)
response = await rag.aquery(request.query, param=param)
# If response is a string (e.g. cache hit), return directly
if isinstance(response, str):
return QueryResponse(response=response)
if isinstance(response, dict):
result = json.dumps(response, indent=2)
return QueryResponse(response=result)
else:
return QueryResponse(response=str(response))
except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e))
@router.post("/query/stream", dependencies=[Depends(optional_api_key)])
async def query_text_stream(request: QueryRequest):
"""
This endpoint performs a retrieval-augmented generation (RAG) query and streams the response.
Args:
request (QueryRequest): The request object containing the query parameters.
optional_api_key (Optional[str], optional): An optional API key for authentication. Defaults to None.
Returns:
StreamingResponse: A streaming response containing the RAG query results.
"""
try:
param = request.to_query_params(True)
response = await rag.aquery(request.query, param=param)
from fastapi.responses import StreamingResponse
async def stream_generator():
if isinstance(response, str):
# If it's a string, send it all at once
yield f"{json.dumps({'response': response})}\n"
else:
# If it's an async generator, send chunks one by one
try:
async for chunk in response:
if chunk: # Only send non-empty content
yield f"{json.dumps({'response': chunk})}\n"
except Exception as e:
logging.error(f"Streaming error: {str(e)}")
yield f"{json.dumps({'error': str(e)})}\n"
return StreamingResponse(
stream_generator(),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "application/x-ndjson",
"X-Accel-Buffering": "no", # Ensure proper handling of streaming response when proxied by Nginx
},
)
except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e))
return router