| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  | import threading | 
					
						
							|  |  |  | from typing import Optional | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from flask import Flask, current_app | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from core.rag.data_post_processor.data_post_processor import DataPostProcessor | 
					
						
							|  |  |  | from core.rag.datasource.keyword.keyword_factory import Keyword | 
					
						
							|  |  |  | from core.rag.datasource.vdb.vector_factory import Vector | 
					
						
							|  |  |  | from extensions.ext_database import db | 
					
						
							|  |  |  | from models.dataset import Dataset | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | default_retrieval_model = { | 
					
						
							|  |  |  |     'search_method': 'semantic_search', | 
					
						
							|  |  |  |     'reranking_enable': False, | 
					
						
							|  |  |  |     'reranking_model': { | 
					
						
							|  |  |  |         'reranking_provider_name': '', | 
					
						
							|  |  |  |         'reranking_model_name': '' | 
					
						
							|  |  |  |     }, | 
					
						
							|  |  |  |     'top_k': 2, | 
					
						
							|  |  |  |     'score_threshold_enabled': False | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class RetrievalService: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def retrieve(cls, retrival_method: str, dataset_id: str, query: str, | 
					
						
							|  |  |  |                  top_k: int, score_threshold: Optional[float] = .0, reranking_model: Optional[dict] = None): | 
					
						
							| 
									
										
										
										
											2024-02-23 16:54:15 +08:00
										 |  |  |         dataset = db.session.query(Dataset).filter( | 
					
						
							|  |  |  |             Dataset.id == dataset_id | 
					
						
							|  |  |  |         ).first() | 
					
						
							|  |  |  |         if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0: | 
					
						
							|  |  |  |             return [] | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  |         all_documents = [] | 
					
						
							|  |  |  |         threads = [] | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  |         exceptions = [] | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  |         # retrieval_model source with keyword | 
					
						
							|  |  |  |         if retrival_method == 'keyword_search': | 
					
						
							|  |  |  |             keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={ | 
					
						
							|  |  |  |                 'flask_app': current_app._get_current_object(), | 
					
						
							|  |  |  |                 'dataset_id': dataset_id, | 
					
						
							|  |  |  |                 'query': query, | 
					
						
							| 
									
										
										
										
											2024-02-27 11:39:05 +08:00
										 |  |  |                 'top_k': top_k, | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  |                 'all_documents': all_documents, | 
					
						
							|  |  |  |                 'exceptions': exceptions, | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  |             }) | 
					
						
							|  |  |  |             threads.append(keyword_thread) | 
					
						
							|  |  |  |             keyword_thread.start() | 
					
						
							|  |  |  |         # retrieval_model source with semantic | 
					
						
							|  |  |  |         if retrival_method == 'semantic_search' or retrival_method == 'hybrid_search': | 
					
						
							|  |  |  |             embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ | 
					
						
							|  |  |  |                 'flask_app': current_app._get_current_object(), | 
					
						
							|  |  |  |                 'dataset_id': dataset_id, | 
					
						
							|  |  |  |                 'query': query, | 
					
						
							|  |  |  |                 'top_k': top_k, | 
					
						
							|  |  |  |                 'score_threshold': score_threshold, | 
					
						
							|  |  |  |                 'reranking_model': reranking_model, | 
					
						
							|  |  |  |                 'all_documents': all_documents, | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  |                 'retrival_method': retrival_method, | 
					
						
							|  |  |  |                 'exceptions': exceptions, | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  |             }) | 
					
						
							|  |  |  |             threads.append(embedding_thread) | 
					
						
							|  |  |  |             embedding_thread.start() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # retrieval source with full text | 
					
						
							|  |  |  |         if retrival_method == 'full_text_search' or retrival_method == 'hybrid_search': | 
					
						
							|  |  |  |             full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ | 
					
						
							|  |  |  |                 'flask_app': current_app._get_current_object(), | 
					
						
							|  |  |  |                 'dataset_id': dataset_id, | 
					
						
							|  |  |  |                 'query': query, | 
					
						
							|  |  |  |                 'retrival_method': retrival_method, | 
					
						
							|  |  |  |                 'score_threshold': score_threshold, | 
					
						
							|  |  |  |                 'top_k': top_k, | 
					
						
							|  |  |  |                 'reranking_model': reranking_model, | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  |                 'all_documents': all_documents, | 
					
						
							|  |  |  |                 'exceptions': exceptions, | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  |             }) | 
					
						
							|  |  |  |             threads.append(full_text_index_thread) | 
					
						
							|  |  |  |             full_text_index_thread.start() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for thread in threads: | 
					
						
							|  |  |  |             thread.join() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  |         if exceptions: | 
					
						
							|  |  |  |             exception_message = ';\n'.join(exceptions) | 
					
						
							|  |  |  |             raise Exception(exception_message) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  |         if retrival_method == 'hybrid_search': | 
					
						
							| 
									
										
										
										
											2024-02-23 16:54:15 +08:00
										 |  |  |             data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  |             all_documents = data_post_processor.invoke( | 
					
						
							|  |  |  |                 query=query, | 
					
						
							|  |  |  |                 documents=all_documents, | 
					
						
							|  |  |  |                 score_threshold=score_threshold, | 
					
						
							|  |  |  |                 top_n=top_k | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         return all_documents | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def keyword_search(cls, flask_app: Flask, dataset_id: str, query: str, | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  |                        top_k: int, all_documents: list, exceptions: list): | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  |         with flask_app.app_context(): | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  |             try: | 
					
						
							|  |  |  |                 dataset = db.session.query(Dataset).filter( | 
					
						
							|  |  |  |                     Dataset.id == dataset_id | 
					
						
							|  |  |  |                 ).first() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 keyword = Keyword( | 
					
						
							|  |  |  |                     dataset=dataset | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 documents = keyword.search( | 
					
						
							|  |  |  |                     query, | 
					
						
							|  |  |  |                     top_k=top_k | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |                 all_documents.extend(documents) | 
					
						
							|  |  |  |             except Exception as e: | 
					
						
							|  |  |  |                 exceptions.append(str(e)) | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str, | 
					
						
							|  |  |  |                          top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  |                          all_documents: list, retrival_method: str, exceptions: list): | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  |         with flask_app.app_context(): | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  |             try: | 
					
						
							|  |  |  |                 dataset = db.session.query(Dataset).filter( | 
					
						
							|  |  |  |                     Dataset.id == dataset_id | 
					
						
							|  |  |  |                 ).first() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 vector = Vector( | 
					
						
							|  |  |  |                     dataset=dataset | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 documents = vector.search_by_vector( | 
					
						
							|  |  |  |                     query, | 
					
						
							|  |  |  |                     search_type='similarity_score_threshold', | 
					
						
							|  |  |  |                     top_k=top_k, | 
					
						
							|  |  |  |                     score_threshold=score_threshold, | 
					
						
							|  |  |  |                     filter={ | 
					
						
							|  |  |  |                         'group_id': [dataset.id] | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 if documents: | 
					
						
							|  |  |  |                     if reranking_model and retrival_method == 'semantic_search': | 
					
						
							|  |  |  |                         data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) | 
					
						
							|  |  |  |                         all_documents.extend(data_post_processor.invoke( | 
					
						
							|  |  |  |                             query=query, | 
					
						
							|  |  |  |                             documents=documents, | 
					
						
							|  |  |  |                             score_threshold=score_threshold, | 
					
						
							|  |  |  |                             top_n=len(documents) | 
					
						
							|  |  |  |                         )) | 
					
						
							|  |  |  |                     else: | 
					
						
							|  |  |  |                         all_documents.extend(documents) | 
					
						
							|  |  |  |             except Exception as e: | 
					
						
							|  |  |  |                 exceptions.append(str(e)) | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str, | 
					
						
							|  |  |  |                                top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  |                                all_documents: list, retrival_method: str, exceptions: list): | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  |         with flask_app.app_context(): | 
					
						
							| 
									
										
										
										
											2024-05-28 14:54:53 +08:00
										 |  |  |             try: | 
					
						
							|  |  |  |                 dataset = db.session.query(Dataset).filter( | 
					
						
							|  |  |  |                     Dataset.id == dataset_id | 
					
						
							|  |  |  |                 ).first() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 vector_processor = Vector( | 
					
						
							|  |  |  |                     dataset=dataset, | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 documents = vector_processor.search_by_full_text( | 
					
						
							|  |  |  |                     query, | 
					
						
							|  |  |  |                     top_k=top_k | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |                 if documents: | 
					
						
							|  |  |  |                     if reranking_model and retrival_method == 'full_text_search': | 
					
						
							|  |  |  |                         data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) | 
					
						
							|  |  |  |                         all_documents.extend(data_post_processor.invoke( | 
					
						
							|  |  |  |                             query=query, | 
					
						
							|  |  |  |                             documents=documents, | 
					
						
							|  |  |  |                             score_threshold=score_threshold, | 
					
						
							|  |  |  |                             top_n=len(documents) | 
					
						
							|  |  |  |                         )) | 
					
						
							|  |  |  |                     else: | 
					
						
							|  |  |  |                         all_documents.extend(documents) | 
					
						
							|  |  |  |             except Exception as e: | 
					
						
							|  |  |  |                 exceptions.append(str(e)) |