mirror of
https://github.com/langgenius/dify.git
synced 2025-12-21 07:02:18 +00:00
192 lines
8.0 KiB
Python
192 lines
8.0 KiB
Python
import base64
|
|
|
|
from core.model_manager import ModelInstance, ModelManager
|
|
from core.model_runtime.entities.model_entities import ModelType
|
|
from core.model_runtime.entities.rerank_entities import RerankResult
|
|
from core.rag.index_processor.constant.doc_type import DocType
|
|
from core.rag.index_processor.constant.query_type import QueryType
|
|
from core.rag.models.document import Document
|
|
from core.rag.rerank.rerank_base import BaseRerankRunner
|
|
from extensions.ext_database import db
|
|
from extensions.ext_storage import storage
|
|
from models.model import UploadFile
|
|
|
|
|
|
class RerankModelRunner(BaseRerankRunner):
|
|
def __init__(self, rerank_model_instance: ModelInstance):
|
|
self.rerank_model_instance = rerank_model_instance
|
|
|
|
def run(
|
|
self,
|
|
query: str,
|
|
documents: list[Document],
|
|
score_threshold: float | None = None,
|
|
top_n: int | None = None,
|
|
user: str | None = None,
|
|
query_type: QueryType = QueryType.TEXT_QUERY,
|
|
) -> list[Document]:
|
|
"""
|
|
Run rerank model
|
|
:param query: search query
|
|
:param documents: documents for reranking
|
|
:param score_threshold: score threshold
|
|
:param top_n: top n
|
|
:param user: unique user id if needed
|
|
:return:
|
|
"""
|
|
model_manager = ModelManager()
|
|
is_support_vision = model_manager.check_model_support_vision(
|
|
tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id,
|
|
provider=self.rerank_model_instance.provider,
|
|
model=self.rerank_model_instance.model,
|
|
model_type=ModelType.RERANK,
|
|
)
|
|
if not is_support_vision:
|
|
if query_type == QueryType.TEXT_QUERY:
|
|
rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user)
|
|
else:
|
|
return documents
|
|
else:
|
|
rerank_result, unique_documents = self.fetch_multimodal_rerank(
|
|
query, documents, score_threshold, top_n, user, query_type
|
|
)
|
|
|
|
rerank_documents = []
|
|
for result in rerank_result.docs:
|
|
if score_threshold is None or result.score >= score_threshold:
|
|
# format document
|
|
rerank_document = Document(
|
|
page_content=result.text,
|
|
metadata=unique_documents[result.index].metadata,
|
|
provider=unique_documents[result.index].provider,
|
|
)
|
|
if rerank_document.metadata is not None:
|
|
rerank_document.metadata["score"] = result.score
|
|
rerank_documents.append(rerank_document)
|
|
|
|
rerank_documents.sort(key=lambda x: x.metadata.get("score", 0.0), reverse=True)
|
|
return rerank_documents[:top_n] if top_n else rerank_documents
|
|
|
|
def fetch_text_rerank(
|
|
self,
|
|
query: str,
|
|
documents: list[Document],
|
|
score_threshold: float | None = None,
|
|
top_n: int | None = None,
|
|
user: str | None = None,
|
|
) -> tuple[RerankResult, list[Document]]:
|
|
"""
|
|
Fetch text rerank
|
|
:param query: search query
|
|
:param documents: documents for reranking
|
|
:param score_threshold: score threshold
|
|
:param top_n: top n
|
|
:param user: unique user id if needed
|
|
:return:
|
|
"""
|
|
docs = []
|
|
doc_ids = set()
|
|
unique_documents = []
|
|
for document in documents:
|
|
if (
|
|
document.provider == "dify"
|
|
and document.metadata is not None
|
|
and document.metadata["doc_id"] not in doc_ids
|
|
):
|
|
if not document.metadata.get("doc_type") or document.metadata.get("doc_type") == DocType.TEXT:
|
|
doc_ids.add(document.metadata["doc_id"])
|
|
docs.append(document.page_content)
|
|
unique_documents.append(document)
|
|
elif document.provider == "external":
|
|
if document not in unique_documents:
|
|
docs.append(document.page_content)
|
|
unique_documents.append(document)
|
|
|
|
rerank_result = self.rerank_model_instance.invoke_rerank(
|
|
query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user
|
|
)
|
|
return rerank_result, unique_documents
|
|
|
|
def fetch_multimodal_rerank(
|
|
self,
|
|
query: str,
|
|
documents: list[Document],
|
|
score_threshold: float | None = None,
|
|
top_n: int | None = None,
|
|
user: str | None = None,
|
|
query_type: QueryType = QueryType.TEXT_QUERY,
|
|
) -> tuple[RerankResult, list[Document]]:
|
|
"""
|
|
Fetch multimodal rerank
|
|
:param query: search query
|
|
:param documents: documents for reranking
|
|
:param score_threshold: score threshold
|
|
:param top_n: top n
|
|
:param user: unique user id if needed
|
|
:param query_type: query type
|
|
:return: rerank result
|
|
"""
|
|
docs = []
|
|
doc_ids = set()
|
|
unique_documents = []
|
|
for document in documents:
|
|
if (
|
|
document.provider == "dify"
|
|
and document.metadata is not None
|
|
and document.metadata["doc_id"] not in doc_ids
|
|
):
|
|
if document.metadata.get("doc_type") == DocType.IMAGE:
|
|
# Query file info within db.session context to ensure thread-safe access
|
|
upload_file = (
|
|
db.session.query(UploadFile).where(UploadFile.id == document.metadata["doc_id"]).first()
|
|
)
|
|
if upload_file:
|
|
blob = storage.load_once(upload_file.key)
|
|
document_file_base64 = base64.b64encode(blob).decode()
|
|
document_file_dict = {
|
|
"content": document_file_base64,
|
|
"content_type": document.metadata["doc_type"],
|
|
}
|
|
docs.append(document_file_dict)
|
|
else:
|
|
document_text_dict = {
|
|
"content": document.page_content,
|
|
"content_type": document.metadata.get("doc_type") or DocType.TEXT,
|
|
}
|
|
docs.append(document_text_dict)
|
|
doc_ids.add(document.metadata["doc_id"])
|
|
unique_documents.append(document)
|
|
elif document.provider == "external":
|
|
if document not in unique_documents:
|
|
docs.append(
|
|
{
|
|
"content": document.page_content,
|
|
"content_type": document.metadata.get("doc_type") or DocType.TEXT,
|
|
}
|
|
)
|
|
unique_documents.append(document)
|
|
|
|
documents = unique_documents
|
|
if query_type == QueryType.TEXT_QUERY:
|
|
rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user)
|
|
return rerank_result, unique_documents
|
|
elif query_type == QueryType.IMAGE_QUERY:
|
|
# Query file info within db.session context to ensure thread-safe access
|
|
upload_file = db.session.query(UploadFile).where(UploadFile.id == query).first()
|
|
if upload_file:
|
|
blob = storage.load_once(upload_file.key)
|
|
file_query = base64.b64encode(blob).decode()
|
|
file_query_dict = {
|
|
"content": file_query,
|
|
"content_type": DocType.IMAGE,
|
|
}
|
|
rerank_result = self.rerank_model_instance.invoke_multimodal_rerank(
|
|
query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user
|
|
)
|
|
return rerank_result, unique_documents
|
|
else:
|
|
raise ValueError(f"Upload file not found for query: {query}")
|
|
|
|
else:
|
|
raise ValueError(f"Query type {query_type} is not supported")
|