| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  | import datetime | 
					
						
							|  |  |  | import logging | 
					
						
							|  |  |  | import time | 
					
						
							|  |  |  | import uuid | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import click | 
					
						
							| 
									
										
										
										
											2024-12-24 18:38:51 +08:00
										 |  |  | from celery import shared_task  # type: ignore | 
					
						
							| 
									
										
										
										
											2025-06-04 19:56:34 +08:00
										 |  |  | from sqlalchemy import func | 
					
						
							| 
									
										
										
										
											2025-01-22 13:39:02 +08:00
										 |  |  | from sqlalchemy.orm import Session | 
					
						
							| 
									
										
										
										
											2024-02-06 13:21:13 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | from core.model_manager import ModelManager | 
					
						
							|  |  |  | from core.model_runtime.entities.model_entities import ModelType | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  | from extensions.ext_database import db | 
					
						
							|  |  |  | from extensions.ext_redis import redis_client | 
					
						
							|  |  |  | from libs import helper | 
					
						
							| 
									
										
										
										
											2024-01-12 12:34:01 +08:00
										 |  |  | from models.dataset import Dataset, Document, DocumentSegment | 
					
						
							| 
									
										
										
										
											2024-12-25 19:49:07 +08:00
										 |  |  | from services.vector_service import VectorService | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 13:38:37 +08:00
										 |  |  | @shared_task(queue="dataset") | 
					
						
							|  |  |  | def batch_create_segment_to_index_task( | 
					
						
							| 
									
										
										
										
											2025-01-22 13:39:02 +08:00
										 |  |  |     job_id: str, | 
					
						
							|  |  |  |     content: list, | 
					
						
							|  |  |  |     dataset_id: str, | 
					
						
							|  |  |  |     document_id: str, | 
					
						
							|  |  |  |     tenant_id: str, | 
					
						
							|  |  |  |     user_id: str, | 
					
						
							| 
									
										
										
										
											2024-08-26 13:38:37 +08:00
										 |  |  | ): | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  |     """
 | 
					
						
							|  |  |  |     Async batch create segment to index | 
					
						
							|  |  |  |     :param job_id: | 
					
						
							|  |  |  |     :param content: | 
					
						
							|  |  |  |     :param dataset_id: | 
					
						
							|  |  |  |     :param document_id: | 
					
						
							|  |  |  |     :param tenant_id: | 
					
						
							|  |  |  |     :param user_id: | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-03-16 11:16:28 +08:00
										 |  |  |     Usage: batch_create_segment_to_index_task.delay(job_id, content, dataset_id, document_id, tenant_id, user_id) | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2024-08-26 13:38:37 +08:00
										 |  |  |     logging.info(click.style("Start batch create segment jobId: {}".format(job_id), fg="green")) | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  |     start_at = time.perf_counter() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 13:38:37 +08:00
										 |  |  |     indexing_cache_key = "segment_batch_import_{}".format(job_id) | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     try: | 
					
						
							| 
									
										
										
										
											2025-01-22 13:39:02 +08:00
										 |  |  |         with Session(db.engine) as session: | 
					
						
							|  |  |  |             dataset = session.get(Dataset, dataset_id) | 
					
						
							|  |  |  |             if not dataset: | 
					
						
							|  |  |  |                 raise ValueError("Dataset not exist.") | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-22 13:39:02 +08:00
										 |  |  |             dataset_document = session.get(Document, document_id) | 
					
						
							|  |  |  |             if not dataset_document: | 
					
						
							|  |  |  |                 raise ValueError("Document not exist.") | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-22 13:39:02 +08:00
										 |  |  |             if ( | 
					
						
							|  |  |  |                 not dataset_document.enabled | 
					
						
							|  |  |  |                 or dataset_document.archived | 
					
						
							|  |  |  |                 or dataset_document.indexing_status != "completed" | 
					
						
							|  |  |  |             ): | 
					
						
							|  |  |  |                 raise ValueError("Document is not available.") | 
					
						
							|  |  |  |             document_segments = [] | 
					
						
							|  |  |  |             embedding_model = None | 
					
						
							|  |  |  |             if dataset.indexing_technique == "high_quality": | 
					
						
							|  |  |  |                 model_manager = ModelManager() | 
					
						
							|  |  |  |                 embedding_model = model_manager.get_model_instance( | 
					
						
							|  |  |  |                     tenant_id=dataset.tenant_id, | 
					
						
							|  |  |  |                     provider=dataset.embedding_model_provider, | 
					
						
							|  |  |  |                     model_type=ModelType.TEXT_EMBEDDING, | 
					
						
							|  |  |  |                     model=dataset.embedding_model, | 
					
						
							|  |  |  |                 ) | 
					
						
							| 
									
										
										
										
											2025-02-17 17:05:13 +08:00
										 |  |  |         word_count_change = 0 | 
					
						
							|  |  |  |         if embedding_model: | 
					
						
							|  |  |  |             tokens_list = embedding_model.get_text_embedding_num_tokens( | 
					
						
							|  |  |  |                 texts=[segment["content"] for segment in content] | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             tokens_list = [0] * len(content) | 
					
						
							|  |  |  |         for segment, tokens in zip(content, tokens_list): | 
					
						
							|  |  |  |             content = segment["content"] | 
					
						
							|  |  |  |             doc_id = str(uuid.uuid4()) | 
					
						
							|  |  |  |             segment_hash = helper.generate_text_hash(content)  # type: ignore | 
					
						
							|  |  |  |             max_position = ( | 
					
						
							|  |  |  |                 db.session.query(func.max(DocumentSegment.position)) | 
					
						
							|  |  |  |                 .filter(DocumentSegment.document_id == dataset_document.id) | 
					
						
							|  |  |  |                 .scalar() | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             segment_document = DocumentSegment( | 
					
						
							|  |  |  |                 tenant_id=tenant_id, | 
					
						
							|  |  |  |                 dataset_id=dataset_id, | 
					
						
							|  |  |  |                 document_id=document_id, | 
					
						
							|  |  |  |                 index_node_id=doc_id, | 
					
						
							|  |  |  |                 index_node_hash=segment_hash, | 
					
						
							|  |  |  |                 position=max_position + 1 if max_position else 1, | 
					
						
							|  |  |  |                 content=content, | 
					
						
							|  |  |  |                 word_count=len(content), | 
					
						
							|  |  |  |                 tokens=tokens, | 
					
						
							|  |  |  |                 created_by=user_id, | 
					
						
							|  |  |  |                 indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), | 
					
						
							|  |  |  |                 status="completed", | 
					
						
							|  |  |  |                 completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             if dataset_document.doc_form == "qa_model": | 
					
						
							|  |  |  |                 segment_document.answer = segment["answer"] | 
					
						
							|  |  |  |                 segment_document.word_count += len(segment["answer"]) | 
					
						
							|  |  |  |             word_count_change += segment_document.word_count | 
					
						
							|  |  |  |             db.session.add(segment_document) | 
					
						
							|  |  |  |             document_segments.append(segment_document) | 
					
						
							|  |  |  |         # update document word count | 
					
						
							|  |  |  |         dataset_document.word_count += word_count_change | 
					
						
							|  |  |  |         db.session.add(dataset_document) | 
					
						
							|  |  |  |         # add index to db | 
					
						
							|  |  |  |         VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form) | 
					
						
							|  |  |  |         db.session.commit() | 
					
						
							| 
									
										
										
										
											2024-08-26 13:38:37 +08:00
										 |  |  |         redis_client.setex(indexing_cache_key, 600, "completed") | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  |         end_at = time.perf_counter() | 
					
						
							| 
									
										
										
										
											2024-08-26 13:38:37 +08:00
										 |  |  |         logging.info( | 
					
						
							| 
									
										
										
										
											2025-01-22 13:39:02 +08:00
										 |  |  |             click.style( | 
					
						
							|  |  |  |                 "Segment batch created job: {} latency: {}".format(job_id, end_at - start_at), | 
					
						
							|  |  |  |                 fg="green", | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-08-26 13:38:37 +08:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2025-02-17 17:05:13 +08:00
										 |  |  |     except Exception: | 
					
						
							| 
									
										
										
										
											2024-11-15 15:41:40 +08:00
										 |  |  |         logging.exception("Segments batch created index failed") | 
					
						
							| 
									
										
										
										
											2024-08-26 13:38:37 +08:00
										 |  |  |         redis_client.setex(indexing_cache_key, 600, "error") | 
					
						
							| 
									
										
										
										
											2025-04-07 20:31:26 +08:00
										 |  |  |     finally: | 
					
						
							|  |  |  |         db.session.close() |