ragflow/rag/llm/rerank_model.py
Stephen Hu 0ecccd27eb
Refactor:improve the logic for rerank models to cal the total token count (#10882)
### What problem does this PR solve?

improve the logic for rerank models to cal the total token count

### Type of change

- [x] Refactoring
2025-10-31 09:46:16 +08:00

492 lines
16 KiB
Python

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
from abc import ABC
from urllib.parse import urljoin
import httpx
import numpy as np
import requests
from yarl import URL
from api.utils.log_utils import log_exception
from rag.utils import num_tokens_from_string, truncate, total_token_count_from_response
class Base(ABC):
def __init__(self, key, model_name, **kwargs):
"""
Abstract base class constructor.
Parameters are not stored; initialization is left to subclasses.
"""
pass
def similarity(self, query: str, texts: list):
raise NotImplementedError("Please implement encode method!")
class JinaRerank(Base):
_FACTORY_NAME = "Jina"
def __init__(self, key, model_name="jina-reranker-v2-base-multilingual", base_url="https://api.jina.ai/v1/rerank"):
self.base_url = "https://api.jina.ai/v1/rerank"
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
self.model_name = model_name
def similarity(self, query: str, texts: list):
texts = [truncate(t, 8196) for t in texts]
data = {"model": self.model_name, "query": query, "documents": texts, "top_n": len(texts)}
res = requests.post(self.base_url, headers=self.headers, json=data).json()
rank = np.zeros(len(texts), dtype=float)
try:
for d in res["results"]:
rank[d["index"]] = d["relevance_score"]
except Exception as _e:
log_exception(_e, res)
return rank, total_token_count_from_response(res)
class XInferenceRerank(Base):
_FACTORY_NAME = "Xinference"
def __init__(self, key="x", model_name="", base_url=""):
if base_url.find("/v1") == -1:
base_url = urljoin(base_url, "/v1/rerank")
if base_url.find("/rerank") == -1:
base_url = urljoin(base_url, "/v1/rerank")
self.model_name = model_name
self.base_url = base_url
self.headers = {"Content-Type": "application/json", "accept": "application/json"}
if key and key != "x":
self.headers["Authorization"] = f"Bearer {key}"
def similarity(self, query: str, texts: list):
if len(texts) == 0:
return np.array([]), 0
pairs = [(query, truncate(t, 4096)) for t in texts]
token_count = 0
for _, t in pairs:
token_count += num_tokens_from_string(t)
data = {"model": self.model_name, "query": query, "return_documents": "true", "return_len": "true", "documents": texts}
res = requests.post(self.base_url, headers=self.headers, json=data).json()
rank = np.zeros(len(texts), dtype=float)
try:
for d in res["results"]:
rank[d["index"]] = d["relevance_score"]
except Exception as _e:
log_exception(_e, res)
return rank, token_count
class LocalAIRerank(Base):
_FACTORY_NAME = "LocalAI"
def __init__(self, key, model_name, base_url):
if base_url.find("/rerank") == -1:
self.base_url = urljoin(base_url, "/rerank")
else:
self.base_url = base_url
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
self.model_name = model_name.split("___")[0]
def similarity(self, query: str, texts: list):
# noway to config Ragflow , use fix setting
texts = [truncate(t, 500) for t in texts]
data = {
"model": self.model_name,
"query": query,
"documents": texts,
"top_n": len(texts),
}
token_count = 0
for t in texts:
token_count += num_tokens_from_string(t)
res = requests.post(self.base_url, headers=self.headers, json=data).json()
rank = np.zeros(len(texts), dtype=float)
try:
for d in res["results"]:
rank[d["index"]] = d["relevance_score"]
except Exception as _e:
log_exception(_e, res)
# Normalize the rank values to the range 0 to 1
min_rank = np.min(rank)
max_rank = np.max(rank)
# Avoid division by zero if all ranks are identical
if not np.isclose(min_rank, max_rank, atol=1e-3):
rank = (rank - min_rank) / (max_rank - min_rank)
else:
rank = np.zeros_like(rank)
return rank, token_count
class NvidiaRerank(Base):
_FACTORY_NAME = "NVIDIA"
def __init__(self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"):
if not base_url:
base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
self.model_name = model_name
if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3":
self.base_url = urljoin(base_url, "nv-rerankqa-mistral-4b-v3/reranking")
if self.model_name == "nvidia/rerank-qa-mistral-4b":
self.base_url = urljoin(base_url, "reranking")
self.model_name = "nv-rerank-qa-mistral-4b:1"
self.headers = {
"accept": "application/json",
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
}
def similarity(self, query: str, texts: list):
token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
data = {
"model": self.model_name,
"query": {"text": query},
"passages": [{"text": text} for text in texts],
"truncate": "END",
"top_n": len(texts),
}
res = requests.post(self.base_url, headers=self.headers, json=data).json()
rank = np.zeros(len(texts), dtype=float)
try:
for d in res["rankings"]:
rank[d["index"]] = d["logit"]
except Exception as _e:
log_exception(_e, res)
return rank, token_count
class LmStudioRerank(Base):
_FACTORY_NAME = "LM-Studio"
def __init__(self, key, model_name, base_url, **kwargs):
pass
def similarity(self, query: str, texts: list):
raise NotImplementedError("The LmStudioRerank has not been implement")
class OpenAI_APIRerank(Base):
_FACTORY_NAME = "OpenAI-API-Compatible"
def __init__(self, key, model_name, base_url):
if base_url.find("/rerank") == -1:
self.base_url = urljoin(base_url, "/rerank")
else:
self.base_url = base_url
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
self.model_name = model_name.split("___")[0]
def similarity(self, query: str, texts: list):
# noway to config Ragflow , use fix setting
texts = [truncate(t, 500) for t in texts]
data = {
"model": self.model_name,
"query": query,
"documents": texts,
"top_n": len(texts),
}
token_count = 0
for t in texts:
token_count += num_tokens_from_string(t)
res = requests.post(self.base_url, headers=self.headers, json=data).json()
rank = np.zeros(len(texts), dtype=float)
try:
for d in res["results"]:
rank[d["index"]] = d["relevance_score"]
except Exception as _e:
log_exception(_e, res)
# Normalize the rank values to the range 0 to 1
min_rank = np.min(rank)
max_rank = np.max(rank)
# Avoid division by zero if all ranks are identical
if not np.isclose(min_rank, max_rank, atol=1e-3):
rank = (rank - min_rank) / (max_rank - min_rank)
else:
rank = np.zeros_like(rank)
return rank, token_count
class CoHereRerank(Base):
_FACTORY_NAME = ["Cohere", "VLLM"]
def __init__(self, key, model_name, base_url=None):
from cohere import Client
self.client = Client(api_key=key, base_url=base_url)
self.model_name = model_name.split("___")[0]
def similarity(self, query: str, texts: list):
token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
res = self.client.rerank(
model=self.model_name,
query=query,
documents=texts,
top_n=len(texts),
return_documents=False,
)
rank = np.zeros(len(texts), dtype=float)
try:
for d in res.results:
rank[d.index] = d.relevance_score
except Exception as _e:
log_exception(_e, res)
return rank, token_count
class TogetherAIRerank(Base):
_FACTORY_NAME = "TogetherAI"
def __init__(self, key, model_name, base_url, **kwargs):
pass
def similarity(self, query: str, texts: list):
raise NotImplementedError("The api has not been implement")
class SILICONFLOWRerank(Base):
_FACTORY_NAME = "SILICONFLOW"
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"):
if not base_url:
base_url = "https://api.siliconflow.cn/v1/rerank"
self.model_name = model_name
self.base_url = base_url
self.headers = {
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {key}",
}
def similarity(self, query: str, texts: list):
payload = {
"model": self.model_name,
"query": query,
"documents": texts,
"top_n": len(texts),
"return_documents": False,
"max_chunks_per_doc": 1024,
"overlap_tokens": 80,
}
response = requests.post(self.base_url, json=payload, headers=self.headers).json()
rank = np.zeros(len(texts), dtype=float)
try:
for d in response["results"]:
rank[d["index"]] = d["relevance_score"]
except Exception as _e:
log_exception(_e, response)
return (
rank,
total_token_count_from_response(response),
)
class BaiduYiyanRerank(Base):
_FACTORY_NAME = "BaiduYiyan"
def __init__(self, key, model_name, base_url=None):
from qianfan.resources import Reranker
key = json.loads(key)
ak = key.get("yiyan_ak", "")
sk = key.get("yiyan_sk", "")
self.client = Reranker(ak=ak, sk=sk)
self.model_name = model_name
def similarity(self, query: str, texts: list):
res = self.client.do(
model=self.model_name,
query=query,
documents=texts,
top_n=len(texts),
).body
rank = np.zeros(len(texts), dtype=float)
try:
for d in res["results"]:
rank[d["index"]] = d["relevance_score"]
except Exception as _e:
log_exception(_e, res)
return rank, total_token_count_from_response(res)
class VoyageRerank(Base):
_FACTORY_NAME = "Voyage AI"
def __init__(self, key, model_name, base_url=None):
import voyageai
self.client = voyageai.Client(api_key=key)
self.model_name = model_name
def similarity(self, query: str, texts: list):
if not texts:
return np.array([]), 0
rank = np.zeros(len(texts), dtype=float)
res = self.client.rerank(query=query, documents=texts, model=self.model_name, top_k=len(texts))
try:
for r in res.results:
rank[r.index] = r.relevance_score
except Exception as _e:
log_exception(_e, res)
return rank, res.total_tokens
class QWenRerank(Base):
_FACTORY_NAME = "Tongyi-Qianwen"
def __init__(self, key, model_name="gte-rerank", base_url=None, **kwargs):
import dashscope
self.api_key = key
self.model_name = dashscope.TextReRank.Models.gte_rerank if model_name is None else model_name
def similarity(self, query: str, texts: list):
from http import HTTPStatus
import dashscope
resp = dashscope.TextReRank.call(api_key=self.api_key, model=self.model_name, query=query, documents=texts, top_n=len(texts), return_documents=False)
rank = np.zeros(len(texts), dtype=float)
if resp.status_code == HTTPStatus.OK:
try:
for r in resp.output.results:
rank[r.index] = r.relevance_score
except Exception as _e:
log_exception(_e, resp)
return rank, total_token_count_from_response(resp)
else:
raise ValueError(f"Error calling QWenRerank model {self.model_name}: {resp.status_code} - {resp.text}")
class HuggingfaceRerank(Base):
_FACTORY_NAME = "HuggingFace"
@staticmethod
def post(query: str, texts: list, url="127.0.0.1"):
exc = None
scores = [0 for _ in range(len(texts))]
batch_size = 8
for i in range(0, len(texts), batch_size):
try:
res = requests.post(
f"http://{url}/rerank", headers={"Content-Type": "application/json"}, json={"query": query, "texts": texts[i : i + batch_size], "raw_scores": False, "truncate": True}
)
for o in res.json():
scores[o["index"] + i] = o["score"]
except Exception as e:
exc = e
if exc:
raise exc
return np.array(scores)
def __init__(self, key, model_name="BAAI/bge-reranker-v2-m3", base_url="http://127.0.0.1"):
self.model_name = model_name.split("___")[0]
self.base_url = base_url
def similarity(self, query: str, texts: list) -> tuple[np.ndarray, int]:
if not texts:
return np.array([]), 0
token_count = 0
for t in texts:
token_count += num_tokens_from_string(t)
return HuggingfaceRerank.post(query, texts, self.base_url), token_count
class GPUStackRerank(Base):
_FACTORY_NAME = "GPUStack"
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("url cannot be None")
self.model_name = model_name
self.base_url = str(URL(base_url) / "v1" / "rerank")
self.headers = {
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {key}",
}
def similarity(self, query: str, texts: list):
payload = {
"model": self.model_name,
"query": query,
"documents": texts,
"top_n": len(texts),
}
try:
response = requests.post(self.base_url, json=payload, headers=self.headers)
response.raise_for_status()
response_json = response.json()
rank = np.zeros(len(texts), dtype=float)
token_count = 0
for t in texts:
token_count += num_tokens_from_string(t)
try:
for result in response_json["results"]:
rank[result["index"]] = result["relevance_score"]
except Exception as _e:
log_exception(_e, response)
return (
rank,
token_count,
)
except httpx.HTTPStatusError as e:
raise ValueError(f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")
class NovitaRerank(JinaRerank):
_FACTORY_NAME = "NovitaAI"
def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/rerank"):
if not base_url:
base_url = "https://api.novita.ai/v3/openai/rerank"
super().__init__(key, model_name, base_url)
class GiteeRerank(JinaRerank):
_FACTORY_NAME = "GiteeAI"
def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/rerank"):
if not base_url:
base_url = "https://ai.gitee.com/v1/rerank"
super().__init__(key, model_name, base_url)
class Ai302Rerank(Base):
_FACTORY_NAME = "302.AI"
def __init__(self, key, model_name, base_url="https://api.302.ai/v1/rerank"):
if not base_url:
base_url = "https://api.302.ai/v1/rerank"
super().__init__(key, model_name, base_url)