| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | import logging | 
					
						
							|  |  |  | import time | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  | from core.rag.datasource.retrieval_service import RetrievalService | 
					
						
							|  |  |  | from core.rag.models.document import Document | 
					
						
							| 
									
										
										
										
											2024-06-19 16:05:27 +08:00
										 |  |  | from core.rag.retrieval.retrival_methods import RetrievalMethod | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | from extensions.ext_database import db | 
					
						
							|  |  |  | from models.account import Account | 
					
						
							| 
									
										
										
										
											2024-01-12 12:34:01 +08:00
										 |  |  | from models.dataset import Dataset, DatasetQuery, DocumentSegment | 
					
						
							| 
									
										
										
										
											2023-11-17 22:13:37 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | default_retrieval_model = { | 
					
						
							| 
									
										
										
										
											2024-06-19 16:05:27 +08:00
										 |  |  |     'search_method': RetrievalMethod.SEMANTIC_SEARCH, | 
					
						
							| 
									
										
										
										
											2023-11-17 22:13:37 +08:00
										 |  |  |     'reranking_enable': False, | 
					
						
							|  |  |  |     'reranking_model': { | 
					
						
							|  |  |  |         'reranking_provider_name': '', | 
					
						
							|  |  |  |         'reranking_model_name': '' | 
					
						
							|  |  |  |     }, | 
					
						
							|  |  |  |     'top_k': 2, | 
					
						
							| 
									
										
										
										
											2023-11-27 15:34:45 +08:00
										 |  |  |     'score_threshold_enabled': False | 
					
						
							| 
									
										
										
										
											2023-11-17 22:13:37 +08:00
										 |  |  | } | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | class HitTestingService: | 
					
						
							|  |  |  |     @classmethod | 
					
						
							| 
									
										
										
										
											2023-11-17 22:13:37 +08:00
										 |  |  |     def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_model: dict, limit: int = 10) -> dict: | 
					
						
							| 
									
										
										
										
											2023-08-23 19:57:27 +08:00
										 |  |  |         if dataset.available_document_count == 0 or dataset.available_segment_count == 0: | 
					
						
							| 
									
										
										
										
											2023-06-25 16:49:14 +08:00
										 |  |  |             return { | 
					
						
							|  |  |  |                 "query": { | 
					
						
							|  |  |  |                     "content": query, | 
					
						
							|  |  |  |                     "tsne_position": {'x': 0, 'y': 0}, | 
					
						
							|  |  |  |                 }, | 
					
						
							|  |  |  |                 "records": [] | 
					
						
							|  |  |  |             } | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-17 22:13:37 +08:00
										 |  |  |         start = time.perf_counter() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # get retrieval model , if the model is not setting , using default | 
					
						
							|  |  |  |         if not retrieval_model: | 
					
						
							|  |  |  |             retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-22 23:31:57 +08:00
										 |  |  |         all_documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], | 
					
						
							|  |  |  |                                                   dataset_id=dataset.id, | 
					
						
							|  |  |  |                                                   query=query, | 
					
						
							|  |  |  |                                                   top_k=retrieval_model['top_k'], | 
					
						
							|  |  |  |                                                   score_threshold=retrieval_model['score_threshold'] | 
					
						
							|  |  |  |                                                   if retrieval_model['score_threshold_enabled'] else None, | 
					
						
							|  |  |  |                                                   reranking_model=retrieval_model['reranking_model'] | 
					
						
							|  |  |  |                                                   if retrieval_model['reranking_enable'] else None | 
					
						
							|  |  |  |                                                   ) | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         end = time.perf_counter() | 
					
						
							|  |  |  |         logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         dataset_query = DatasetQuery( | 
					
						
							|  |  |  |             dataset_id=dataset.id, | 
					
						
							|  |  |  |             content=query, | 
					
						
							|  |  |  |             source='hit_testing', | 
					
						
							|  |  |  |             created_by_role='account', | 
					
						
							|  |  |  |             created_by=account.id | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         db.session.add(dataset_query) | 
					
						
							|  |  |  |         db.session.commit() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-02 17:57:42 +08:00
										 |  |  |         return cls.compact_retrieve_response(dataset, query, all_documents) | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							| 
									
										
										
										
											2024-07-02 17:57:42 +08:00
										 |  |  |     def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]): | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         i = 0 | 
					
						
							|  |  |  |         records = [] | 
					
						
							| 
									
										
										
										
											2023-06-25 16:49:14 +08:00
										 |  |  |         for document in documents: | 
					
						
							|  |  |  |             index_node_id = document.metadata['doc_id'] | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             segment = db.session.query(DocumentSegment).filter( | 
					
						
							|  |  |  |                 DocumentSegment.dataset_id == dataset.id, | 
					
						
							|  |  |  |                 DocumentSegment.enabled == True, | 
					
						
							|  |  |  |                 DocumentSegment.status == 'completed', | 
					
						
							|  |  |  |                 DocumentSegment.index_node_id == index_node_id | 
					
						
							|  |  |  |             ).first() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if not segment: | 
					
						
							|  |  |  |                 i += 1 | 
					
						
							|  |  |  |                 continue | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             record = { | 
					
						
							|  |  |  |                 "segment": segment, | 
					
						
							| 
									
										
										
										
											2023-11-17 22:13:37 +08:00
										 |  |  |                 "score": document.metadata.get('score', None), | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |             } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             records.append(record) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             i += 1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return { | 
					
						
							|  |  |  |             "query": { | 
					
						
							| 
									
										
										
										
											2023-06-25 16:49:14 +08:00
										 |  |  |                 "content": query, | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |             }, | 
					
						
							|  |  |  |             "records": records | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-17 22:13:37 +08:00
										 |  |  |     @classmethod | 
					
						
							|  |  |  |     def hit_testing_args_check(cls, args): | 
					
						
							|  |  |  |         query = args['query'] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if not query or len(query) > 250: | 
					
						
							|  |  |  |             raise ValueError('Query is required and cannot exceed 250 characters') |