mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-12-28 15:22:22 +00:00
Fix: rename rerank parameter from top_k to top_n
The change aligns with the API parameter naming used by Jina and Cohere rerank services, ensuring consistency and clarity.
This commit is contained in:
parent
4d8eda5ce3
commit
cb3bf3291c
@ -57,7 +57,7 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||
)
|
||||
|
||||
|
||||
async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs):
|
||||
async def my_rerank_func(query: str, documents: list, top_n: int = None, **kwargs):
|
||||
"""Custom rerank function with all settings included"""
|
||||
return await custom_rerank(
|
||||
query=query,
|
||||
@ -65,7 +65,7 @@ async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwarg
|
||||
model="BAAI/bge-reranker-v2-m3",
|
||||
base_url="https://api.your-rerank-provider.com/v1/rerank",
|
||||
api_key="your_rerank_api_key_here",
|
||||
top_k=top_k or 10, # Default top_k if not provided
|
||||
top_n=top_n or 10,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -217,7 +217,7 @@ async def test_direct_rerank():
|
||||
model="BAAI/bge-reranker-v2-m3",
|
||||
base_url="https://api.your-rerank-provider.com/v1/rerank",
|
||||
api_key="your_rerank_api_key_here",
|
||||
top_k=3,
|
||||
top_n=3,
|
||||
)
|
||||
|
||||
print("\n✅ Rerank Results:")
|
||||
|
||||
@ -298,7 +298,7 @@ def create_app(args):
|
||||
from lightrag.rerank import custom_rerank
|
||||
|
||||
async def server_rerank_func(
|
||||
query: str, documents: list, top_k: int = None, **kwargs
|
||||
query: str, documents: list, top_n: int = None, **kwargs
|
||||
):
|
||||
"""Server rerank function with configuration from environment variables"""
|
||||
return await custom_rerank(
|
||||
@ -307,7 +307,7 @@ def create_app(args):
|
||||
model=args.rerank_model,
|
||||
base_url=args.rerank_binding_host,
|
||||
api_key=args.rerank_binding_api_key,
|
||||
top_k=top_k,
|
||||
top_n=top_n,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@ -3165,7 +3165,7 @@ async def apply_rerank_if_enabled(
|
||||
retrieved_docs: list[dict],
|
||||
global_config: dict,
|
||||
enable_rerank: bool = True,
|
||||
top_k: int = None,
|
||||
top_n: int = None,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Apply reranking to retrieved documents if rerank is enabled.
|
||||
@ -3175,7 +3175,7 @@ async def apply_rerank_if_enabled(
|
||||
retrieved_docs: List of retrieved documents
|
||||
global_config: Global configuration containing rerank settings
|
||||
enable_rerank: Whether to enable reranking from query parameter
|
||||
top_k: Number of top documents to return after reranking
|
||||
top_n: Number of top documents to return after reranking
|
||||
|
||||
Returns:
|
||||
Reranked documents if rerank is enabled, otherwise original documents
|
||||
@ -3192,18 +3192,18 @@ async def apply_rerank_if_enabled(
|
||||
|
||||
try:
|
||||
logger.debug(
|
||||
f"Applying rerank to {len(retrieved_docs)} documents, returning top {top_k}"
|
||||
f"Applying rerank to {len(retrieved_docs)} documents, returning top {top_n}"
|
||||
)
|
||||
|
||||
# Apply reranking - let rerank_model_func handle top_k internally
|
||||
reranked_docs = await rerank_func(
|
||||
query=query,
|
||||
documents=retrieved_docs,
|
||||
top_k=top_k,
|
||||
top_n=top_n,
|
||||
)
|
||||
if reranked_docs and len(reranked_docs) > 0:
|
||||
if len(reranked_docs) > top_k:
|
||||
reranked_docs = reranked_docs[:top_k]
|
||||
if len(reranked_docs) > top_n:
|
||||
reranked_docs = reranked_docs[:top_n]
|
||||
logger.info(
|
||||
f"Successfully reranked {len(retrieved_docs)} documents to {len(reranked_docs)}"
|
||||
)
|
||||
@ -3263,7 +3263,7 @@ async def process_chunks_unified(
|
||||
retrieved_docs=unique_chunks,
|
||||
global_config=global_config,
|
||||
enable_rerank=query_param.enable_rerank,
|
||||
top_k=rerank_top_k,
|
||||
top_n=rerank_top_k,
|
||||
)
|
||||
logger.debug(f"Rerank: {len(unique_chunks)} chunks (source: {source_type})")
|
||||
|
||||
|
||||
@ -41,13 +41,13 @@ class RerankModel(BaseModel):
|
||||
|
||||
Or define a custom function directly:
|
||||
```python
|
||||
async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs):
|
||||
async def my_rerank_func(query: str, documents: list, top_n: int = None, **kwargs):
|
||||
return await jina_rerank(
|
||||
query=query,
|
||||
documents=documents,
|
||||
model="BAAI/bge-reranker-v2-m3",
|
||||
api_key="your_api_key_here",
|
||||
top_k=top_k or 10,
|
||||
top_n=top_n or 10,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@ -71,14 +71,14 @@ class RerankModel(BaseModel):
|
||||
self,
|
||||
query: str,
|
||||
documents: List[Dict[str, Any]],
|
||||
top_k: Optional[int] = None,
|
||||
top_n: Optional[int] = None,
|
||||
**extra_kwargs,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Rerank documents using the configured model function."""
|
||||
# Merge extra kwargs with model kwargs
|
||||
kwargs = {**self.kwargs, **extra_kwargs}
|
||||
return await self.rerank_func(
|
||||
query=query, documents=documents, top_k=top_k, **kwargs
|
||||
query=query, documents=documents, top_n=top_n, **kwargs
|
||||
)
|
||||
|
||||
|
||||
@ -98,7 +98,7 @@ class MultiRerankModel(BaseModel):
|
||||
query: str,
|
||||
documents: List[Dict[str, Any]],
|
||||
mode: str = "default",
|
||||
top_k: Optional[int] = None,
|
||||
top_n: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Rerank using the appropriate model based on mode."""
|
||||
@ -116,7 +116,7 @@ class MultiRerankModel(BaseModel):
|
||||
logger.warning(f"No rerank model available for mode: {mode}")
|
||||
return documents
|
||||
|
||||
return await model.rerank(query, documents, top_k, **kwargs)
|
||||
return await model.rerank(query, documents, top_n, **kwargs)
|
||||
|
||||
|
||||
async def generic_rerank_api(
|
||||
@ -125,7 +125,7 @@ async def generic_rerank_api(
|
||||
model: str,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
top_k: Optional[int] = None,
|
||||
top_n: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
@ -137,7 +137,7 @@ async def generic_rerank_api(
|
||||
model: Model identifier
|
||||
base_url: API endpoint URL
|
||||
api_key: API authentication key
|
||||
top_k: Number of top results to return
|
||||
top_n: Number of top results to return
|
||||
**kwargs: Additional API-specific parameters
|
||||
|
||||
Returns:
|
||||
@ -165,8 +165,8 @@ async def generic_rerank_api(
|
||||
|
||||
data = {"model": model, "query": query, "documents": prepared_docs, **kwargs}
|
||||
|
||||
if top_k is not None:
|
||||
data["top_k"] = min(top_k, len(prepared_docs))
|
||||
if top_n is not None:
|
||||
data["top_n"] = min(top_n, len(prepared_docs))
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
@ -206,7 +206,7 @@ async def jina_rerank(
|
||||
query: str,
|
||||
documents: List[Dict[str, Any]],
|
||||
model: str = "BAAI/bge-reranker-v2-m3",
|
||||
top_k: Optional[int] = None,
|
||||
top_n: Optional[int] = None,
|
||||
base_url: str = "https://api.jina.ai/v1/rerank",
|
||||
api_key: Optional[str] = None,
|
||||
**kwargs,
|
||||
@ -218,7 +218,7 @@ async def jina_rerank(
|
||||
query: The search query
|
||||
documents: List of documents to rerank
|
||||
model: Jina rerank model name
|
||||
top_k: Number of top results to return
|
||||
top_n: Number of top results to return
|
||||
base_url: Jina API endpoint
|
||||
api_key: Jina API key
|
||||
**kwargs: Additional parameters
|
||||
@ -235,7 +235,7 @@ async def jina_rerank(
|
||||
model=model,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
top_k=top_k,
|
||||
top_n=top_n,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -244,7 +244,7 @@ async def cohere_rerank(
|
||||
query: str,
|
||||
documents: List[Dict[str, Any]],
|
||||
model: str = "rerank-english-v2.0",
|
||||
top_k: Optional[int] = None,
|
||||
top_n: Optional[int] = None,
|
||||
base_url: str = "https://api.cohere.ai/v1/rerank",
|
||||
api_key: Optional[str] = None,
|
||||
**kwargs,
|
||||
@ -256,7 +256,7 @@ async def cohere_rerank(
|
||||
query: The search query
|
||||
documents: List of documents to rerank
|
||||
model: Cohere rerank model name
|
||||
top_k: Number of top results to return
|
||||
top_n: Number of top results to return
|
||||
base_url: Cohere API endpoint
|
||||
api_key: Cohere API key
|
||||
**kwargs: Additional parameters
|
||||
@ -273,7 +273,7 @@ async def cohere_rerank(
|
||||
model=model,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
top_k=top_k,
|
||||
top_n=top_n,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -285,7 +285,7 @@ async def custom_rerank(
|
||||
model: str,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
top_k: Optional[int] = None,
|
||||
top_n: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
@ -298,7 +298,7 @@ async def custom_rerank(
|
||||
model=model,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
top_k=top_k,
|
||||
top_n=top_n,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -317,7 +317,7 @@ if __name__ == "__main__":
|
||||
query = "What is the capital of France?"
|
||||
|
||||
result = await jina_rerank(
|
||||
query=query, documents=docs, top_k=2, api_key="your-api-key-here"
|
||||
query=query, documents=docs, top_n=2, api_key="your-api-key-here"
|
||||
)
|
||||
print(result)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user