LightRAG/lightrag/rerank.py
2025-07-15 12:17:27 +08:00

325 lines
9.4 KiB
Python

from __future__ import annotations
import os
import aiohttp
from typing import Callable, Any, List, Dict, Optional
from pydantic import BaseModel, Field
from .utils import logger
class RerankModel(BaseModel):
"""
Wrapper for rerank functions that can be used with LightRAG.
Example usage:
```python
from lightrag.rerank import RerankModel, jina_rerank
# Create rerank model
rerank_model = RerankModel(
rerank_func=jina_rerank,
kwargs={
"model": "BAAI/bge-reranker-v2-m3",
"api_key": "your_api_key_here",
"base_url": "https://api.jina.ai/v1/rerank"
}
)
# Use in LightRAG
rag = LightRAG(
rerank_model_func=rerank_model.rerank,
# ... other configurations
)
# Query with rerank enabled (default)
result = await rag.aquery(
"your query",
param=QueryParam(enable_rerank=True)
)
```
Or define a custom function directly:
```python
async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs):
return await jina_rerank(
query=query,
documents=documents,
model="BAAI/bge-reranker-v2-m3",
api_key="your_api_key_here",
top_k=top_k or 10,
**kwargs
)
rag = LightRAG(
rerank_model_func=my_rerank_func,
# ... other configurations
)
# Control rerank per query
result = await rag.aquery(
"your query",
param=QueryParam(enable_rerank=True) # Enable rerank for this query
)
```
"""
rerank_func: Callable[[Any], List[Dict]]
kwargs: Dict[str, Any] = Field(default_factory=dict)
async def rerank(
self,
query: str,
documents: List[Dict[str, Any]],
top_k: Optional[int] = None,
**extra_kwargs,
) -> List[Dict[str, Any]]:
"""Rerank documents using the configured model function."""
# Merge extra kwargs with model kwargs
kwargs = {**self.kwargs, **extra_kwargs}
return await self.rerank_func(
query=query, documents=documents, top_k=top_k, **kwargs
)
class MultiRerankModel(BaseModel):
"""Multiple rerank models for different modes/scenarios."""
# Primary rerank model (used if mode-specific models are not defined)
rerank_model: Optional[RerankModel] = None
# Mode-specific rerank models
entity_rerank_model: Optional[RerankModel] = None
relation_rerank_model: Optional[RerankModel] = None
chunk_rerank_model: Optional[RerankModel] = None
async def rerank(
self,
query: str,
documents: List[Dict[str, Any]],
mode: str = "default",
top_k: Optional[int] = None,
**kwargs,
) -> List[Dict[str, Any]]:
"""Rerank using the appropriate model based on mode."""
# Select model based on mode
if mode == "entity" and self.entity_rerank_model:
model = self.entity_rerank_model
elif mode == "relation" and self.relation_rerank_model:
model = self.relation_rerank_model
elif mode == "chunk" and self.chunk_rerank_model:
model = self.chunk_rerank_model
elif self.rerank_model:
model = self.rerank_model
else:
logger.warning(f"No rerank model available for mode: {mode}")
return documents
return await model.rerank(query, documents, top_k, **kwargs)
async def generic_rerank_api(
query: str,
documents: List[Dict[str, Any]],
model: str,
base_url: str,
api_key: str,
top_k: Optional[int] = None,
**kwargs,
) -> List[Dict[str, Any]]:
"""
Generic rerank function that works with Jina/Cohere compatible APIs.
Args:
query: The search query
documents: List of documents to rerank
model: Model identifier
base_url: API endpoint URL
api_key: API authentication key
top_k: Number of top results to return
**kwargs: Additional API-specific parameters
Returns:
List of reranked documents with relevance scores
"""
if not api_key:
logger.warning("No API key provided for rerank service")
return documents
if not documents:
return documents
# Prepare documents for reranking - handle both text and dict formats
prepared_docs = []
for doc in documents:
if isinstance(doc, dict):
# Use 'content' field if available, otherwise use 'text' or convert to string
text = doc.get("content") or doc.get("text") or str(doc)
else:
text = str(doc)
prepared_docs.append(text)
# Prepare request
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
data = {"model": model, "query": query, "documents": prepared_docs, **kwargs}
if top_k is not None:
data["top_k"] = min(top_k, len(prepared_docs))
try:
async with aiohttp.ClientSession() as session:
async with session.post(base_url, headers=headers, json=data) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"Rerank API error {response.status}: {error_text}")
return documents
result = await response.json()
# Extract reranked results
if "results" in result:
# Standard format: results contain index and relevance_score
reranked_docs = []
for item in result["results"]:
if "index" in item:
doc_idx = item["index"]
if 0 <= doc_idx < len(documents):
reranked_doc = documents[doc_idx].copy()
if "relevance_score" in item:
reranked_doc["rerank_score"] = item[
"relevance_score"
]
reranked_docs.append(reranked_doc)
return reranked_docs
else:
logger.warning("Unexpected rerank API response format")
return documents
except Exception as e:
logger.error(f"Error during reranking: {e}")
return documents
async def jina_rerank(
query: str,
documents: List[Dict[str, Any]],
model: str = "BAAI/bge-reranker-v2-m3",
top_k: Optional[int] = None,
base_url: str = "https://api.jina.ai/v1/rerank",
api_key: Optional[str] = None,
**kwargs,
) -> List[Dict[str, Any]]:
"""
Rerank documents using Jina AI API.
Args:
query: The search query
documents: List of documents to rerank
model: Jina rerank model name
top_k: Number of top results to return
base_url: Jina API endpoint
api_key: Jina API key
**kwargs: Additional parameters
Returns:
List of reranked documents with relevance scores
"""
if api_key is None:
api_key = os.getenv("JINA_API_KEY") or os.getenv("RERANK_API_KEY")
return await generic_rerank_api(
query=query,
documents=documents,
model=model,
base_url=base_url,
api_key=api_key,
top_k=top_k,
**kwargs,
)
async def cohere_rerank(
query: str,
documents: List[Dict[str, Any]],
model: str = "rerank-english-v2.0",
top_k: Optional[int] = None,
base_url: str = "https://api.cohere.ai/v1/rerank",
api_key: Optional[str] = None,
**kwargs,
) -> List[Dict[str, Any]]:
"""
Rerank documents using Cohere API.
Args:
query: The search query
documents: List of documents to rerank
model: Cohere rerank model name
top_k: Number of top results to return
base_url: Cohere API endpoint
api_key: Cohere API key
**kwargs: Additional parameters
Returns:
List of reranked documents with relevance scores
"""
if api_key is None:
api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_API_KEY")
return await generic_rerank_api(
query=query,
documents=documents,
model=model,
base_url=base_url,
api_key=api_key,
top_k=top_k,
**kwargs,
)
# Convenience function for custom API endpoints
async def custom_rerank(
query: str,
documents: List[Dict[str, Any]],
model: str,
base_url: str,
api_key: str,
top_k: Optional[int] = None,
**kwargs,
) -> List[Dict[str, Any]]:
"""
Rerank documents using a custom API endpoint.
This is useful for self-hosted or custom rerank services.
"""
return await generic_rerank_api(
query=query,
documents=documents,
model=model,
base_url=base_url,
api_key=api_key,
top_k=top_k,
**kwargs,
)
if __name__ == "__main__":
import asyncio
async def main():
# Example usage
docs = [
{"content": "The capital of France is Paris."},
{"content": "Tokyo is the capital of Japan."},
{"content": "London is the capital of England."},
]
query = "What is the capital of France?"
result = await jina_rerank(
query=query, documents=docs, top_k=2, api_key="your-api-key-here"
)
print(result)
asyncio.run(main())