| 
									
										
										
										
											2025-02-17 14:09:57 +08:00
										 |  |  | import concurrent.futures | 
					
						
							| 
									
										
										
										
											2025-02-19 09:13:36 +08:00
										 |  |  | from concurrent.futures import ThreadPoolExecutor | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  | from typing import Optional | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from flask import Flask, current_app | 
					
						
							| 
									
										
										
										
											2025-02-17 14:09:57 +08:00
										 |  |  | from sqlalchemy.orm import load_only | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-17 14:09:57 +08:00
										 |  |  | from configs import dify_config | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  | from core.rag.data_post_processor.data_post_processor import DataPostProcessor | 
					
						
							|  |  |  | from core.rag.datasource.keyword.keyword_factory import Keyword | 
					
						
							|  |  |  | from core.rag.datasource.vdb.vector_factory import Vector | 
					
						
							| 
									
										
										
										
											2024-12-25 19:49:07 +08:00
										 |  |  | from core.rag.embedding.retrieval import RetrievalSegments | 
					
						
							|  |  |  | from core.rag.index_processor.constant.index_type import IndexType | 
					
						
							| 
									
										
										
										
											2024-12-24 18:38:51 +08:00
										 |  |  | from core.rag.models.document import Document | 
					
						
							| 
									
										
										
										
											2024-10-17 19:12:42 +08:00
										 |  |  | from core.rag.rerank.rerank_type import RerankMode | 
					
						
							| 
									
										
										
										
											2024-09-08 12:14:11 +07:00
										 |  |  | from core.rag.retrieval.retrieval_methods import RetrievalMethod | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  | from extensions.ext_database import db | 
					
						
							| 
									
										
										
										
											2024-12-25 19:49:07 +08:00
										 |  |  | from models.dataset import ChildChunk, Dataset, DocumentSegment | 
					
						
							|  |  |  | from models.dataset import Document as DatasetDocument | 
					
						
							| 
									
										
										
										
											2024-09-30 15:38:43 +08:00
										 |  |  | from services.external_knowledge_service import ExternalDatasetService | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | default_retrieval_model = { | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |     "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, | 
					
						
							|  |  |  |     "reranking_enable": False, | 
					
						
							|  |  |  |     "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, | 
					
						
							|  |  |  |     "top_k": 2, | 
					
						
							|  |  |  |     "score_threshold_enabled": False, | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class RetrievalService: | 
					
						
							| 
									
										
										
										
											2025-02-17 14:09:57 +08:00
										 |  |  |     # Cache precompiled regular expressions to avoid repeated compilation | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  |     @classmethod | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |     def retrieve( | 
					
						
							|  |  |  |         cls, | 
					
						
							|  |  |  |         retrieval_method: str, | 
					
						
							|  |  |  |         dataset_id: str, | 
					
						
							|  |  |  |         query: str, | 
					
						
							|  |  |  |         top_k: int, | 
					
						
							|  |  |  |         score_threshold: Optional[float] = 0.0, | 
					
						
							|  |  |  |         reranking_model: Optional[dict] = None, | 
					
						
							| 
									
										
										
										
											2024-12-24 18:38:51 +08:00
										 |  |  |         reranking_mode: str = "reranking_model", | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         weights: Optional[dict] = None, | 
					
						
							| 
									
										
										
										
											2025-03-18 16:42:19 +08:00
										 |  |  |         document_ids_filter: Optional[list[str]] = None, | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |     ): | 
					
						
							| 
									
										
										
										
											2024-10-31 21:25:00 +08:00
										 |  |  |         if not query: | 
					
						
							|  |  |  |             return [] | 
					
						
							| 
									
										
										
										
											2025-02-17 14:09:57 +08:00
										 |  |  |         dataset = cls._get_dataset(dataset_id) | 
					
						
							| 
									
										
										
										
											2025-04-09 11:22:53 +08:00
										 |  |  |         if not dataset: | 
					
						
							| 
									
										
										
										
											2024-02-23 16:54:15 +08:00
										 |  |  |             return [] | 
					
						
							| 
									
										
										
										
											2025-02-17 14:09:57 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-24 18:38:51 +08:00
										 |  |  |         all_documents: list[Document] = [] | 
					
						
							|  |  |  |         exceptions: list[str] = [] | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-17 14:09:57 +08:00
										 |  |  |         # Optimize multithreading with thread pools | 
					
						
							| 
									
										
										
										
											2025-02-19 09:13:36 +08:00
										 |  |  |         with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor:  # type: ignore | 
					
						
							| 
									
										
										
										
											2025-02-17 14:09:57 +08:00
										 |  |  |             futures = [] | 
					
						
							|  |  |  |             if retrieval_method == "keyword_search": | 
					
						
							|  |  |  |                 futures.append( | 
					
						
							|  |  |  |                     executor.submit( | 
					
						
							|  |  |  |                         cls.keyword_search, | 
					
						
							|  |  |  |                         flask_app=current_app._get_current_object(),  # type: ignore | 
					
						
							|  |  |  |                         dataset_id=dataset_id, | 
					
						
							|  |  |  |                         query=query, | 
					
						
							|  |  |  |                         top_k=top_k, | 
					
						
							|  |  |  |                         all_documents=all_documents, | 
					
						
							|  |  |  |                         exceptions=exceptions, | 
					
						
							| 
									
										
										
										
											2025-03-18 16:42:19 +08:00
										 |  |  |                         document_ids_filter=document_ids_filter, | 
					
						
							| 
									
										
										
										
											2025-02-17 14:09:57 +08:00
										 |  |  |                     ) | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |             if RetrievalMethod.is_support_semantic_search(retrieval_method): | 
					
						
							|  |  |  |                 futures.append( | 
					
						
							|  |  |  |                     executor.submit( | 
					
						
							|  |  |  |                         cls.embedding_search, | 
					
						
							|  |  |  |                         flask_app=current_app._get_current_object(),  # type: ignore | 
					
						
							|  |  |  |                         dataset_id=dataset_id, | 
					
						
							|  |  |  |                         query=query, | 
					
						
							|  |  |  |                         top_k=top_k, | 
					
						
							|  |  |  |                         score_threshold=score_threshold, | 
					
						
							|  |  |  |                         reranking_model=reranking_model, | 
					
						
							|  |  |  |                         all_documents=all_documents, | 
					
						
							|  |  |  |                         retrieval_method=retrieval_method, | 
					
						
							|  |  |  |                         exceptions=exceptions, | 
					
						
							| 
									
										
										
										
											2025-03-18 16:42:19 +08:00
										 |  |  |                         document_ids_filter=document_ids_filter, | 
					
						
							| 
									
										
										
										
											2025-02-17 14:09:57 +08:00
										 |  |  |                     ) | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |             if RetrievalMethod.is_support_fulltext_search(retrieval_method): | 
					
						
							|  |  |  |                 futures.append( | 
					
						
							|  |  |  |                     executor.submit( | 
					
						
							|  |  |  |                         cls.full_text_index_search, | 
					
						
							|  |  |  |                         flask_app=current_app._get_current_object(),  # type: ignore | 
					
						
							|  |  |  |                         dataset_id=dataset_id, | 
					
						
							|  |  |  |                         query=query, | 
					
						
							|  |  |  |                         top_k=top_k, | 
					
						
							|  |  |  |                         score_threshold=score_threshold, | 
					
						
							|  |  |  |                         reranking_model=reranking_model, | 
					
						
							|  |  |  |                         all_documents=all_documents, | 
					
						
							|  |  |  |                         retrieval_method=retrieval_method, | 
					
						
							|  |  |  |                         exceptions=exceptions, | 
					
						
							| 
									
										
										
										
											2025-03-24 18:35:16 +08:00
										 |  |  |                         document_ids_filter=document_ids_filter, | 
					
						
							| 
									
										
										
										
											2025-02-17 14:09:57 +08:00
										 |  |  |                     ) | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |             concurrent.futures.wait(futures, timeout=30, return_when=concurrent.futures.ALL_COMPLETED) | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  |         if exceptions: | 
					
						
							| 
									
										
										
										
											2025-02-17 14:09:57 +08:00
										 |  |  |             raise ValueError(";\n".join(exceptions)) | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-08 12:14:11 +07:00
										 |  |  |         if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value: | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             data_post_processor = DataPostProcessor( | 
					
						
							|  |  |  |                 str(dataset.tenant_id), reranking_mode, reranking_model, weights, False | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  |             all_documents = data_post_processor.invoke( | 
					
						
							| 
									
										
										
										
											2024-11-30 11:14:45 +08:00
										 |  |  |                 query=query, | 
					
						
							|  |  |  |                 documents=all_documents, | 
					
						
							|  |  |  |                 score_threshold=score_threshold, | 
					
						
							| 
									
										
										
										
											2024-12-03 17:34:56 +08:00
										 |  |  |                 top_n=top_k, | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-11-30 11:14:45 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  |         return all_documents | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-30 15:38:43 +08:00
										 |  |  |     @classmethod | 
					
						
							|  |  |  |     def external_retrieve(cls, dataset_id: str, query: str, external_retrieval_model: Optional[dict] = None): | 
					
						
							|  |  |  |         dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | 
					
						
							|  |  |  |         if not dataset: | 
					
						
							|  |  |  |             return [] | 
					
						
							|  |  |  |         all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( | 
					
						
							| 
									
										
										
										
											2024-12-24 18:38:51 +08:00
										 |  |  |             dataset.tenant_id, dataset_id, query, external_retrieval_model or {} | 
					
						
							| 
									
										
										
										
											2024-09-30 15:38:43 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  |         return all_documents | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-17 14:09:57 +08:00
										 |  |  |     @classmethod | 
					
						
							|  |  |  |     def _get_dataset(cls, dataset_id: str) -> Optional[Dataset]: | 
					
						
							|  |  |  |         return db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  |     @classmethod | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |     def keyword_search( | 
					
						
							| 
									
										
										
										
											2025-03-18 16:42:19 +08:00
										 |  |  |         cls, | 
					
						
							|  |  |  |         flask_app: Flask, | 
					
						
							|  |  |  |         dataset_id: str, | 
					
						
							|  |  |  |         query: str, | 
					
						
							|  |  |  |         top_k: int, | 
					
						
							|  |  |  |         all_documents: list, | 
					
						
							|  |  |  |         exceptions: list, | 
					
						
							|  |  |  |         document_ids_filter: Optional[list[str]] = None, | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |     ): | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  |         with flask_app.app_context(): | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  |             try: | 
					
						
							| 
									
										
										
										
											2025-02-17 14:09:57 +08:00
										 |  |  |                 dataset = cls._get_dataset(dataset_id) | 
					
						
							| 
									
										
										
										
											2024-12-24 18:38:51 +08:00
										 |  |  |                 if not dataset: | 
					
						
							|  |  |  |                     raise ValueError("dataset not found") | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 keyword = Keyword(dataset=dataset) | 
					
						
							| 
									
										
										
										
											2025-03-18 16:42:19 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |                 documents = keyword.search( | 
					
						
							|  |  |  |                     cls.escape_query_for_search(query), top_k=top_k, document_ids_filter=document_ids_filter | 
					
						
							|  |  |  |                 ) | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  |                 all_documents.extend(documents) | 
					
						
							|  |  |  |             except Exception as e: | 
					
						
							|  |  |  |                 exceptions.append(str(e)) | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |     def embedding_search( | 
					
						
							|  |  |  |         cls, | 
					
						
							|  |  |  |         flask_app: Flask, | 
					
						
							|  |  |  |         dataset_id: str, | 
					
						
							|  |  |  |         query: str, | 
					
						
							|  |  |  |         top_k: int, | 
					
						
							|  |  |  |         score_threshold: Optional[float], | 
					
						
							|  |  |  |         reranking_model: Optional[dict], | 
					
						
							|  |  |  |         all_documents: list, | 
					
						
							|  |  |  |         retrieval_method: str, | 
					
						
							|  |  |  |         exceptions: list, | 
					
						
							| 
									
										
										
										
											2025-03-18 16:42:19 +08:00
										 |  |  |         document_ids_filter: Optional[list[str]] = None, | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |     ): | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  |         with flask_app.app_context(): | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  |             try: | 
					
						
							| 
									
										
										
										
											2025-02-17 14:09:57 +08:00
										 |  |  |                 dataset = cls._get_dataset(dataset_id) | 
					
						
							| 
									
										
										
										
											2024-12-24 18:38:51 +08:00
										 |  |  |                 if not dataset: | 
					
						
							|  |  |  |                     raise ValueError("dataset not found") | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 vector = Vector(dataset=dataset) | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  |                 documents = vector.search_by_vector( | 
					
						
							| 
									
										
										
										
											2025-02-11 09:58:31 +08:00
										 |  |  |                     query, | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                     search_type="similarity_score_threshold", | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  |                     top_k=top_k, | 
					
						
							|  |  |  |                     score_threshold=score_threshold, | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                     filter={"group_id": [dataset.id]}, | 
					
						
							| 
									
										
										
										
											2025-03-18 16:42:19 +08:00
										 |  |  |                     document_ids_filter=document_ids_filter, | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  |                 ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 if documents: | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                     if ( | 
					
						
							|  |  |  |                         reranking_model | 
					
						
							|  |  |  |                         and reranking_model.get("reranking_model_name") | 
					
						
							|  |  |  |                         and reranking_model.get("reranking_provider_name") | 
					
						
							|  |  |  |                         and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value | 
					
						
							|  |  |  |                     ): | 
					
						
							|  |  |  |                         data_post_processor = DataPostProcessor( | 
					
						
							| 
									
										
										
										
											2025-02-17 14:09:57 +08:00
										 |  |  |                             str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL.value), reranking_model, None, False | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                         ) | 
					
						
							|  |  |  |                         all_documents.extend( | 
					
						
							|  |  |  |                             data_post_processor.invoke( | 
					
						
							| 
									
										
										
										
											2024-11-30 11:14:45 +08:00
										 |  |  |                                 query=query, | 
					
						
							|  |  |  |                                 documents=documents, | 
					
						
							|  |  |  |                                 score_threshold=score_threshold, | 
					
						
							| 
									
										
										
										
											2024-12-03 17:34:56 +08:00
										 |  |  |                                 top_n=len(documents), | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                             ) | 
					
						
							|  |  |  |                         ) | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  |                     else: | 
					
						
							|  |  |  |                         all_documents.extend(documents) | 
					
						
							|  |  |  |             except Exception as e: | 
					
						
							|  |  |  |                 exceptions.append(str(e)) | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |     def full_text_index_search( | 
					
						
							|  |  |  |         cls, | 
					
						
							|  |  |  |         flask_app: Flask, | 
					
						
							|  |  |  |         dataset_id: str, | 
					
						
							|  |  |  |         query: str, | 
					
						
							|  |  |  |         top_k: int, | 
					
						
							|  |  |  |         score_threshold: Optional[float], | 
					
						
							|  |  |  |         reranking_model: Optional[dict], | 
					
						
							|  |  |  |         all_documents: list, | 
					
						
							|  |  |  |         retrieval_method: str, | 
					
						
							|  |  |  |         exceptions: list, | 
					
						
							| 
									
										
										
										
											2025-03-24 18:35:16 +08:00
										 |  |  |         document_ids_filter: Optional[list[str]] = None, | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |     ): | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  |         with flask_app.app_context(): | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  |             try: | 
					
						
							| 
									
										
										
										
											2025-02-17 14:09:57 +08:00
										 |  |  |                 dataset = cls._get_dataset(dataset_id) | 
					
						
							| 
									
										
										
										
											2024-12-24 18:38:51 +08:00
										 |  |  |                 if not dataset: | 
					
						
							|  |  |  |                     raise ValueError("dataset not found") | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-17 14:09:57 +08:00
										 |  |  |                 vector_processor = Vector(dataset=dataset) | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-03-24 18:35:16 +08:00
										 |  |  |                 documents = vector_processor.search_by_full_text( | 
					
						
							|  |  |  |                     cls.escape_query_for_search(query), top_k=top_k, document_ids_filter=document_ids_filter | 
					
						
							|  |  |  |                 ) | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  |                 if documents: | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                     if ( | 
					
						
							|  |  |  |                         reranking_model | 
					
						
							|  |  |  |                         and reranking_model.get("reranking_model_name") | 
					
						
							|  |  |  |                         and reranking_model.get("reranking_provider_name") | 
					
						
							|  |  |  |                         and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value | 
					
						
							|  |  |  |                     ): | 
					
						
							|  |  |  |                         data_post_processor = DataPostProcessor( | 
					
						
							| 
									
										
										
										
											2025-02-17 14:09:57 +08:00
										 |  |  |                             str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL.value), reranking_model, None, False | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                         ) | 
					
						
							|  |  |  |                         all_documents.extend( | 
					
						
							|  |  |  |                             data_post_processor.invoke( | 
					
						
							| 
									
										
										
										
											2024-11-30 11:14:45 +08:00
										 |  |  |                                 query=query, | 
					
						
							|  |  |  |                                 documents=documents, | 
					
						
							|  |  |  |                                 score_threshold=score_threshold, | 
					
						
							| 
									
										
										
										
											2024-12-03 17:34:56 +08:00
										 |  |  |                                 top_n=len(documents), | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                             ) | 
					
						
							|  |  |  |                         ) | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  |                     else: | 
					
						
							|  |  |  |                         all_documents.extend(documents) | 
					
						
							|  |  |  |             except Exception as e: | 
					
						
							|  |  |  |                 exceptions.append(str(e)) | 
					
						
							| 
									
										
										
										
											2024-07-23 16:02:25 +09:00
										 |  |  | 
 | 
					
						
							|  |  |  |     @staticmethod | 
					
						
							|  |  |  |     def escape_query_for_search(query: str) -> str: | 
					
						
							| 
									
										
										
										
											2025-03-12 18:34:42 +08:00
										 |  |  |         return query.replace('"', '\\"') | 
					
						
							| 
									
										
										
										
											2024-12-25 19:49:07 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-17 14:09:57 +08:00
										 |  |  |     @classmethod | 
					
						
							|  |  |  |     def format_retrieval_documents(cls, documents: list[Document]) -> list[RetrievalSegments]: | 
					
						
							|  |  |  |         """Format retrieval documents with optimized batch processing""" | 
					
						
							|  |  |  |         if not documents: | 
					
						
							|  |  |  |             return [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         try: | 
					
						
							| 
									
										
										
										
											2025-04-09 20:25:36 +08:00
										 |  |  |             # Collect document IDs | 
					
						
							| 
									
										
										
										
											2025-02-17 14:09:57 +08:00
										 |  |  |             document_ids = {doc.metadata.get("document_id") for doc in documents if "document_id" in doc.metadata} | 
					
						
							|  |  |  |             if not document_ids: | 
					
						
							|  |  |  |                 return [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # Batch query dataset documents | 
					
						
							|  |  |  |             dataset_documents = { | 
					
						
							|  |  |  |                 doc.id: doc | 
					
						
							|  |  |  |                 for doc in db.session.query(DatasetDocument) | 
					
						
							|  |  |  |                 .filter(DatasetDocument.id.in_(document_ids)) | 
					
						
							|  |  |  |                 .options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id)) | 
					
						
							|  |  |  |                 .all() | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             records = [] | 
					
						
							|  |  |  |             include_segment_ids = set() | 
					
						
							|  |  |  |             segment_child_map = {} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-04-09 20:25:36 +08:00
										 |  |  |             # Process documents | 
					
						
							| 
									
										
										
										
											2025-02-17 14:09:57 +08:00
										 |  |  |             for document in documents: | 
					
						
							|  |  |  |                 document_id = document.metadata.get("document_id") | 
					
						
							| 
									
										
										
										
											2025-04-09 20:25:36 +08:00
										 |  |  |                 if document_id not in dataset_documents: | 
					
						
							|  |  |  |                     continue | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 dataset_document = dataset_documents[document_id] | 
					
						
							| 
									
										
										
										
											2025-03-14 16:40:01 +08:00
										 |  |  |                 if not dataset_document: | 
					
						
							|  |  |  |                     continue | 
					
						
							| 
									
										
										
										
											2025-02-17 14:09:57 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-04-09 20:25:36 +08:00
										 |  |  |                 if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: | 
					
						
							|  |  |  |                     # Handle parent-child documents | 
					
						
							| 
									
										
										
										
											2024-12-26 00:16:35 +08:00
										 |  |  |                     child_index_node_id = document.metadata.get("doc_id") | 
					
						
							| 
									
										
										
										
											2025-02-17 14:09:57 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-04-09 20:25:36 +08:00
										 |  |  |                     child_chunk = ( | 
					
						
							|  |  |  |                         db.session.query(ChildChunk).filter(ChildChunk.index_node_id == child_index_node_id).first() | 
					
						
							|  |  |  |                     ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-17 14:09:57 +08:00
										 |  |  |                     if not child_chunk: | 
					
						
							|  |  |  |                         continue | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-04-09 20:25:36 +08:00
										 |  |  |                     segment = ( | 
					
						
							|  |  |  |                         db.session.query(DocumentSegment) | 
					
						
							|  |  |  |                         .filter( | 
					
						
							|  |  |  |                             DocumentSegment.dataset_id == dataset_document.dataset_id, | 
					
						
							|  |  |  |                             DocumentSegment.enabled == True, | 
					
						
							|  |  |  |                             DocumentSegment.status == "completed", | 
					
						
							|  |  |  |                             DocumentSegment.id == child_chunk.segment_id, | 
					
						
							|  |  |  |                         ) | 
					
						
							|  |  |  |                         .options( | 
					
						
							|  |  |  |                             load_only( | 
					
						
							|  |  |  |                                 DocumentSegment.id, | 
					
						
							|  |  |  |                                 DocumentSegment.content, | 
					
						
							|  |  |  |                                 DocumentSegment.answer, | 
					
						
							|  |  |  |                             ) | 
					
						
							|  |  |  |                         ) | 
					
						
							|  |  |  |                         .first() | 
					
						
							|  |  |  |                     ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-17 14:09:57 +08:00
										 |  |  |                     if not segment: | 
					
						
							| 
									
										
										
										
											2024-12-26 00:16:35 +08:00
										 |  |  |                         continue | 
					
						
							| 
									
										
										
										
											2025-02-17 14:09:57 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |                     if segment.id not in include_segment_ids: | 
					
						
							|  |  |  |                         include_segment_ids.add(segment.id) | 
					
						
							| 
									
										
										
										
											2025-04-09 20:25:36 +08:00
										 |  |  |                         child_chunk_detail = { | 
					
						
							|  |  |  |                             "id": child_chunk.id, | 
					
						
							|  |  |  |                             "content": child_chunk.content, | 
					
						
							|  |  |  |                             "position": child_chunk.position, | 
					
						
							|  |  |  |                             "score": document.metadata.get("score", 0.0), | 
					
						
							|  |  |  |                         } | 
					
						
							|  |  |  |                         map_detail = { | 
					
						
							|  |  |  |                             "max_score": document.metadata.get("score", 0.0), | 
					
						
							|  |  |  |                             "child_chunks": [child_chunk_detail], | 
					
						
							|  |  |  |                         } | 
					
						
							| 
									
										
										
										
											2025-02-17 14:09:57 +08:00
										 |  |  |                         segment_child_map[segment.id] = map_detail | 
					
						
							| 
									
										
										
										
											2025-04-09 20:25:36 +08:00
										 |  |  |                         record = { | 
					
						
							|  |  |  |                             "segment": segment, | 
					
						
							|  |  |  |                         } | 
					
						
							|  |  |  |                         records.append(record) | 
					
						
							|  |  |  |                     else: | 
					
						
							|  |  |  |                         child_chunk_detail = { | 
					
						
							|  |  |  |                             "id": child_chunk.id, | 
					
						
							|  |  |  |                             "content": child_chunk.content, | 
					
						
							|  |  |  |                             "position": child_chunk.position, | 
					
						
							|  |  |  |                             "score": document.metadata.get("score", 0.0), | 
					
						
							|  |  |  |                         } | 
					
						
							|  |  |  |                         segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) | 
					
						
							|  |  |  |                         segment_child_map[segment.id]["max_score"] = max( | 
					
						
							|  |  |  |                             segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) | 
					
						
							|  |  |  |                         ) | 
					
						
							| 
									
										
										
										
											2024-12-25 19:49:07 +08:00
										 |  |  |                 else: | 
					
						
							| 
									
										
										
										
											2025-02-17 14:09:57 +08:00
										 |  |  |                     # Handle normal documents | 
					
						
							|  |  |  |                     index_node_id = document.metadata.get("doc_id") | 
					
						
							|  |  |  |                     if not index_node_id: | 
					
						
							|  |  |  |                         continue | 
					
						
							| 
									
										
										
										
											2024-12-25 19:49:07 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-04-09 20:25:36 +08:00
										 |  |  |                     segment = ( | 
					
						
							|  |  |  |                         db.session.query(DocumentSegment) | 
					
						
							|  |  |  |                         .filter( | 
					
						
							|  |  |  |                             DocumentSegment.dataset_id == dataset_document.dataset_id, | 
					
						
							|  |  |  |                             DocumentSegment.enabled == True, | 
					
						
							|  |  |  |                             DocumentSegment.status == "completed", | 
					
						
							|  |  |  |                             DocumentSegment.index_node_id == index_node_id, | 
					
						
							|  |  |  |                         ) | 
					
						
							|  |  |  |                         .first() | 
					
						
							| 
									
										
										
										
											2024-12-25 19:49:07 +08:00
										 |  |  |                     ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-26 00:16:35 +08:00
										 |  |  |                     if not segment: | 
					
						
							|  |  |  |                         continue | 
					
						
							| 
									
										
										
										
											2025-02-17 14:09:57 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-04-09 20:25:36 +08:00
										 |  |  |                     include_segment_ids.add(segment.id) | 
					
						
							|  |  |  |                     record = { | 
					
						
							|  |  |  |                         "segment": segment, | 
					
						
							|  |  |  |                         "score": document.metadata.get("score"),  # type: ignore | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |                     records.append(record) | 
					
						
							| 
									
										
										
										
											2025-02-17 14:09:57 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-04-09 20:25:36 +08:00
										 |  |  |             # Add child chunks information to records | 
					
						
							| 
									
										
										
										
											2024-12-25 19:49:07 +08:00
										 |  |  |             for record in records: | 
					
						
							| 
									
										
										
										
											2025-04-09 20:25:36 +08:00
										 |  |  |                 if record["segment"].id in segment_child_map: | 
					
						
							|  |  |  |                     record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks")  # type: ignore | 
					
						
							|  |  |  |                     record["score"] = segment_child_map[record["segment"].id]["max_score"] | 
					
						
							| 
									
										
										
										
											2024-12-25 19:49:07 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-17 14:09:57 +08:00
										 |  |  |             return [RetrievalSegments(**record) for record in records] | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             db.session.rollback() | 
					
						
							|  |  |  |             raise e |