| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  | import threading | 
					
						
							|  |  |  | from typing import Optional | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from flask import Flask, current_app | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from core.rag.data_post_processor.data_post_processor import DataPostProcessor | 
					
						
							|  |  |  | from core.rag.datasource.keyword.keyword_factory import Keyword | 
					
						
							|  |  |  | from core.rag.datasource.vdb.vector_factory import Vector | 
					
						
							| 
									
										
										
										
											2024-10-17 19:12:42 +08:00
										 |  |  | from core.rag.rerank.rerank_type import RerankMode | 
					
						
							| 
									
										
										
										
											2024-09-08 12:14:11 +07:00
										 |  |  | from core.rag.retrieval.retrieval_methods import RetrievalMethod | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  | from extensions.ext_database import db | 
					
						
							|  |  |  | from models.dataset import Dataset | 
					
						
							| 
									
										
										
										
											2024-09-30 15:38:43 +08:00
										 |  |  | from services.external_knowledge_service import ExternalDatasetService | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | default_retrieval_model = { | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |     "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, | 
					
						
							|  |  |  |     "reranking_enable": False, | 
					
						
							|  |  |  |     "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, | 
					
						
							|  |  |  |     "top_k": 2, | 
					
						
							|  |  |  |     "score_threshold_enabled": False, | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class RetrievalService: | 
					
						
							|  |  |  |     @classmethod | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |     def retrieve( | 
					
						
							|  |  |  |         cls, | 
					
						
							|  |  |  |         retrieval_method: str, | 
					
						
							|  |  |  |         dataset_id: str, | 
					
						
							|  |  |  |         query: str, | 
					
						
							|  |  |  |         top_k: int, | 
					
						
							|  |  |  |         score_threshold: Optional[float] = 0.0, | 
					
						
							|  |  |  |         reranking_model: Optional[dict] = None, | 
					
						
							|  |  |  |         reranking_mode: Optional[str] = "reranking_model", | 
					
						
							|  |  |  |         weights: Optional[dict] = None, | 
					
						
							|  |  |  |     ): | 
					
						
							| 
									
										
										
										
											2024-10-31 21:25:00 +08:00
										 |  |  |         if not query: | 
					
						
							|  |  |  |             return [] | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | 
					
						
							| 
									
										
										
										
											2024-09-30 15:38:43 +08:00
										 |  |  |         if not dataset: | 
					
						
							|  |  |  |             return [] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-23 16:54:15 +08:00
										 |  |  |         if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0: | 
					
						
							|  |  |  |             return [] | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  |         all_documents = [] | 
					
						
							|  |  |  |         threads = [] | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  |         exceptions = [] | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  |         # retrieval_model source with keyword | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         if retrieval_method == "keyword_search": | 
					
						
							|  |  |  |             keyword_thread = threading.Thread( | 
					
						
							|  |  |  |                 target=RetrievalService.keyword_search, | 
					
						
							|  |  |  |                 kwargs={ | 
					
						
							|  |  |  |                     "flask_app": current_app._get_current_object(), | 
					
						
							|  |  |  |                     "dataset_id": dataset_id, | 
					
						
							|  |  |  |                     "query": query, | 
					
						
							|  |  |  |                     "top_k": top_k, | 
					
						
							|  |  |  |                     "all_documents": all_documents, | 
					
						
							|  |  |  |                     "exceptions": exceptions, | 
					
						
							|  |  |  |                 }, | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  |             threads.append(keyword_thread) | 
					
						
							|  |  |  |             keyword_thread.start() | 
					
						
							|  |  |  |         # retrieval_model source with semantic | 
					
						
							| 
									
										
										
										
											2024-09-08 12:14:11 +07:00
										 |  |  |         if RetrievalMethod.is_support_semantic_search(retrieval_method): | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             embedding_thread = threading.Thread( | 
					
						
							|  |  |  |                 target=RetrievalService.embedding_search, | 
					
						
							|  |  |  |                 kwargs={ | 
					
						
							|  |  |  |                     "flask_app": current_app._get_current_object(), | 
					
						
							|  |  |  |                     "dataset_id": dataset_id, | 
					
						
							|  |  |  |                     "query": query, | 
					
						
							|  |  |  |                     "top_k": top_k, | 
					
						
							|  |  |  |                     "score_threshold": score_threshold, | 
					
						
							|  |  |  |                     "reranking_model": reranking_model, | 
					
						
							|  |  |  |                     "all_documents": all_documents, | 
					
						
							|  |  |  |                     "retrieval_method": retrieval_method, | 
					
						
							|  |  |  |                     "exceptions": exceptions, | 
					
						
							|  |  |  |                 }, | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  |             threads.append(embedding_thread) | 
					
						
							|  |  |  |             embedding_thread.start() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # retrieval source with full text | 
					
						
							| 
									
										
										
										
											2024-09-08 12:14:11 +07:00
										 |  |  |         if RetrievalMethod.is_support_fulltext_search(retrieval_method): | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             full_text_index_thread = threading.Thread( | 
					
						
							|  |  |  |                 target=RetrievalService.full_text_index_search, | 
					
						
							|  |  |  |                 kwargs={ | 
					
						
							|  |  |  |                     "flask_app": current_app._get_current_object(), | 
					
						
							|  |  |  |                     "dataset_id": dataset_id, | 
					
						
							|  |  |  |                     "query": query, | 
					
						
							|  |  |  |                     "retrieval_method": retrieval_method, | 
					
						
							|  |  |  |                     "score_threshold": score_threshold, | 
					
						
							|  |  |  |                     "top_k": top_k, | 
					
						
							|  |  |  |                     "reranking_model": reranking_model, | 
					
						
							|  |  |  |                     "all_documents": all_documents, | 
					
						
							|  |  |  |                     "exceptions": exceptions, | 
					
						
							|  |  |  |                 }, | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  |             threads.append(full_text_index_thread) | 
					
						
							|  |  |  |             full_text_index_thread.start() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for thread in threads: | 
					
						
							|  |  |  |             thread.join() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  |         if exceptions: | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             exception_message = ";\n".join(exceptions) | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  |             raise Exception(exception_message) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-08 12:14:11 +07:00
										 |  |  |         if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value: | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             data_post_processor = DataPostProcessor( | 
					
						
							|  |  |  |                 str(dataset.tenant_id), reranking_mode, reranking_model, weights, False | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  |             all_documents = data_post_processor.invoke( | 
					
						
							| 
									
										
										
										
											2024-11-30 11:14:45 +08:00
										 |  |  |                 query=query, | 
					
						
							|  |  |  |                 documents=all_documents, | 
					
						
							|  |  |  |                 score_threshold=score_threshold, | 
					
						
							| 
									
										
										
										
											2024-12-03 17:34:56 +08:00
										 |  |  |                 top_n=top_k, | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-11-30 11:14:45 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  |         return all_documents | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-30 15:38:43 +08:00
										 |  |  |     @classmethod | 
					
						
							|  |  |  |     def external_retrieve(cls, dataset_id: str, query: str, external_retrieval_model: Optional[dict] = None): | 
					
						
							|  |  |  |         dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | 
					
						
							|  |  |  |         if not dataset: | 
					
						
							|  |  |  |             return [] | 
					
						
							|  |  |  |         all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( | 
					
						
							|  |  |  |             dataset.tenant_id, dataset_id, query, external_retrieval_model | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         return all_documents | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  |     @classmethod | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |     def keyword_search( | 
					
						
							|  |  |  |         cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list | 
					
						
							|  |  |  |     ): | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  |         with flask_app.app_context(): | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  |             try: | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 keyword = Keyword(dataset=dataset) | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 documents = keyword.search(cls.escape_query_for_search(query), top_k=top_k) | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  |                 all_documents.extend(documents) | 
					
						
							|  |  |  |             except Exception as e: | 
					
						
							|  |  |  |                 exceptions.append(str(e)) | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |     def embedding_search( | 
					
						
							|  |  |  |         cls, | 
					
						
							|  |  |  |         flask_app: Flask, | 
					
						
							|  |  |  |         dataset_id: str, | 
					
						
							|  |  |  |         query: str, | 
					
						
							|  |  |  |         top_k: int, | 
					
						
							|  |  |  |         score_threshold: Optional[float], | 
					
						
							|  |  |  |         reranking_model: Optional[dict], | 
					
						
							|  |  |  |         all_documents: list, | 
					
						
							|  |  |  |         retrieval_method: str, | 
					
						
							|  |  |  |         exceptions: list, | 
					
						
							|  |  |  |     ): | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  |         with flask_app.app_context(): | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  |             try: | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 vector = Vector(dataset=dataset) | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |                 documents = vector.search_by_vector( | 
					
						
							| 
									
										
										
										
											2024-07-23 16:02:25 +09:00
										 |  |  |                     cls.escape_query_for_search(query), | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                     search_type="similarity_score_threshold", | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  |                     top_k=top_k, | 
					
						
							|  |  |  |                     score_threshold=score_threshold, | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                     filter={"group_id": [dataset.id]}, | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  |                 ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 if documents: | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                     if ( | 
					
						
							|  |  |  |                         reranking_model | 
					
						
							|  |  |  |                         and reranking_model.get("reranking_model_name") | 
					
						
							|  |  |  |                         and reranking_model.get("reranking_provider_name") | 
					
						
							|  |  |  |                         and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value | 
					
						
							|  |  |  |                     ): | 
					
						
							|  |  |  |                         data_post_processor = DataPostProcessor( | 
					
						
							|  |  |  |                             str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False | 
					
						
							|  |  |  |                         ) | 
					
						
							|  |  |  |                         all_documents.extend( | 
					
						
							|  |  |  |                             data_post_processor.invoke( | 
					
						
							| 
									
										
										
										
											2024-11-30 11:14:45 +08:00
										 |  |  |                                 query=query, | 
					
						
							|  |  |  |                                 documents=documents, | 
					
						
							|  |  |  |                                 score_threshold=score_threshold, | 
					
						
							| 
									
										
										
										
											2024-12-03 17:34:56 +08:00
										 |  |  |                                 top_n=len(documents), | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                             ) | 
					
						
							|  |  |  |                         ) | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  |                     else: | 
					
						
							|  |  |  |                         all_documents.extend(documents) | 
					
						
							|  |  |  |             except Exception as e: | 
					
						
							|  |  |  |                 exceptions.append(str(e)) | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |     def full_text_index_search( | 
					
						
							|  |  |  |         cls, | 
					
						
							|  |  |  |         flask_app: Flask, | 
					
						
							|  |  |  |         dataset_id: str, | 
					
						
							|  |  |  |         query: str, | 
					
						
							|  |  |  |         top_k: int, | 
					
						
							|  |  |  |         score_threshold: Optional[float], | 
					
						
							|  |  |  |         reranking_model: Optional[dict], | 
					
						
							|  |  |  |         all_documents: list, | 
					
						
							|  |  |  |         retrieval_method: str, | 
					
						
							|  |  |  |         exceptions: list, | 
					
						
							|  |  |  |     ): | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  |         with flask_app.app_context(): | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  |             try: | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |                 vector_processor = Vector( | 
					
						
							|  |  |  |                     dataset=dataset, | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 documents = vector_processor.search_by_full_text(cls.escape_query_for_search(query), top_k=top_k) | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  |                 if documents: | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                     if ( | 
					
						
							|  |  |  |                         reranking_model | 
					
						
							|  |  |  |                         and reranking_model.get("reranking_model_name") | 
					
						
							|  |  |  |                         and reranking_model.get("reranking_provider_name") | 
					
						
							|  |  |  |                         and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value | 
					
						
							|  |  |  |                     ): | 
					
						
							|  |  |  |                         data_post_processor = DataPostProcessor( | 
					
						
							|  |  |  |                             str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False | 
					
						
							|  |  |  |                         ) | 
					
						
							|  |  |  |                         all_documents.extend( | 
					
						
							|  |  |  |                             data_post_processor.invoke( | 
					
						
							| 
									
										
										
										
											2024-11-30 11:14:45 +08:00
										 |  |  |                                 query=query, | 
					
						
							|  |  |  |                                 documents=documents, | 
					
						
							|  |  |  |                                 score_threshold=score_threshold, | 
					
						
							| 
									
										
										
										
											2024-12-03 17:34:56 +08:00
										 |  |  |                                 top_n=len(documents), | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                             ) | 
					
						
							|  |  |  |                         ) | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  |                     else: | 
					
						
							|  |  |  |                         all_documents.extend(documents) | 
					
						
							|  |  |  |             except Exception as e: | 
					
						
							|  |  |  |                 exceptions.append(str(e)) | 
					
						
							| 
									
										
										
										
											2024-07-23 16:02:25 +09:00
										 |  |  | 
 | 
					
						
							|  |  |  |     @staticmethod | 
					
						
							|  |  |  |     def escape_query_for_search(query: str) -> str: | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         return query.replace('"', '\\"') |