mirror of
				https://github.com/langgenius/dify.git
				synced 2025-10-30 18:33:30 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			107 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			107 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import logging
 | |
| import time
 | |
| 
 | |
| from core.rag.datasource.retrieval_service import RetrievalService
 | |
| from core.rag.models.document import Document
 | |
| from core.rag.retrieval.retrival_methods import RetrievalMethod
 | |
| from extensions.ext_database import db
 | |
| from models.account import Account
 | |
| from models.dataset import Dataset, DatasetQuery, DocumentSegment
 | |
| 
 | |
| default_retrieval_model = {
 | |
|     'search_method': RetrievalMethod.SEMANTIC_SEARCH,
 | |
|     'reranking_enable': False,
 | |
|     'reranking_model': {
 | |
|         'reranking_provider_name': '',
 | |
|         'reranking_model_name': ''
 | |
|     },
 | |
|     'top_k': 2,
 | |
|     'score_threshold_enabled': False
 | |
| }
 | |
| 
 | |
| 
 | |
| class HitTestingService:
 | |
|     @classmethod
 | |
|     def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_model: dict, limit: int = 10) -> dict:
 | |
|         if dataset.available_document_count == 0 or dataset.available_segment_count == 0:
 | |
|             return {
 | |
|                 "query": {
 | |
|                     "content": query,
 | |
|                     "tsne_position": {'x': 0, 'y': 0},
 | |
|                 },
 | |
|                 "records": []
 | |
|             }
 | |
| 
 | |
|         start = time.perf_counter()
 | |
| 
 | |
|         # get retrieval model , if the model is not setting , using default
 | |
|         if not retrieval_model:
 | |
|             retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
 | |
| 
 | |
|         all_documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
 | |
|                                                   dataset_id=dataset.id,
 | |
|                                                   query=query,
 | |
|                                                   top_k=retrieval_model['top_k'],
 | |
|                                                   score_threshold=retrieval_model['score_threshold']
 | |
|                                                   if retrieval_model['score_threshold_enabled'] else None,
 | |
|                                                   reranking_model=retrieval_model['reranking_model']
 | |
|                                                   if retrieval_model['reranking_enable'] else None
 | |
|                                                   )
 | |
| 
 | |
|         end = time.perf_counter()
 | |
|         logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")
 | |
| 
 | |
|         dataset_query = DatasetQuery(
 | |
|             dataset_id=dataset.id,
 | |
|             content=query,
 | |
|             source='hit_testing',
 | |
|             created_by_role='account',
 | |
|             created_by=account.id
 | |
|         )
 | |
| 
 | |
|         db.session.add(dataset_query)
 | |
|         db.session.commit()
 | |
| 
 | |
|         return cls.compact_retrieve_response(dataset, query, all_documents)
 | |
| 
 | |
|     @classmethod
 | |
|     def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]):
 | |
|         i = 0
 | |
|         records = []
 | |
|         for document in documents:
 | |
|             index_node_id = document.metadata['doc_id']
 | |
| 
 | |
|             segment = db.session.query(DocumentSegment).filter(
 | |
|                 DocumentSegment.dataset_id == dataset.id,
 | |
|                 DocumentSegment.enabled == True,
 | |
|                 DocumentSegment.status == 'completed',
 | |
|                 DocumentSegment.index_node_id == index_node_id
 | |
|             ).first()
 | |
| 
 | |
|             if not segment:
 | |
|                 i += 1
 | |
|                 continue
 | |
| 
 | |
|             record = {
 | |
|                 "segment": segment,
 | |
|                 "score": document.metadata.get('score', None),
 | |
|             }
 | |
| 
 | |
|             records.append(record)
 | |
| 
 | |
|             i += 1
 | |
| 
 | |
|         return {
 | |
|             "query": {
 | |
|                 "content": query,
 | |
|             },
 | |
|             "records": records
 | |
|         }
 | |
| 
 | |
|     @classmethod
 | |
|     def hit_testing_args_check(cls, args):
 | |
|         query = args['query']
 | |
| 
 | |
|         if not query or len(query) > 250:
 | |
|             raise ValueError('Query is required and cannot exceed 250 characters')
 | 
