mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-07-23 17:00:50 +00:00
325 lines
9.4 KiB
Python
325 lines
9.4 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
import aiohttp
|
|
from typing import Callable, Any, List, Dict, Optional
|
|
from pydantic import BaseModel, Field
|
|
|
|
from .utils import logger
|
|
|
|
|
|
class RerankModel(BaseModel):
|
|
"""
|
|
Wrapper for rerank functions that can be used with LightRAG.
|
|
|
|
Example usage:
|
|
```python
|
|
from lightrag.rerank import RerankModel, jina_rerank
|
|
|
|
# Create rerank model
|
|
rerank_model = RerankModel(
|
|
rerank_func=jina_rerank,
|
|
kwargs={
|
|
"model": "BAAI/bge-reranker-v2-m3",
|
|
"api_key": "your_api_key_here",
|
|
"base_url": "https://api.jina.ai/v1/rerank"
|
|
}
|
|
)
|
|
|
|
# Use in LightRAG
|
|
rag = LightRAG(
|
|
rerank_model_func=rerank_model.rerank,
|
|
# ... other configurations
|
|
)
|
|
|
|
# Query with rerank enabled (default)
|
|
result = await rag.aquery(
|
|
"your query",
|
|
param=QueryParam(enable_rerank=True)
|
|
)
|
|
```
|
|
|
|
Or define a custom function directly:
|
|
```python
|
|
async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs):
|
|
return await jina_rerank(
|
|
query=query,
|
|
documents=documents,
|
|
model="BAAI/bge-reranker-v2-m3",
|
|
api_key="your_api_key_here",
|
|
top_k=top_k or 10,
|
|
**kwargs
|
|
)
|
|
|
|
rag = LightRAG(
|
|
rerank_model_func=my_rerank_func,
|
|
# ... other configurations
|
|
)
|
|
|
|
# Control rerank per query
|
|
result = await rag.aquery(
|
|
"your query",
|
|
param=QueryParam(enable_rerank=True) # Enable rerank for this query
|
|
)
|
|
```
|
|
"""
|
|
|
|
rerank_func: Callable[[Any], List[Dict]]
|
|
kwargs: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
async def rerank(
|
|
self,
|
|
query: str,
|
|
documents: List[Dict[str, Any]],
|
|
top_k: Optional[int] = None,
|
|
**extra_kwargs,
|
|
) -> List[Dict[str, Any]]:
|
|
"""Rerank documents using the configured model function."""
|
|
# Merge extra kwargs with model kwargs
|
|
kwargs = {**self.kwargs, **extra_kwargs}
|
|
return await self.rerank_func(
|
|
query=query, documents=documents, top_k=top_k, **kwargs
|
|
)
|
|
|
|
|
|
class MultiRerankModel(BaseModel):
|
|
"""Multiple rerank models for different modes/scenarios."""
|
|
|
|
# Primary rerank model (used if mode-specific models are not defined)
|
|
rerank_model: Optional[RerankModel] = None
|
|
|
|
# Mode-specific rerank models
|
|
entity_rerank_model: Optional[RerankModel] = None
|
|
relation_rerank_model: Optional[RerankModel] = None
|
|
chunk_rerank_model: Optional[RerankModel] = None
|
|
|
|
async def rerank(
|
|
self,
|
|
query: str,
|
|
documents: List[Dict[str, Any]],
|
|
mode: str = "default",
|
|
top_k: Optional[int] = None,
|
|
**kwargs,
|
|
) -> List[Dict[str, Any]]:
|
|
"""Rerank using the appropriate model based on mode."""
|
|
|
|
# Select model based on mode
|
|
if mode == "entity" and self.entity_rerank_model:
|
|
model = self.entity_rerank_model
|
|
elif mode == "relation" and self.relation_rerank_model:
|
|
model = self.relation_rerank_model
|
|
elif mode == "chunk" and self.chunk_rerank_model:
|
|
model = self.chunk_rerank_model
|
|
elif self.rerank_model:
|
|
model = self.rerank_model
|
|
else:
|
|
logger.warning(f"No rerank model available for mode: {mode}")
|
|
return documents
|
|
|
|
return await model.rerank(query, documents, top_k, **kwargs)
|
|
|
|
|
|
async def generic_rerank_api(
|
|
query: str,
|
|
documents: List[Dict[str, Any]],
|
|
model: str,
|
|
base_url: str,
|
|
api_key: str,
|
|
top_k: Optional[int] = None,
|
|
**kwargs,
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Generic rerank function that works with Jina/Cohere compatible APIs.
|
|
|
|
Args:
|
|
query: The search query
|
|
documents: List of documents to rerank
|
|
model: Model identifier
|
|
base_url: API endpoint URL
|
|
api_key: API authentication key
|
|
top_k: Number of top results to return
|
|
**kwargs: Additional API-specific parameters
|
|
|
|
Returns:
|
|
List of reranked documents with relevance scores
|
|
"""
|
|
if not api_key:
|
|
logger.warning("No API key provided for rerank service")
|
|
return documents
|
|
|
|
if not documents:
|
|
return documents
|
|
|
|
# Prepare documents for reranking - handle both text and dict formats
|
|
prepared_docs = []
|
|
for doc in documents:
|
|
if isinstance(doc, dict):
|
|
# Use 'content' field if available, otherwise use 'text' or convert to string
|
|
text = doc.get("content") or doc.get("text") or str(doc)
|
|
else:
|
|
text = str(doc)
|
|
prepared_docs.append(text)
|
|
|
|
# Prepare request
|
|
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
|
|
|
data = {"model": model, "query": query, "documents": prepared_docs, **kwargs}
|
|
|
|
if top_k is not None:
|
|
data["top_k"] = min(top_k, len(prepared_docs))
|
|
|
|
try:
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.post(base_url, headers=headers, json=data) as response:
|
|
if response.status != 200:
|
|
error_text = await response.text()
|
|
logger.error(f"Rerank API error {response.status}: {error_text}")
|
|
return documents
|
|
|
|
result = await response.json()
|
|
|
|
# Extract reranked results
|
|
if "results" in result:
|
|
# Standard format: results contain index and relevance_score
|
|
reranked_docs = []
|
|
for item in result["results"]:
|
|
if "index" in item:
|
|
doc_idx = item["index"]
|
|
if 0 <= doc_idx < len(documents):
|
|
reranked_doc = documents[doc_idx].copy()
|
|
if "relevance_score" in item:
|
|
reranked_doc["rerank_score"] = item[
|
|
"relevance_score"
|
|
]
|
|
reranked_docs.append(reranked_doc)
|
|
return reranked_docs
|
|
else:
|
|
logger.warning("Unexpected rerank API response format")
|
|
return documents
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during reranking: {e}")
|
|
return documents
|
|
|
|
|
|
async def jina_rerank(
|
|
query: str,
|
|
documents: List[Dict[str, Any]],
|
|
model: str = "BAAI/bge-reranker-v2-m3",
|
|
top_k: Optional[int] = None,
|
|
base_url: str = "https://api.jina.ai/v1/rerank",
|
|
api_key: Optional[str] = None,
|
|
**kwargs,
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Rerank documents using Jina AI API.
|
|
|
|
Args:
|
|
query: The search query
|
|
documents: List of documents to rerank
|
|
model: Jina rerank model name
|
|
top_k: Number of top results to return
|
|
base_url: Jina API endpoint
|
|
api_key: Jina API key
|
|
**kwargs: Additional parameters
|
|
|
|
Returns:
|
|
List of reranked documents with relevance scores
|
|
"""
|
|
if api_key is None:
|
|
api_key = os.getenv("JINA_API_KEY") or os.getenv("RERANK_API_KEY")
|
|
|
|
return await generic_rerank_api(
|
|
query=query,
|
|
documents=documents,
|
|
model=model,
|
|
base_url=base_url,
|
|
api_key=api_key,
|
|
top_k=top_k,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
async def cohere_rerank(
|
|
query: str,
|
|
documents: List[Dict[str, Any]],
|
|
model: str = "rerank-english-v2.0",
|
|
top_k: Optional[int] = None,
|
|
base_url: str = "https://api.cohere.ai/v1/rerank",
|
|
api_key: Optional[str] = None,
|
|
**kwargs,
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Rerank documents using Cohere API.
|
|
|
|
Args:
|
|
query: The search query
|
|
documents: List of documents to rerank
|
|
model: Cohere rerank model name
|
|
top_k: Number of top results to return
|
|
base_url: Cohere API endpoint
|
|
api_key: Cohere API key
|
|
**kwargs: Additional parameters
|
|
|
|
Returns:
|
|
List of reranked documents with relevance scores
|
|
"""
|
|
if api_key is None:
|
|
api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_API_KEY")
|
|
|
|
return await generic_rerank_api(
|
|
query=query,
|
|
documents=documents,
|
|
model=model,
|
|
base_url=base_url,
|
|
api_key=api_key,
|
|
top_k=top_k,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
# Convenience function for custom API endpoints
|
|
async def custom_rerank(
|
|
query: str,
|
|
documents: List[Dict[str, Any]],
|
|
model: str,
|
|
base_url: str,
|
|
api_key: str,
|
|
top_k: Optional[int] = None,
|
|
**kwargs,
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Rerank documents using a custom API endpoint.
|
|
This is useful for self-hosted or custom rerank services.
|
|
"""
|
|
return await generic_rerank_api(
|
|
query=query,
|
|
documents=documents,
|
|
model=model,
|
|
base_url=base_url,
|
|
api_key=api_key,
|
|
top_k=top_k,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import asyncio
|
|
|
|
async def main():
|
|
# Example usage
|
|
docs = [
|
|
{"content": "The capital of France is Paris."},
|
|
{"content": "Tokyo is the capital of Japan."},
|
|
{"content": "London is the capital of England."},
|
|
]
|
|
|
|
query = "What is the capital of France?"
|
|
|
|
result = await jina_rerank(
|
|
query=query, documents=docs, top_k=2, api_key="your-api-key-here"
|
|
)
|
|
print(result)
|
|
|
|
asyncio.run(main())
|