| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  | import datetime | 
					
						
							|  |  |  | import logging | 
					
						
							|  |  |  | import time | 
					
						
							|  |  |  | import uuid | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import click | 
					
						
							| 
									
										
										
										
											2024-12-24 18:38:51 +08:00
										 |  |  | from celery import shared_task  # type: ignore | 
					
						
							| 
									
										
										
										
											2024-02-06 13:21:13 +08:00
										 |  |  | from sqlalchemy import func | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | from core.model_manager import ModelManager | 
					
						
							|  |  |  | from core.model_runtime.entities.model_entities import ModelType | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  | from extensions.ext_database import db | 
					
						
							|  |  |  | from extensions.ext_redis import redis_client | 
					
						
							|  |  |  | from libs import helper | 
					
						
							| 
									
										
										
										
											2024-01-12 12:34:01 +08:00
										 |  |  | from models.dataset import Dataset, Document, DocumentSegment | 
					
						
							| 
									
										
										
										
											2024-12-25 19:49:07 +08:00
										 |  |  | from services.vector_service import VectorService | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 13:38:37 +08:00
										 |  |  | @shared_task(queue="dataset") | 
					
						
							|  |  |  | def batch_create_segment_to_index_task( | 
					
						
							|  |  |  |     job_id: str, content: list, dataset_id: str, document_id: str, tenant_id: str, user_id: str | 
					
						
							|  |  |  | ): | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  |     """
 | 
					
						
							|  |  |  |     Async batch create segment to index | 
					
						
							|  |  |  |     :param job_id: | 
					
						
							|  |  |  |     :param content: | 
					
						
							|  |  |  |     :param dataset_id: | 
					
						
							|  |  |  |     :param document_id: | 
					
						
							|  |  |  |     :param tenant_id: | 
					
						
							|  |  |  |     :param user_id: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Usage: batch_create_segment_to_index_task.delay(segment_id) | 
					
						
							|  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2024-08-26 13:38:37 +08:00
										 |  |  |     logging.info(click.style("Start batch create segment jobId: {}".format(job_id), fg="green")) | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  |     start_at = time.perf_counter() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 13:38:37 +08:00
										 |  |  |     indexing_cache_key = "segment_batch_import_{}".format(job_id) | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | 
					
						
							|  |  |  |         if not dataset: | 
					
						
							| 
									
										
										
										
											2024-08-26 13:38:37 +08:00
										 |  |  |             raise ValueError("Dataset not exist.") | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         dataset_document = db.session.query(Document).filter(Document.id == document_id).first() | 
					
						
							|  |  |  |         if not dataset_document: | 
					
						
							| 
									
										
										
										
											2024-08-26 13:38:37 +08:00
										 |  |  |             raise ValueError("Document not exist.") | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 13:38:37 +08:00
										 |  |  |         if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": | 
					
						
							|  |  |  |             raise ValueError("Document is not available.") | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  |         document_segments = [] | 
					
						
							| 
									
										
										
										
											2023-08-29 03:37:45 +08:00
										 |  |  |         embedding_model = None | 
					
						
							| 
									
										
										
										
											2024-08-26 13:38:37 +08:00
										 |  |  |         if dataset.indexing_technique == "high_quality": | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             model_manager = ModelManager() | 
					
						
							|  |  |  |             embedding_model = model_manager.get_model_instance( | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  |                 tenant_id=dataset.tenant_id, | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |                 provider=dataset.embedding_model_provider, | 
					
						
							|  |  |  |                 model_type=ModelType.TEXT_EMBEDDING, | 
					
						
							| 
									
										
										
										
											2024-08-26 13:38:37 +08:00
										 |  |  |                 model=dataset.embedding_model, | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-11-08 17:32:27 +08:00
										 |  |  |         word_count_change = 0 | 
					
						
							| 
									
										
										
										
											2024-12-24 18:38:51 +08:00
										 |  |  |         segments_to_insert: list[str] = []  # Explicitly type hint the list as List[str] | 
					
						
							| 
									
										
										
										
											2023-08-29 03:37:45 +08:00
										 |  |  |         for segment in content: | 
					
						
							| 
									
										
										
										
											2024-12-24 18:38:51 +08:00
										 |  |  |             content_str = segment["content"] | 
					
						
							| 
									
										
										
										
											2023-08-29 03:37:45 +08:00
										 |  |  |             doc_id = str(uuid.uuid4()) | 
					
						
							| 
									
										
										
										
											2024-12-24 18:38:51 +08:00
										 |  |  |             segment_hash = helper.generate_text_hash(content_str) | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  |             # calc embedding use tokens | 
					
						
							| 
									
										
										
										
											2024-12-24 18:38:51 +08:00
										 |  |  |             tokens = embedding_model.get_text_embedding_num_tokens(texts=[content_str]) if embedding_model else 0 | 
					
						
							| 
									
										
										
										
											2024-08-26 13:38:37 +08:00
										 |  |  |             max_position = ( | 
					
						
							|  |  |  |                 db.session.query(func.max(DocumentSegment.position)) | 
					
						
							|  |  |  |                 .filter(DocumentSegment.document_id == dataset_document.id) | 
					
						
							|  |  |  |                 .scalar() | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  |             segment_document = DocumentSegment( | 
					
						
							|  |  |  |                 tenant_id=tenant_id, | 
					
						
							|  |  |  |                 dataset_id=dataset_id, | 
					
						
							|  |  |  |                 document_id=document_id, | 
					
						
							|  |  |  |                 index_node_id=doc_id, | 
					
						
							|  |  |  |                 index_node_hash=segment_hash, | 
					
						
							|  |  |  |                 position=max_position + 1 if max_position else 1, | 
					
						
							| 
									
										
										
										
											2025-01-13 09:06:59 +08:00
										 |  |  |                 content=content_str, | 
					
						
							|  |  |  |                 word_count=len(content_str), | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  |                 tokens=tokens, | 
					
						
							|  |  |  |                 created_by=user_id, | 
					
						
							| 
									
										
										
										
											2024-11-24 13:28:46 +08:00
										 |  |  |                 indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), | 
					
						
							| 
									
										
										
										
											2024-08-26 13:38:37 +08:00
										 |  |  |                 status="completed", | 
					
						
							| 
									
										
										
										
											2024-11-24 13:28:46 +08:00
										 |  |  |                 completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-08-26 13:38:37 +08:00
										 |  |  |             if dataset_document.doc_form == "qa_model": | 
					
						
							|  |  |  |                 segment_document.answer = segment["answer"] | 
					
						
							| 
									
										
										
										
											2024-11-08 17:32:27 +08:00
										 |  |  |                 segment_document.word_count += len(segment["answer"]) | 
					
						
							|  |  |  |             word_count_change += segment_document.word_count | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  |             db.session.add(segment_document) | 
					
						
							|  |  |  |             document_segments.append(segment_document) | 
					
						
							| 
									
										
										
										
											2024-12-24 18:38:51 +08:00
										 |  |  |             segments_to_insert.append(str(segment))  # Cast to string if needed | 
					
						
							| 
									
										
										
										
											2024-11-08 17:32:27 +08:00
										 |  |  |         # update document word count | 
					
						
							|  |  |  |         dataset_document.word_count += word_count_change | 
					
						
							|  |  |  |         db.session.add(dataset_document) | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  |         # add index to db | 
					
						
							| 
									
										
										
										
											2024-12-25 19:49:07 +08:00
										 |  |  |         VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form) | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  |         db.session.commit() | 
					
						
							| 
									
										
										
										
											2024-08-26 13:38:37 +08:00
										 |  |  |         redis_client.setex(indexing_cache_key, 600, "completed") | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  |         end_at = time.perf_counter() | 
					
						
							| 
									
										
										
										
											2024-08-26 13:38:37 +08:00
										 |  |  |         logging.info( | 
					
						
							|  |  |  |             click.style("Segment batch created job: {} latency: {}".format(job_id, end_at - start_at), fg="green") | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  |     except Exception as e: | 
					
						
							| 
									
										
										
										
											2024-11-15 15:41:40 +08:00
										 |  |  |         logging.exception("Segments batch created index failed") | 
					
						
							| 
									
										
										
										
											2024-08-26 13:38:37 +08:00
										 |  |  |         redis_client.setex(indexing_cache_key, 600, "error") |