| 
									
										
										
										
											2024-01-12 12:34:01 +08:00
										 |  |  | from typing import Optional, Union | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-13 22:05:46 +08:00
										 |  |  | from core.generator.llm_generator import LLMGenerator | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | from extensions.ext_database import db | 
					
						
							| 
									
										
										
										
											2024-01-12 12:34:01 +08:00
										 |  |  | from libs.infinite_scroll_pagination import InfiniteScrollPagination | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | from models.account import Account | 
					
						
							| 
									
										
										
										
											2024-01-12 12:34:01 +08:00
										 |  |  | from models.model import App, Conversation, EndUser, Message | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError | 
					
						
							| 
									
										
										
										
											2023-11-13 22:05:46 +08:00
										 |  |  | from services.errors.message import MessageNotExistsError | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class ConversationService: | 
					
						
							|  |  |  |     @classmethod | 
					
						
							| 
									
										
										
										
											2023-12-03 20:59:13 +08:00
										 |  |  |     def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]], | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |                               last_id: Optional[str], limit: int, | 
					
						
							| 
									
										
										
										
											2023-11-13 22:05:46 +08:00
										 |  |  |                               include_ids: Optional[list] = None, exclude_ids: Optional[list] = None, | 
					
						
							|  |  |  |                               exclude_debug_conversation: bool = False) -> InfiniteScrollPagination: | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         if not user: | 
					
						
							|  |  |  |             return InfiniteScrollPagination(data=[], limit=limit, has_more=False) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         base_query = db.session.query(Conversation).filter( | 
					
						
							| 
									
										
										
										
											2023-06-28 13:31:51 +08:00
										 |  |  |             Conversation.is_deleted == False, | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |             Conversation.app_id == app_model.id, | 
					
						
							|  |  |  |             Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'), | 
					
						
							|  |  |  |             Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None), | 
					
						
							|  |  |  |             Conversation.from_account_id == (user.id if isinstance(user, Account) else None), | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if include_ids is not None: | 
					
						
							|  |  |  |             base_query = base_query.filter(Conversation.id.in_(include_ids)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if exclude_ids is not None: | 
					
						
							|  |  |  |             base_query = base_query.filter(~Conversation.id.in_(exclude_ids)) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-13 22:05:46 +08:00
										 |  |  |         if exclude_debug_conversation: | 
					
						
							|  |  |  |             base_query = base_query.filter(Conversation.override_model_configs == None) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         if last_id: | 
					
						
							|  |  |  |             last_conversation = base_query.filter( | 
					
						
							|  |  |  |                 Conversation.id == last_id, | 
					
						
							|  |  |  |             ).first() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if not last_conversation: | 
					
						
							|  |  |  |                 raise LastConversationNotExistsError() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             conversations = base_query.filter( | 
					
						
							|  |  |  |                 Conversation.created_at < last_conversation.created_at, | 
					
						
							|  |  |  |                 Conversation.id != last_conversation.id | 
					
						
							|  |  |  |             ).order_by(Conversation.created_at.desc()).limit(limit).all() | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             conversations = base_query.order_by(Conversation.created_at.desc()).limit(limit).all() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         has_more = False | 
					
						
							|  |  |  |         if len(conversations) == limit: | 
					
						
							|  |  |  |             current_page_first_conversation = conversations[-1] | 
					
						
							|  |  |  |             rest_count = base_query.filter( | 
					
						
							|  |  |  |                 Conversation.created_at < current_page_first_conversation.created_at, | 
					
						
							|  |  |  |                 Conversation.id != current_page_first_conversation.id | 
					
						
							|  |  |  |             ).count() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if rest_count > 0: | 
					
						
							|  |  |  |                 has_more = True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return InfiniteScrollPagination( | 
					
						
							|  |  |  |             data=conversations, | 
					
						
							|  |  |  |             limit=limit, | 
					
						
							|  |  |  |             has_more=has_more | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def rename(cls, app_model: App, conversation_id: str, | 
					
						
							| 
									
										
										
										
											2023-12-03 20:59:13 +08:00
										 |  |  |                user: Optional[Union[Account, EndUser]], name: str, auto_generate: bool): | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         conversation = cls.get_conversation(app_model, conversation_id, user) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-13 22:05:46 +08:00
										 |  |  |         if auto_generate: | 
					
						
							|  |  |  |             return cls.auto_generate_name(app_model, conversation) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             conversation.name = name | 
					
						
							|  |  |  |             db.session.commit() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return conversation | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def auto_generate_name(cls, app_model: App, conversation: Conversation): | 
					
						
							|  |  |  |         # get conversation first message | 
					
						
							|  |  |  |         message = db.session.query(Message) \ | 
					
						
							|  |  |  |             .filter( | 
					
						
							|  |  |  |                 Message.app_id == app_model.id, | 
					
						
							|  |  |  |                 Message.conversation_id == conversation.id | 
					
						
							|  |  |  |             ).order_by(Message.created_at.asc()).first() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if not message: | 
					
						
							|  |  |  |             raise MessageNotExistsError() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # generate conversation name | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             name = LLMGenerator.generate_conversation_name(app_model.tenant_id, message.query) | 
					
						
							|  |  |  |             conversation.name = name | 
					
						
							|  |  |  |         except: | 
					
						
							|  |  |  |             pass | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         db.session.commit() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return conversation | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							| 
									
										
										
										
											2023-12-03 20:59:13 +08:00
										 |  |  |     def get_conversation(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         conversation = db.session.query(Conversation) \ | 
					
						
							|  |  |  |             .filter( | 
					
						
							|  |  |  |             Conversation.id == conversation_id, | 
					
						
							|  |  |  |             Conversation.app_id == app_model.id, | 
					
						
							|  |  |  |             Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'), | 
					
						
							|  |  |  |             Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None), | 
					
						
							|  |  |  |             Conversation.from_account_id == (user.id if isinstance(user, Account) else None), | 
					
						
							| 
									
										
										
										
											2023-06-28 13:31:51 +08:00
										 |  |  |             Conversation.is_deleted == False | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         ).first() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if not conversation: | 
					
						
							|  |  |  |             raise ConversationNotExistsError() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return conversation | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							| 
									
										
										
										
											2023-12-03 20:59:13 +08:00
										 |  |  |     def delete(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         conversation = cls.get_conversation(app_model, conversation_id, user) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-06-28 13:31:51 +08:00
										 |  |  |         conversation.is_deleted = True | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         db.session.commit() |