mirror of
				https://github.com/langgenius/dify.git
				synced 2025-10-31 02:42:59 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			647 lines
		
	
	
		
			28 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			647 lines
		
	
	
		
			28 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import json
 | |
| import logging
 | |
| import uuid
 | |
| from datetime import datetime
 | |
| from mimetypes import guess_extension
 | |
| from typing import Optional, Union, cast
 | |
| 
 | |
| from core.app_runner.app_runner import AppRunner
 | |
| from core.application_queue_manager import ApplicationQueueManager
 | |
| from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
 | |
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
 | |
| from core.entities.application_entities import (
 | |
|     AgentEntity,
 | |
|     AgentToolEntity,
 | |
|     ApplicationGenerateEntity,
 | |
|     AppOrchestrationConfigEntity,
 | |
|     InvokeFrom,
 | |
|     ModelConfigEntity,
 | |
| )
 | |
| from core.file.message_file_parser import FileTransferMethod
 | |
| from core.memory.token_buffer_memory import TokenBufferMemory
 | |
| from core.model_manager import ModelInstance
 | |
| from core.model_runtime.entities.llm_entities import LLMUsage
 | |
| from core.model_runtime.entities.message_entities import (
 | |
|     AssistantPromptMessage,
 | |
|     PromptMessage,
 | |
|     PromptMessageTool,
 | |
|     SystemPromptMessage,
 | |
|     ToolPromptMessage,
 | |
|     UserPromptMessage,
 | |
| )
 | |
| from core.model_runtime.entities.model_entities import ModelFeature
 | |
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 | |
| from core.model_runtime.utils.encoders import jsonable_encoder
 | |
| from core.tools.entities.tool_entities import (
 | |
|     ToolInvokeMessage,
 | |
|     ToolInvokeMessageBinary,
 | |
|     ToolParameter,
 | |
|     ToolRuntimeVariablePool,
 | |
| )
 | |
| from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
 | |
| from core.tools.tool.tool import Tool
 | |
| from core.tools.tool_file_manager import ToolFileManager
 | |
| from core.tools.tool_manager import ToolManager
 | |
| from extensions.ext_database import db
 | |
| from models.model import Message, MessageAgentThought, MessageFile
 | |
| from models.tools import ToolConversationVariables
 | |
| 
 | |
| logger = logging.getLogger(__name__)
 | |
| 
 | |
| class BaseAssistantApplicationRunner(AppRunner):
 | |
|     def __init__(self, tenant_id: str,
 | |
|                  application_generate_entity: ApplicationGenerateEntity,
 | |
|                  app_orchestration_config: AppOrchestrationConfigEntity,
 | |
|                  model_config: ModelConfigEntity,
 | |
|                  config: AgentEntity,
 | |
|                  queue_manager: ApplicationQueueManager,
 | |
|                  message: Message,
 | |
|                  user_id: str,
 | |
|                  memory: Optional[TokenBufferMemory] = None,
 | |
|                  prompt_messages: Optional[list[PromptMessage]] = None,
 | |
|                  variables_pool: Optional[ToolRuntimeVariablePool] = None,
 | |
|                  db_variables: Optional[ToolConversationVariables] = None,
 | |
|                  model_instance: ModelInstance = None
 | |
|                  ) -> None:
 | |
|         """
 | |
|         Agent runner
 | |
|         :param tenant_id: tenant id
 | |
|         :param app_orchestration_config: app orchestration config
 | |
|         :param model_config: model config
 | |
|         :param config: dataset config
 | |
|         :param queue_manager: queue manager
 | |
|         :param message: message
 | |
|         :param user_id: user id
 | |
|         :param agent_llm_callback: agent llm callback
 | |
|         :param callback: callback
 | |
|         :param memory: memory
 | |
|         """
 | |
|         self.tenant_id = tenant_id
 | |
|         self.application_generate_entity = application_generate_entity
 | |
|         self.app_orchestration_config = app_orchestration_config
 | |
|         self.model_config = model_config
 | |
|         self.config = config
 | |
|         self.queue_manager = queue_manager
 | |
|         self.message = message
 | |
|         self.user_id = user_id
 | |
|         self.memory = memory
 | |
|         self.history_prompt_messages = self.organize_agent_history(
 | |
|             prompt_messages=prompt_messages or []
 | |
|         )
 | |
|         self.variables_pool = variables_pool
 | |
|         self.db_variables_pool = db_variables
 | |
|         self.model_instance = model_instance
 | |
| 
 | |
|         # init callback
 | |
|         self.agent_callback = DifyAgentCallbackHandler()
 | |
|         # init dataset tools
 | |
|         hit_callback = DatasetIndexToolCallbackHandler(
 | |
|             queue_manager=queue_manager,
 | |
|             app_id=self.application_generate_entity.app_id,
 | |
|             message_id=message.id,
 | |
|             user_id=user_id,
 | |
|             invoke_from=self.application_generate_entity.invoke_from,
 | |
|         )
 | |
|         self.dataset_tools = DatasetRetrieverTool.get_dataset_tools(
 | |
|             tenant_id=tenant_id,
 | |
|             dataset_ids=app_orchestration_config.dataset.dataset_ids if app_orchestration_config.dataset else [],
 | |
|             retrieve_config=app_orchestration_config.dataset.retrieve_config if app_orchestration_config.dataset else None,
 | |
|             return_resource=app_orchestration_config.show_retrieve_source,
 | |
|             invoke_from=application_generate_entity.invoke_from,
 | |
|             hit_callback=hit_callback
 | |
|         )
 | |
|         # get how many agent thoughts have been created
 | |
|         self.agent_thought_count = db.session.query(MessageAgentThought).filter(
 | |
|             MessageAgentThought.message_id == self.message.id,
 | |
|         ).count()
 | |
| 
 | |
|         # check if model supports stream tool call
 | |
|         llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
 | |
|         model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
 | |
|         if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []):
 | |
|             self.stream_tool_call = True
 | |
|         else:
 | |
|             self.stream_tool_call = False
 | |
| 
 | |
|     def _repack_app_orchestration_config(self, app_orchestration_config: AppOrchestrationConfigEntity) -> AppOrchestrationConfigEntity:
 | |
|         """
 | |
|         Repack app orchestration config
 | |
|         """
 | |
|         if app_orchestration_config.prompt_template.simple_prompt_template is None:
 | |
|             app_orchestration_config.prompt_template.simple_prompt_template = ''
 | |
| 
 | |
|         return app_orchestration_config
 | |
| 
 | |
|     def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str:
 | |
|         """
 | |
|         Handle tool response
 | |
|         """
 | |
|         result = ''
 | |
|         for response in tool_response:
 | |
|             if response.type == ToolInvokeMessage.MessageType.TEXT:
 | |
|                 result += response.message
 | |
|             elif response.type == ToolInvokeMessage.MessageType.LINK:
 | |
|                 result += f"result link: {response.message}. please tell user to check it."
 | |
|             elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
 | |
|                  response.type == ToolInvokeMessage.MessageType.IMAGE:
 | |
|                 result += "image has been created and sent to user already, you should tell user to check it now."
 | |
|             else:
 | |
|                 result += f"tool response: {response.message}."
 | |
| 
 | |
|         return result
 | |
|     
 | |
|     def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
 | |
|         """
 | |
|             convert tool to prompt message tool
 | |
|         """
 | |
|         tool_entity = ToolManager.get_tool_runtime(
 | |
|             provider_type=tool.provider_type, provider_name=tool.provider_id, tool_name=tool.tool_name, 
 | |
|             tenant_id=self.application_generate_entity.tenant_id,
 | |
|             agent_callback=self.agent_callback
 | |
|         )
 | |
|         tool_entity.load_variables(self.variables_pool)
 | |
| 
 | |
|         message_tool = PromptMessageTool(
 | |
|             name=tool.tool_name,
 | |
|             description=tool_entity.description.llm,
 | |
|             parameters={
 | |
|                 "type": "object",
 | |
|                 "properties": {},
 | |
|                 "required": [],
 | |
|             }
 | |
|         )
 | |
| 
 | |
|         runtime_parameters = {}
 | |
| 
 | |
|         parameters = tool_entity.parameters or []
 | |
|         user_parameters = tool_entity.get_runtime_parameters() or []
 | |
| 
 | |
|         # override parameters
 | |
|         for parameter in user_parameters:
 | |
|             # check if parameter in tool parameters
 | |
|             found = False
 | |
|             for tool_parameter in parameters:
 | |
|                 if tool_parameter.name == parameter.name:
 | |
|                     found = True
 | |
|                     break
 | |
| 
 | |
|             if found:
 | |
|                 # override parameter
 | |
|                 tool_parameter.type = parameter.type
 | |
|                 tool_parameter.form = parameter.form
 | |
|                 tool_parameter.required = parameter.required
 | |
|                 tool_parameter.default = parameter.default
 | |
|                 tool_parameter.options = parameter.options
 | |
|                 tool_parameter.llm_description = parameter.llm_description
 | |
|             else:
 | |
|                 # add new parameter
 | |
|                 parameters.append(parameter)
 | |
| 
 | |
|         for parameter in parameters:
 | |
|             parameter_type = 'string'
 | |
|             enum = []
 | |
|             if parameter.type == ToolParameter.ToolParameterType.STRING:
 | |
|                 parameter_type = 'string'
 | |
|             elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
 | |
|                 parameter_type = 'boolean'
 | |
|             elif parameter.type == ToolParameter.ToolParameterType.NUMBER:
 | |
|                 parameter_type = 'number'
 | |
|             elif parameter.type == ToolParameter.ToolParameterType.SELECT:
 | |
|                 for option in parameter.options:
 | |
|                     enum.append(option.value)
 | |
|                 parameter_type = 'string'
 | |
|             else:
 | |
|                 raise ValueError(f"parameter type {parameter.type} is not supported")
 | |
|             
 | |
|             if parameter.form == ToolParameter.ToolParameterForm.FORM:
 | |
|                 # get tool parameter from form
 | |
|                 tool_parameter_config = tool.tool_parameters.get(parameter.name)
 | |
|                 if not tool_parameter_config:
 | |
|                     # get default value
 | |
|                     tool_parameter_config = parameter.default
 | |
|                     if not tool_parameter_config and parameter.required:
 | |
|                         raise ValueError(f"tool parameter {parameter.name} not found in tool config")
 | |
|                     
 | |
|                 if parameter.type == ToolParameter.ToolParameterType.SELECT:
 | |
|                     # check if tool_parameter_config in options
 | |
|                     options = list(map(lambda x: x.value, parameter.options))
 | |
|                     if tool_parameter_config not in options:
 | |
|                         raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} not in options {options}")
 | |
|                     
 | |
|                 # convert tool parameter config to correct type
 | |
|                 try:
 | |
|                     if parameter.type == ToolParameter.ToolParameterType.NUMBER:
 | |
|                         # check if tool parameter is integer
 | |
|                         if isinstance(tool_parameter_config, int):
 | |
|                             tool_parameter_config = tool_parameter_config
 | |
|                         elif isinstance(tool_parameter_config, float):
 | |
|                             tool_parameter_config = tool_parameter_config
 | |
|                         elif isinstance(tool_parameter_config, str):
 | |
|                             if '.' in tool_parameter_config:
 | |
|                                 tool_parameter_config = float(tool_parameter_config)
 | |
|                             else:
 | |
|                                 tool_parameter_config = int(tool_parameter_config)
 | |
|                     elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
 | |
|                         tool_parameter_config = bool(tool_parameter_config)
 | |
|                     elif parameter.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]:
 | |
|                         tool_parameter_config = str(tool_parameter_config)
 | |
|                     elif parameter.type == ToolParameter.ToolParameterType:
 | |
|                         tool_parameter_config = str(tool_parameter_config)
 | |
|                 except Exception as e:
 | |
|                     raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type")
 | |
|                 
 | |
|                 # save tool parameter to tool entity memory
 | |
|                 runtime_parameters[parameter.name] = tool_parameter_config
 | |
|             
 | |
|             elif parameter.form == ToolParameter.ToolParameterForm.LLM:
 | |
|                 message_tool.parameters['properties'][parameter.name] = {
 | |
|                     "type": parameter_type,
 | |
|                     "description": parameter.llm_description or '',
 | |
|                 }
 | |
| 
 | |
|                 if len(enum) > 0:
 | |
|                     message_tool.parameters['properties'][parameter.name]['enum'] = enum
 | |
| 
 | |
|                 if parameter.required:
 | |
|                     message_tool.parameters['required'].append(parameter.name)
 | |
| 
 | |
|         tool_entity.runtime.runtime_parameters.update(runtime_parameters)
 | |
| 
 | |
|         return message_tool, tool_entity
 | |
|     
 | |
|     def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
 | |
|         """
 | |
|         convert dataset retriever tool to prompt message tool
 | |
|         """
 | |
|         prompt_tool = PromptMessageTool(
 | |
|             name=tool.identity.name,
 | |
|             description=tool.description.llm,
 | |
|             parameters={
 | |
|                 "type": "object",
 | |
|                 "properties": {},
 | |
|                 "required": [],
 | |
|             }
 | |
|         )
 | |
| 
 | |
|         for parameter in tool.get_runtime_parameters():
 | |
|             parameter_type = 'string'
 | |
|         
 | |
|             prompt_tool.parameters['properties'][parameter.name] = {
 | |
|                 "type": parameter_type,
 | |
|                 "description": parameter.llm_description or '',
 | |
|             }
 | |
| 
 | |
|             if parameter.required:
 | |
|                 if parameter.name not in prompt_tool.parameters['required']:
 | |
|                     prompt_tool.parameters['required'].append(parameter.name)
 | |
| 
 | |
|         return prompt_tool
 | |
|     
 | |
|     def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:
 | |
|         """
 | |
|         update prompt message tool
 | |
|         """
 | |
|         # try to get tool runtime parameters
 | |
|         tool_runtime_parameters = tool.get_runtime_parameters() or []
 | |
| 
 | |
|         for parameter in tool_runtime_parameters:
 | |
|             parameter_type = 'string'
 | |
|             enum = []
 | |
|             if parameter.type == ToolParameter.ToolParameterType.STRING:
 | |
|                 parameter_type = 'string'
 | |
|             elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
 | |
|                 parameter_type = 'boolean'
 | |
|             elif parameter.type == ToolParameter.ToolParameterType.NUMBER:
 | |
|                 parameter_type = 'number'
 | |
|             elif parameter.type == ToolParameter.ToolParameterType.SELECT:
 | |
|                 for option in parameter.options:
 | |
|                     enum.append(option.value)
 | |
|                 parameter_type = 'string'
 | |
|             else:
 | |
|                 raise ValueError(f"parameter type {parameter.type} is not supported")
 | |
|         
 | |
|             if parameter.form == ToolParameter.ToolParameterForm.LLM:
 | |
|                 prompt_tool.parameters['properties'][parameter.name] = {
 | |
|                     "type": parameter_type,
 | |
|                     "description": parameter.llm_description or '',
 | |
|                 }
 | |
| 
 | |
|                 if len(enum) > 0:
 | |
|                     prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
 | |
| 
 | |
|                 if parameter.required:
 | |
|                     if parameter.name not in prompt_tool.parameters['required']:
 | |
|                         prompt_tool.parameters['required'].append(parameter.name)
 | |
| 
 | |
|         return prompt_tool
 | |
|     
 | |
|     def extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[ToolInvokeMessageBinary]:
 | |
|         """
 | |
|         Extract tool response binary
 | |
|         """
 | |
|         result = []
 | |
| 
 | |
|         for response in tool_response:
 | |
|             if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
 | |
|                 response.type == ToolInvokeMessage.MessageType.IMAGE:
 | |
|                 result.append(ToolInvokeMessageBinary(
 | |
|                     mimetype=response.meta.get('mime_type', 'octet/stream'),
 | |
|                     url=response.message,
 | |
|                     save_as=response.save_as,
 | |
|                 ))
 | |
|             elif response.type == ToolInvokeMessage.MessageType.BLOB:
 | |
|                 result.append(ToolInvokeMessageBinary(
 | |
|                     mimetype=response.meta.get('mime_type', 'octet/stream'),
 | |
|                     url=response.message,
 | |
|                     save_as=response.save_as,
 | |
|                 ))
 | |
|             elif response.type == ToolInvokeMessage.MessageType.LINK:
 | |
|                 # check if there is a mime type in meta
 | |
|                 if response.meta and 'mime_type' in response.meta:
 | |
|                     result.append(ToolInvokeMessageBinary(
 | |
|                         mimetype=response.meta.get('mime_type', 'octet/stream') if response.meta else 'octet/stream',
 | |
|                         url=response.message,
 | |
|                         save_as=response.save_as,
 | |
|                     ))
 | |
| 
 | |
|         return result
 | |
|     
 | |
|     def create_message_files(self, messages: list[ToolInvokeMessageBinary]) -> list[tuple[MessageFile, bool]]:
 | |
|         """
 | |
|         Create message file
 | |
| 
 | |
|         :param messages: messages
 | |
|         :return: message files, should save as variable
 | |
|         """
 | |
|         result = []
 | |
| 
 | |
|         for message in messages:
 | |
|             file_type = 'bin'
 | |
|             if 'image' in message.mimetype:
 | |
|                 file_type = 'image'
 | |
|             elif 'video' in message.mimetype:
 | |
|                 file_type = 'video'
 | |
|             elif 'audio' in message.mimetype:
 | |
|                 file_type = 'audio'
 | |
|             elif 'text' in message.mimetype:
 | |
|                 file_type = 'text'
 | |
|             elif 'pdf' in message.mimetype:
 | |
|                 file_type = 'pdf'
 | |
|             elif 'zip' in message.mimetype:
 | |
|                 file_type = 'archive'
 | |
|             # ...
 | |
| 
 | |
|             invoke_from = self.application_generate_entity.invoke_from
 | |
| 
 | |
|             message_file = MessageFile(
 | |
|                 message_id=self.message.id,
 | |
|                 type=file_type,
 | |
|                 transfer_method=FileTransferMethod.TOOL_FILE.value,
 | |
|                 belongs_to='assistant',
 | |
|                 url=message.url,
 | |
|                 upload_file_id=None,
 | |
|                 created_by_role=('account'if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'),
 | |
|                 created_by=self.user_id,
 | |
|             )
 | |
|             db.session.add(message_file)
 | |
|             result.append((
 | |
|                 message_file,
 | |
|                 message.save_as
 | |
|             ))
 | |
|             
 | |
|         db.session.commit()
 | |
| 
 | |
|         return result
 | |
|         
 | |
|     def create_agent_thought(self, message_id: str, message: str, 
 | |
|                              tool_name: str, tool_input: str, messages_ids: list[str]
 | |
|                              ) -> MessageAgentThought:
 | |
|         """
 | |
|         Create agent thought
 | |
|         """
 | |
|         thought = MessageAgentThought(
 | |
|             message_id=message_id,
 | |
|             message_chain_id=None,
 | |
|             thought='',
 | |
|             tool=tool_name,
 | |
|             tool_labels_str='{}',
 | |
|             tool_input=tool_input,
 | |
|             message=message,
 | |
|             message_token=0,
 | |
|             message_unit_price=0,
 | |
|             message_price_unit=0,
 | |
|             message_files=json.dumps(messages_ids) if messages_ids else '',
 | |
|             answer='',
 | |
|             observation='',
 | |
|             answer_token=0,
 | |
|             answer_unit_price=0,
 | |
|             answer_price_unit=0,
 | |
|             tokens=0,
 | |
|             total_price=0,
 | |
|             position=self.agent_thought_count + 1,
 | |
|             currency='USD',
 | |
|             latency=0,
 | |
|             created_by_role='account',
 | |
|             created_by=self.user_id,
 | |
|         )
 | |
| 
 | |
|         db.session.add(thought)
 | |
|         db.session.commit()
 | |
| 
 | |
|         self.agent_thought_count += 1
 | |
| 
 | |
|         return thought
 | |
| 
 | |
|     def save_agent_thought(self, 
 | |
|                            agent_thought: MessageAgentThought, 
 | |
|                            tool_name: str,
 | |
|                            tool_input: Union[str, dict],
 | |
|                            thought: str, 
 | |
|                            observation: str, 
 | |
|                            answer: str,
 | |
|                            messages_ids: list[str],
 | |
|                            llm_usage: LLMUsage = None) -> MessageAgentThought:
 | |
|         """
 | |
|         Save agent thought
 | |
|         """
 | |
|         if thought is not None:
 | |
|             agent_thought.thought = thought
 | |
| 
 | |
|         if tool_name is not None:
 | |
|             agent_thought.tool = tool_name
 | |
| 
 | |
|         if tool_input is not None:
 | |
|             if isinstance(tool_input, dict):
 | |
|                 try:
 | |
|                     tool_input = json.dumps(tool_input, ensure_ascii=False)
 | |
|                 except Exception as e:
 | |
|                     tool_input = json.dumps(tool_input)
 | |
| 
 | |
|             agent_thought.tool_input = tool_input
 | |
| 
 | |
|         if observation is not None:
 | |
|             agent_thought.observation = observation
 | |
| 
 | |
|         if answer is not None:
 | |
|             agent_thought.answer = answer
 | |
| 
 | |
|         if messages_ids is not None and len(messages_ids) > 0:
 | |
|             agent_thought.message_files = json.dumps(messages_ids)
 | |
|         
 | |
|         if llm_usage:
 | |
|             agent_thought.message_token = llm_usage.prompt_tokens
 | |
|             agent_thought.message_price_unit = llm_usage.prompt_price_unit
 | |
|             agent_thought.message_unit_price = llm_usage.prompt_unit_price
 | |
|             agent_thought.answer_token = llm_usage.completion_tokens
 | |
|             agent_thought.answer_price_unit = llm_usage.completion_price_unit
 | |
|             agent_thought.answer_unit_price = llm_usage.completion_unit_price
 | |
|             agent_thought.tokens = llm_usage.total_tokens
 | |
|             agent_thought.total_price = llm_usage.total_price
 | |
| 
 | |
|         # check if tool labels is not empty
 | |
|         labels = agent_thought.tool_labels or {}
 | |
|         tools = agent_thought.tool.split(';') if agent_thought.tool else []
 | |
|         for tool in tools:
 | |
|             if not tool:
 | |
|                 continue
 | |
|             if tool not in labels:
 | |
|                 tool_label = ToolManager.get_tool_label(tool)
 | |
|                 if tool_label:
 | |
|                     labels[tool] = tool_label.to_dict()
 | |
|                 else:
 | |
|                     labels[tool] = {'en_US': tool, 'zh_Hans': tool}
 | |
| 
 | |
|         agent_thought.tool_labels_str = json.dumps(labels)
 | |
| 
 | |
|         db.session.commit()
 | |
|     
 | |
|     def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]:
 | |
|         """
 | |
|         Transform tool message into agent thought
 | |
|         """
 | |
|         result = []
 | |
| 
 | |
|         for message in messages:
 | |
|             if message.type == ToolInvokeMessage.MessageType.TEXT:
 | |
|                 result.append(message)
 | |
|             elif message.type == ToolInvokeMessage.MessageType.LINK:
 | |
|                 result.append(message)
 | |
|             elif message.type == ToolInvokeMessage.MessageType.IMAGE:
 | |
|                 # try to download image
 | |
|                 try:
 | |
|                     file = ToolFileManager.create_file_by_url(user_id=self.user_id, tenant_id=self.tenant_id,
 | |
|                                                                conversation_id=self.message.conversation_id,
 | |
|                                                                file_url=message.message)
 | |
|                     
 | |
|                     url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}'
 | |
| 
 | |
|                     result.append(ToolInvokeMessage(
 | |
|                         type=ToolInvokeMessage.MessageType.IMAGE_LINK,
 | |
|                         message=url,
 | |
|                         save_as=message.save_as,
 | |
|                         meta=message.meta.copy() if message.meta is not None else {},
 | |
|                     ))
 | |
|                 except Exception as e:
 | |
|                     logger.exception(e)
 | |
|                     result.append(ToolInvokeMessage(
 | |
|                         type=ToolInvokeMessage.MessageType.TEXT,
 | |
|                         message=f"Failed to download image: {message.message}, you can try to download it yourself.",
 | |
|                         meta=message.meta.copy() if message.meta is not None else {},
 | |
|                         save_as=message.save_as,
 | |
|                     ))
 | |
|             elif message.type == ToolInvokeMessage.MessageType.BLOB:
 | |
|                 # get mime type and save blob to storage
 | |
|                 mimetype = message.meta.get('mime_type', 'octet/stream')
 | |
|                 # if message is str, encode it to bytes
 | |
|                 if isinstance(message.message, str):
 | |
|                     message.message = message.message.encode('utf-8')
 | |
|                 file = ToolFileManager.create_file_by_raw(user_id=self.user_id, tenant_id=self.tenant_id,
 | |
|                                                             conversation_id=self.message.conversation_id,
 | |
|                                                             file_binary=message.message,
 | |
|                                                             mimetype=mimetype)
 | |
|                                                             
 | |
|                 url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".bin"}'
 | |
| 
 | |
|                 # check if file is image
 | |
|                 if 'image' in mimetype:
 | |
|                     result.append(ToolInvokeMessage(
 | |
|                         type=ToolInvokeMessage.MessageType.IMAGE_LINK,
 | |
|                         message=url,
 | |
|                         save_as=message.save_as,
 | |
|                         meta=message.meta.copy() if message.meta is not None else {},
 | |
|                     ))
 | |
|                 else:
 | |
|                     result.append(ToolInvokeMessage(
 | |
|                         type=ToolInvokeMessage.MessageType.LINK,
 | |
|                         message=url,
 | |
|                         save_as=message.save_as,
 | |
|                         meta=message.meta.copy() if message.meta is not None else {},
 | |
|                     ))
 | |
|             else:
 | |
|                 result.append(message)
 | |
| 
 | |
|         return result
 | |
|     
 | |
|     def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
 | |
|         """
 | |
|         convert tool variables to db variables
 | |
|         """
 | |
|         db_variables.updated_at = datetime.utcnow()
 | |
|         db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
 | |
|         db.session.commit()
 | |
| 
 | |
|     def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
 | |
|         """
 | |
|         Organize agent history
 | |
|         """
 | |
|         result = []
 | |
|         # check if there is a system message in the beginning of the conversation
 | |
|         if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage):
 | |
|             result.append(prompt_messages[0])
 | |
| 
 | |
|         messages: list[Message] = db.session.query(Message).filter(
 | |
|             Message.conversation_id == self.message.conversation_id,
 | |
|         ).order_by(Message.created_at.asc()).all()
 | |
| 
 | |
|         for message in messages:
 | |
|             result.append(UserPromptMessage(content=message.query))
 | |
|             agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
 | |
|             if agent_thoughts:
 | |
|                 for agent_thought in agent_thoughts:
 | |
|                     tools = agent_thought.tool
 | |
|                     if tools:
 | |
|                         tools = tools.split(';')
 | |
|                         tool_calls: list[AssistantPromptMessage.ToolCall] = []
 | |
|                         tool_call_response: list[ToolPromptMessage] = []
 | |
|                         tool_inputs = json.loads(agent_thought.tool_input)
 | |
|                         for tool in tools:
 | |
|                             # generate a uuid for tool call
 | |
|                             tool_call_id = str(uuid.uuid4())
 | |
|                             tool_calls.append(AssistantPromptMessage.ToolCall(
 | |
|                                 id=tool_call_id,
 | |
|                                 type='function',
 | |
|                                 function=AssistantPromptMessage.ToolCall.ToolCallFunction(
 | |
|                                     name=tool,
 | |
|                                     arguments=json.dumps(tool_inputs.get(tool, {})),
 | |
|                                 )
 | |
|                             ))
 | |
|                             tool_call_response.append(ToolPromptMessage(
 | |
|                                 content=agent_thought.observation,
 | |
|                                 name=tool,
 | |
|                                 tool_call_id=tool_call_id,
 | |
|                             ))
 | |
| 
 | |
|                         result.extend([
 | |
|                             AssistantPromptMessage(
 | |
|                                 content=agent_thought.thought,
 | |
|                                 tool_calls=tool_calls,
 | |
|                             ),
 | |
|                             *tool_call_response
 | |
|                         ])
 | |
|                     if not tools:
 | |
|                         result.append(AssistantPromptMessage(content=agent_thought.thought))
 | |
|             else:
 | |
|                 if message.answer:
 | |
|                     result.append(AssistantPromptMessage(content=message.answer))
 | |
| 
 | |
|         return result | 
