mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-07-26 02:10:50 +00:00
307 lines
9.1 KiB
Python
307 lines
9.1 KiB
Python
![]() |
from __future__ import annotations
|
||
|
|
||
|
import os
|
||
|
import json
|
||
|
import aiohttp
|
||
|
import numpy as np
|
||
|
from typing import Callable, Any, List, Dict, Optional
|
||
|
from pydantic import BaseModel, Field
|
||
|
from dataclasses import asdict
|
||
|
|
||
|
from .utils import logger
|
||
|
|
||
|
|
||
|
class RerankModel(BaseModel):
|
||
|
"""
|
||
|
Pydantic model class for defining a custom rerank model.
|
||
|
|
||
|
Attributes:
|
||
|
rerank_func (Callable[[Any], List[Dict]]): A callable function that reranks documents.
|
||
|
The function should take query and documents as input and return reranked results.
|
||
|
kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function.
|
||
|
This could include parameters such as the model name, API key, etc.
|
||
|
|
||
|
Example usage:
|
||
|
Rerank model example from jina:
|
||
|
```python
|
||
|
rerank_model = RerankModel(
|
||
|
rerank_func=jina_rerank,
|
||
|
kwargs={
|
||
|
"model": "BAAI/bge-reranker-v2-m3",
|
||
|
"api_key": "your_api_key_here",
|
||
|
"base_url": "https://api.jina.ai/v1/rerank"
|
||
|
}
|
||
|
)
|
||
|
```
|
||
|
"""
|
||
|
|
||
|
rerank_func: Callable[[Any], List[Dict]]
|
||
|
kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||
|
|
||
|
async def rerank(
|
||
|
self,
|
||
|
query: str,
|
||
|
documents: List[Dict[str, Any]],
|
||
|
top_k: Optional[int] = None,
|
||
|
**extra_kwargs
|
||
|
) -> List[Dict[str, Any]]:
|
||
|
"""Rerank documents using the configured model function."""
|
||
|
# Merge extra kwargs with model kwargs
|
||
|
kwargs = {**self.kwargs, **extra_kwargs}
|
||
|
return await self.rerank_func(
|
||
|
query=query,
|
||
|
documents=documents,
|
||
|
top_k=top_k,
|
||
|
**kwargs
|
||
|
)
|
||
|
|
||
|
|
||
|
class MultiRerankModel(BaseModel):
|
||
|
"""Multiple rerank models for different modes/scenarios."""
|
||
|
|
||
|
# Primary rerank model (used if mode-specific models are not defined)
|
||
|
rerank_model: Optional[RerankModel] = None
|
||
|
|
||
|
# Mode-specific rerank models
|
||
|
entity_rerank_model: Optional[RerankModel] = None
|
||
|
relation_rerank_model: Optional[RerankModel] = None
|
||
|
chunk_rerank_model: Optional[RerankModel] = None
|
||
|
|
||
|
async def rerank(
|
||
|
self,
|
||
|
query: str,
|
||
|
documents: List[Dict[str, Any]],
|
||
|
mode: str = "default",
|
||
|
top_k: Optional[int] = None,
|
||
|
**kwargs
|
||
|
) -> List[Dict[str, Any]]:
|
||
|
"""Rerank using the appropriate model based on mode."""
|
||
|
|
||
|
# Select model based on mode
|
||
|
if mode == "entity" and self.entity_rerank_model:
|
||
|
model = self.entity_rerank_model
|
||
|
elif mode == "relation" and self.relation_rerank_model:
|
||
|
model = self.relation_rerank_model
|
||
|
elif mode == "chunk" and self.chunk_rerank_model:
|
||
|
model = self.chunk_rerank_model
|
||
|
elif self.rerank_model:
|
||
|
model = self.rerank_model
|
||
|
else:
|
||
|
logger.warning(f"No rerank model available for mode: {mode}")
|
||
|
return documents
|
||
|
|
||
|
return await model.rerank(query, documents, top_k, **kwargs)
|
||
|
|
||
|
|
||
|
async def generic_rerank_api(
|
||
|
query: str,
|
||
|
documents: List[Dict[str, Any]],
|
||
|
model: str,
|
||
|
base_url: str,
|
||
|
api_key: str,
|
||
|
top_k: Optional[int] = None,
|
||
|
**kwargs
|
||
|
) -> List[Dict[str, Any]]:
|
||
|
"""
|
||
|
Generic rerank function that works with Jina/Cohere compatible APIs.
|
||
|
|
||
|
Args:
|
||
|
query: The search query
|
||
|
documents: List of documents to rerank
|
||
|
model: Model identifier
|
||
|
base_url: API endpoint URL
|
||
|
api_key: API authentication key
|
||
|
top_k: Number of top results to return
|
||
|
**kwargs: Additional API-specific parameters
|
||
|
|
||
|
Returns:
|
||
|
List of reranked documents with relevance scores
|
||
|
"""
|
||
|
if not api_key:
|
||
|
logger.warning("No API key provided for rerank service")
|
||
|
return documents
|
||
|
|
||
|
if not documents:
|
||
|
return documents
|
||
|
|
||
|
# Prepare documents for reranking - handle both text and dict formats
|
||
|
prepared_docs = []
|
||
|
for doc in documents:
|
||
|
if isinstance(doc, dict):
|
||
|
# Use 'content' field if available, otherwise use 'text' or convert to string
|
||
|
text = doc.get('content') or doc.get('text') or str(doc)
|
||
|
else:
|
||
|
text = str(doc)
|
||
|
prepared_docs.append(text)
|
||
|
|
||
|
# Prepare request
|
||
|
headers = {
|
||
|
"Content-Type": "application/json",
|
||
|
"Authorization": f"Bearer {api_key}"
|
||
|
}
|
||
|
|
||
|
data = {
|
||
|
"model": model,
|
||
|
"query": query,
|
||
|
"documents": prepared_docs,
|
||
|
**kwargs
|
||
|
}
|
||
|
|
||
|
if top_k is not None:
|
||
|
data["top_k"] = min(top_k, len(prepared_docs))
|
||
|
|
||
|
try:
|
||
|
async with aiohttp.ClientSession() as session:
|
||
|
async with session.post(base_url, headers=headers, json=data) as response:
|
||
|
if response.status != 200:
|
||
|
error_text = await response.text()
|
||
|
logger.error(f"Rerank API error {response.status}: {error_text}")
|
||
|
return documents
|
||
|
|
||
|
result = await response.json()
|
||
|
|
||
|
# Extract reranked results
|
||
|
if "results" in result:
|
||
|
# Standard format: results contain index and relevance_score
|
||
|
reranked_docs = []
|
||
|
for item in result["results"]:
|
||
|
if "index" in item:
|
||
|
doc_idx = item["index"]
|
||
|
if 0 <= doc_idx < len(documents):
|
||
|
reranked_doc = documents[doc_idx].copy()
|
||
|
if "relevance_score" in item:
|
||
|
reranked_doc["rerank_score"] = item["relevance_score"]
|
||
|
reranked_docs.append(reranked_doc)
|
||
|
return reranked_docs
|
||
|
else:
|
||
|
logger.warning("Unexpected rerank API response format")
|
||
|
return documents
|
||
|
|
||
|
except Exception as e:
|
||
|
logger.error(f"Error during reranking: {e}")
|
||
|
return documents
|
||
|
|
||
|
|
||
|
async def jina_rerank(
|
||
|
query: str,
|
||
|
documents: List[Dict[str, Any]],
|
||
|
model: str = "BAAI/bge-reranker-v2-m3",
|
||
|
top_k: Optional[int] = None,
|
||
|
base_url: str = "https://api.jina.ai/v1/rerank",
|
||
|
api_key: Optional[str] = None,
|
||
|
**kwargs
|
||
|
) -> List[Dict[str, Any]]:
|
||
|
"""
|
||
|
Rerank documents using Jina AI API.
|
||
|
|
||
|
Args:
|
||
|
query: The search query
|
||
|
documents: List of documents to rerank
|
||
|
model: Jina rerank model name
|
||
|
top_k: Number of top results to return
|
||
|
base_url: Jina API endpoint
|
||
|
api_key: Jina API key
|
||
|
**kwargs: Additional parameters
|
||
|
|
||
|
Returns:
|
||
|
List of reranked documents with relevance scores
|
||
|
"""
|
||
|
if api_key is None:
|
||
|
api_key = os.getenv("JINA_API_KEY") or os.getenv("RERANK_API_KEY")
|
||
|
|
||
|
return await generic_rerank_api(
|
||
|
query=query,
|
||
|
documents=documents,
|
||
|
model=model,
|
||
|
base_url=base_url,
|
||
|
api_key=api_key,
|
||
|
top_k=top_k,
|
||
|
**kwargs
|
||
|
)
|
||
|
|
||
|
|
||
|
async def cohere_rerank(
|
||
|
query: str,
|
||
|
documents: List[Dict[str, Any]],
|
||
|
model: str = "rerank-english-v2.0",
|
||
|
top_k: Optional[int] = None,
|
||
|
base_url: str = "https://api.cohere.ai/v1/rerank",
|
||
|
api_key: Optional[str] = None,
|
||
|
**kwargs
|
||
|
) -> List[Dict[str, Any]]:
|
||
|
"""
|
||
|
Rerank documents using Cohere API.
|
||
|
|
||
|
Args:
|
||
|
query: The search query
|
||
|
documents: List of documents to rerank
|
||
|
model: Cohere rerank model name
|
||
|
top_k: Number of top results to return
|
||
|
base_url: Cohere API endpoint
|
||
|
api_key: Cohere API key
|
||
|
**kwargs: Additional parameters
|
||
|
|
||
|
Returns:
|
||
|
List of reranked documents with relevance scores
|
||
|
"""
|
||
|
if api_key is None:
|
||
|
api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_API_KEY")
|
||
|
|
||
|
return await generic_rerank_api(
|
||
|
query=query,
|
||
|
documents=documents,
|
||
|
model=model,
|
||
|
base_url=base_url,
|
||
|
api_key=api_key,
|
||
|
top_k=top_k,
|
||
|
**kwargs
|
||
|
)
|
||
|
|
||
|
|
||
|
# Convenience function for custom API endpoints
|
||
|
async def custom_rerank(
|
||
|
query: str,
|
||
|
documents: List[Dict[str, Any]],
|
||
|
model: str,
|
||
|
base_url: str,
|
||
|
api_key: str,
|
||
|
top_k: Optional[int] = None,
|
||
|
**kwargs
|
||
|
) -> List[Dict[str, Any]]:
|
||
|
"""
|
||
|
Rerank documents using a custom API endpoint.
|
||
|
This is useful for self-hosted or custom rerank services.
|
||
|
"""
|
||
|
return await generic_rerank_api(
|
||
|
query=query,
|
||
|
documents=documents,
|
||
|
model=model,
|
||
|
base_url=base_url,
|
||
|
api_key=api_key,
|
||
|
top_k=top_k,
|
||
|
**kwargs
|
||
|
)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
import asyncio
|
||
|
|
||
|
async def main():
|
||
|
# Example usage
|
||
|
docs = [
|
||
|
{"content": "The capital of France is Paris."},
|
||
|
{"content": "Tokyo is the capital of Japan."},
|
||
|
{"content": "London is the capital of England."},
|
||
|
]
|
||
|
|
||
|
query = "What is the capital of France?"
|
||
|
|
||
|
result = await jina_rerank(
|
||
|
query=query,
|
||
|
documents=docs,
|
||
|
top_k=2,
|
||
|
api_key="your-api-key-here"
|
||
|
)
|
||
|
print(result)
|
||
|
|
||
|
asyncio.run(main())
|