| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | import enum | 
					
						
							|  |  |  | from typing import Any, cast | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-12 12:34:01 +08:00
										 |  |  | from langchain.schema import AIMessage, BaseMessage, FunctionMessage, HumanMessage, SystemMessage | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | from pydantic import BaseModel | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-06 13:21:13 +08:00
										 |  |  | from core.model_runtime.entities.message_entities import ( | 
					
						
							|  |  |  |     AssistantPromptMessage, | 
					
						
							|  |  |  |     ImagePromptMessageContent, | 
					
						
							|  |  |  |     PromptMessage, | 
					
						
							|  |  |  |     SystemPromptMessage, | 
					
						
							|  |  |  |     TextPromptMessageContent, | 
					
						
							|  |  |  |     ToolPromptMessage, | 
					
						
							|  |  |  |     UserPromptMessage, | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | class PromptMessageFileType(enum.Enum): | 
					
						
							|  |  |  |     IMAGE = 'image' | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @staticmethod | 
					
						
							|  |  |  |     def value_of(value): | 
					
						
							|  |  |  |         for member in PromptMessageFileType: | 
					
						
							|  |  |  |             if member.value == value: | 
					
						
							|  |  |  |                 return member | 
					
						
							|  |  |  |         raise ValueError(f"No matching enum found for value '{value}'") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class PromptMessageFile(BaseModel): | 
					
						
							|  |  |  |     type: PromptMessageFileType | 
					
						
							|  |  |  |     data: Any | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class ImagePromptMessageFile(PromptMessageFile): | 
					
						
							|  |  |  |     class DETAIL(enum.Enum): | 
					
						
							|  |  |  |         LOW = 'low' | 
					
						
							|  |  |  |         HIGH = 'high' | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     type: PromptMessageFileType = PromptMessageFileType.IMAGE | 
					
						
							|  |  |  |     detail: DETAIL = DETAIL.LOW | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class LCHumanMessageWithFiles(HumanMessage): | 
					
						
							| 
									
										
										
										
											2024-02-12 00:56:17 +08:00
										 |  |  |     # content: Union[str, list[Union[str, Dict]]] | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |     content: str | 
					
						
							|  |  |  |     files: list[PromptMessageFile] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def lc_messages_to_prompt_messages(messages: list[BaseMessage]) -> list[PromptMessage]: | 
					
						
							|  |  |  |     prompt_messages = [] | 
					
						
							|  |  |  |     for message in messages: | 
					
						
							|  |  |  |         if isinstance(message, HumanMessage): | 
					
						
							|  |  |  |             if isinstance(message, LCHumanMessageWithFiles): | 
					
						
							|  |  |  |                 file_prompt_message_contents = [] | 
					
						
							|  |  |  |                 for file in message.files: | 
					
						
							|  |  |  |                     if file.type == PromptMessageFileType.IMAGE: | 
					
						
							|  |  |  |                         file = cast(ImagePromptMessageFile, file) | 
					
						
							|  |  |  |                         file_prompt_message_contents.append(ImagePromptMessageContent( | 
					
						
							|  |  |  |                             data=file.data, | 
					
						
							|  |  |  |                             detail=ImagePromptMessageContent.DETAIL.HIGH | 
					
						
							|  |  |  |                             if file.detail.value == "high" else ImagePromptMessageContent.DETAIL.LOW | 
					
						
							|  |  |  |                         )) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 prompt_message_contents = [TextPromptMessageContent(data=message.content)] | 
					
						
							|  |  |  |                 prompt_message_contents.extend(file_prompt_message_contents) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 prompt_messages.append(UserPromptMessage(content=message.content)) | 
					
						
							|  |  |  |         elif isinstance(message, AIMessage): | 
					
						
							|  |  |  |             message_kwargs = { | 
					
						
							|  |  |  |                 'content': message.content | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if 'function_call' in message.additional_kwargs: | 
					
						
							|  |  |  |                 message_kwargs['tool_calls'] = [ | 
					
						
							|  |  |  |                     AssistantPromptMessage.ToolCall( | 
					
						
							|  |  |  |                         id=message.additional_kwargs['function_call']['id'], | 
					
						
							|  |  |  |                         type='function', | 
					
						
							|  |  |  |                         function=AssistantPromptMessage.ToolCall.ToolCallFunction( | 
					
						
							|  |  |  |                             name=message.additional_kwargs['function_call']['name'], | 
					
						
							|  |  |  |                             arguments=message.additional_kwargs['function_call']['arguments'] | 
					
						
							|  |  |  |                         ) | 
					
						
							|  |  |  |                     ) | 
					
						
							|  |  |  |                 ] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             prompt_messages.append(AssistantPromptMessage(**message_kwargs)) | 
					
						
							|  |  |  |         elif isinstance(message, SystemMessage): | 
					
						
							|  |  |  |             prompt_messages.append(SystemPromptMessage(content=message.content)) | 
					
						
							|  |  |  |         elif isinstance(message, FunctionMessage): | 
					
						
							|  |  |  |             prompt_messages.append(ToolPromptMessage(content=message.content, tool_call_id=message.name)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return prompt_messages | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def prompt_messages_to_lc_messages(prompt_messages: list[PromptMessage]) -> list[BaseMessage]: | 
					
						
							|  |  |  |     messages = [] | 
					
						
							|  |  |  |     for prompt_message in prompt_messages: | 
					
						
							|  |  |  |         if isinstance(prompt_message, UserPromptMessage): | 
					
						
							|  |  |  |             if isinstance(prompt_message.content, str): | 
					
						
							|  |  |  |                 messages.append(HumanMessage(content=prompt_message.content)) | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 message_contents = [] | 
					
						
							|  |  |  |                 for content in prompt_message.content: | 
					
						
							|  |  |  |                     if isinstance(content, TextPromptMessageContent): | 
					
						
							|  |  |  |                         message_contents.append(content.data) | 
					
						
							|  |  |  |                     elif isinstance(content, ImagePromptMessageContent): | 
					
						
							|  |  |  |                         message_contents.append({ | 
					
						
							|  |  |  |                             'type': 'image', | 
					
						
							|  |  |  |                             'data': content.data, | 
					
						
							|  |  |  |                             'detail': content.detail.value | 
					
						
							|  |  |  |                         }) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 messages.append(HumanMessage(content=message_contents)) | 
					
						
							|  |  |  |         elif isinstance(prompt_message, AssistantPromptMessage): | 
					
						
							|  |  |  |             message_kwargs = { | 
					
						
							|  |  |  |                 'content': prompt_message.content | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if prompt_message.tool_calls: | 
					
						
							|  |  |  |                 message_kwargs['additional_kwargs'] = { | 
					
						
							|  |  |  |                     'function_call': { | 
					
						
							|  |  |  |                         'id': prompt_message.tool_calls[0].id, | 
					
						
							|  |  |  |                         'name': prompt_message.tool_calls[0].function.name, | 
					
						
							|  |  |  |                         'arguments': prompt_message.tool_calls[0].function.arguments | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             messages.append(AIMessage(**message_kwargs)) | 
					
						
							|  |  |  |         elif isinstance(prompt_message, SystemPromptMessage): | 
					
						
							|  |  |  |             messages.append(SystemMessage(content=prompt_message.content)) | 
					
						
							|  |  |  |         elif isinstance(prompt_message, ToolPromptMessage): | 
					
						
							|  |  |  |             messages.append(FunctionMessage(name=prompt_message.tool_call_id, content=prompt_message.content)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return messages |