| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  | import datetime | 
					
						
							|  |  |  | import logging | 
					
						
							|  |  |  | import time | 
					
						
							|  |  |  | import uuid | 
					
						
							| 
									
										
										
										
											2024-02-09 15:21:33 +08:00
										 |  |  | from typing import cast | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | import click | 
					
						
							|  |  |  | from celery import shared_task | 
					
						
							| 
									
										
										
										
											2024-02-06 13:21:13 +08:00
										 |  |  | from sqlalchemy import func | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  | from core.indexing_runner import IndexingRunner | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | from core.model_manager import ModelManager | 
					
						
							|  |  |  | from core.model_runtime.entities.model_entities import ModelType | 
					
						
							|  |  |  | from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  | from extensions.ext_database import db | 
					
						
							|  |  |  | from extensions.ext_redis import redis_client | 
					
						
							|  |  |  | from libs import helper | 
					
						
							| 
									
										
										
										
											2024-01-12 12:34:01 +08:00
										 |  |  | from models.dataset import Dataset, Document, DocumentSegment | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @shared_task(queue='dataset') | 
					
						
							| 
									
										
										
										
											2024-02-09 15:21:33 +08:00
										 |  |  | def batch_create_segment_to_index_task(job_id: str, content: list, dataset_id: str, document_id: str, | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  |                                        tenant_id: str, user_id: str): | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     Async batch create segment to index | 
					
						
							|  |  |  |     :param job_id: | 
					
						
							|  |  |  |     :param content: | 
					
						
							|  |  |  |     :param dataset_id: | 
					
						
							|  |  |  |     :param document_id: | 
					
						
							|  |  |  |     :param tenant_id: | 
					
						
							|  |  |  |     :param user_id: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Usage: batch_create_segment_to_index_task.delay(segment_id) | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     logging.info(click.style('Start batch create segment jobId: {}'.format(job_id), fg='green')) | 
					
						
							|  |  |  |     start_at = time.perf_counter() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     indexing_cache_key = 'segment_batch_import_{}'.format(job_id) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | 
					
						
							|  |  |  |         if not dataset: | 
					
						
							|  |  |  |             raise ValueError('Dataset not exist.') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         dataset_document = db.session.query(Document).filter(Document.id == document_id).first() | 
					
						
							|  |  |  |         if not dataset_document: | 
					
						
							|  |  |  |             raise ValueError('Document not exist.') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': | 
					
						
							|  |  |  |             raise ValueError('Document is not available.') | 
					
						
							|  |  |  |         document_segments = [] | 
					
						
							| 
									
										
										
										
											2023-08-29 03:37:45 +08:00
										 |  |  |         embedding_model = None | 
					
						
							|  |  |  |         if dataset.indexing_technique == 'high_quality': | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             model_manager = ModelManager() | 
					
						
							|  |  |  |             embedding_model = model_manager.get_model_instance( | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  |                 tenant_id=dataset.tenant_id, | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |                 provider=dataset.embedding_model_provider, | 
					
						
							|  |  |  |                 model_type=ModelType.TEXT_EMBEDDING, | 
					
						
							|  |  |  |                 model=dataset.embedding_model | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         model_type_instance = embedding_model.model_type_instance | 
					
						
							|  |  |  |         model_type_instance = cast(TextEmbeddingModel, model_type_instance) | 
					
						
							| 
									
										
										
										
											2023-08-29 03:37:45 +08:00
										 |  |  |         for segment in content: | 
					
						
							|  |  |  |             content = segment['content'] | 
					
						
							|  |  |  |             doc_id = str(uuid.uuid4()) | 
					
						
							|  |  |  |             segment_hash = helper.generate_text_hash(content) | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  |             # calc embedding use tokens | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             tokens = model_type_instance.get_num_tokens( | 
					
						
							|  |  |  |                 model=embedding_model.model, | 
					
						
							|  |  |  |                 credentials=embedding_model.credentials, | 
					
						
							|  |  |  |                 texts=[content] | 
					
						
							|  |  |  |             ) if embedding_model else 0 | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  |             max_position = db.session.query(func.max(DocumentSegment.position)).filter( | 
					
						
							|  |  |  |                 DocumentSegment.document_id == dataset_document.id | 
					
						
							|  |  |  |             ).scalar() | 
					
						
							|  |  |  |             segment_document = DocumentSegment( | 
					
						
							|  |  |  |                 tenant_id=tenant_id, | 
					
						
							|  |  |  |                 dataset_id=dataset_id, | 
					
						
							|  |  |  |                 document_id=document_id, | 
					
						
							|  |  |  |                 index_node_id=doc_id, | 
					
						
							|  |  |  |                 index_node_hash=segment_hash, | 
					
						
							|  |  |  |                 position=max_position + 1 if max_position else 1, | 
					
						
							|  |  |  |                 content=content, | 
					
						
							|  |  |  |                 word_count=len(content), | 
					
						
							|  |  |  |                 tokens=tokens, | 
					
						
							|  |  |  |                 created_by=user_id, | 
					
						
							| 
									
										
										
										
											2024-04-12 16:22:24 +08:00
										 |  |  |                 indexing_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  |                 status='completed', | 
					
						
							| 
									
										
										
										
											2024-04-12 16:22:24 +08:00
										 |  |  |                 completed_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  |             if dataset_document.doc_form == 'qa_model': | 
					
						
							|  |  |  |                 segment_document.answer = segment['answer'] | 
					
						
							|  |  |  |             db.session.add(segment_document) | 
					
						
							|  |  |  |             document_segments.append(segment_document) | 
					
						
							|  |  |  |         # add index to db | 
					
						
							|  |  |  |         indexing_runner = IndexingRunner() | 
					
						
							|  |  |  |         indexing_runner.batch_add_segments(document_segments, dataset) | 
					
						
							|  |  |  |         db.session.commit() | 
					
						
							|  |  |  |         redis_client.setex(indexing_cache_key, 600, 'completed') | 
					
						
							|  |  |  |         end_at = time.perf_counter() | 
					
						
							|  |  |  |         logging.info(click.style('Segment batch created job: {} latency: {}'.format(job_id, end_at - start_at), fg='green')) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         logging.exception("Segments batch created index failed:{}".format(str(e))) | 
					
						
							|  |  |  |         redis_client.setex(indexing_cache_key, 600, 'error') |