| 
									
										
										
										
											2024-06-15 18:04:24 -07:00
										 |  |  | import asyncio | 
					
						
							| 
									
										
										
										
											2024-06-25 13:23:29 -07:00
										 |  |  | import os | 
					
						
							| 
									
										
										
										
											2024-06-17 17:20:46 -07:00
										 |  |  | import random | 
					
						
							| 
									
										
										
										
											2024-06-25 13:23:29 -07:00
										 |  |  | import sys | 
					
						
							| 
									
										
										
										
											2024-06-15 18:04:24 -07:00
										 |  |  | from asyncio import Future | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-25 13:23:29 -07:00
										 |  |  | from agnext.components import Image, TypeRoutedAgent, message_handler | 
					
						
							|  |  |  | from agnext.core import AgentRuntime, CancellationToken | 
					
						
							|  |  |  | from textual.app import App, ComposeResult | 
					
						
							|  |  |  | from textual.containers import ScrollableContainer | 
					
						
							|  |  |  | from textual.widgets import Button, Footer, Header, Input, Markdown, Static | 
					
						
							|  |  |  | from textual_imageview.viewer import ImageViewer | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from common.types import ( | 
					
						
							| 
									
										
										
										
											2024-06-17 17:20:46 -07:00
										 |  |  |     MultiModalMessage, | 
					
						
							|  |  |  |     PublishNow, | 
					
						
							|  |  |  |     RespondNow, | 
					
						
							|  |  |  |     TextMessage, | 
					
						
							|  |  |  |     ToolApprovalRequest, | 
					
						
							|  |  |  |     ToolApprovalResponse, | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2024-06-15 18:04:24 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-17 17:20:46 -07:00
										 |  |  | class ChatAppMessage(Static): | 
					
						
							|  |  |  |     def __init__(self, message: TextMessage | MultiModalMessage) -> None:  # type: ignore | 
					
						
							|  |  |  |         self._message = message | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							| 
									
										
										
										
											2024-06-15 18:04:24 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def on_mount(self) -> None: | 
					
						
							|  |  |  |         self.styles.margin = 1 | 
					
						
							|  |  |  |         self.styles.padding = 1 | 
					
						
							|  |  |  |         self.styles.border = ("solid", "blue") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-17 17:20:46 -07:00
										 |  |  |     def compose(self) -> ComposeResult: | 
					
						
							|  |  |  |         if isinstance(self._message, TextMessage): | 
					
						
							|  |  |  |             yield Markdown(f"{self._message.source}:") | 
					
						
							|  |  |  |             yield Markdown(self._message.content) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             yield Markdown(f"{self._message.source}:") | 
					
						
							|  |  |  |             for content in self._message.content: | 
					
						
							|  |  |  |                 if isinstance(content, str): | 
					
						
							|  |  |  |                     yield Markdown(content) | 
					
						
							|  |  |  |                 elif isinstance(content, Image): | 
					
						
							|  |  |  |                     viewer = ImageViewer(content.image) | 
					
						
							|  |  |  |                     viewer.styles.min_width = 50 | 
					
						
							|  |  |  |                     viewer.styles.min_height = 50 | 
					
						
							|  |  |  |                     yield viewer | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-15 18:04:24 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  | class WelcomeMessage(Static): | 
					
						
							|  |  |  |     def on_mount(self) -> None: | 
					
						
							|  |  |  |         self.styles.margin = 1 | 
					
						
							|  |  |  |         self.styles.padding = 1 | 
					
						
							|  |  |  |         self.styles.border = ("solid", "blue") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class ChatInput(Input): | 
					
						
							|  |  |  |     def on_mount(self) -> None: | 
					
						
							|  |  |  |         self.focus() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def on_input_submitted(self, event: Input.Submitted) -> None: | 
					
						
							|  |  |  |         self.clear() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class ToolApprovalRequestNotice(Static): | 
					
						
							|  |  |  |     def __init__(self, request: ToolApprovalRequest, response_future: Future[ToolApprovalResponse]) -> None:  # type: ignore | 
					
						
							|  |  |  |         self._tool_call = request.tool_call | 
					
						
							|  |  |  |         self._future = response_future | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def compose(self) -> ComposeResult: | 
					
						
							|  |  |  |         yield Static(f"Tool call: {self._tool_call.name}, arguments: {self._tool_call.arguments[:50]}") | 
					
						
							|  |  |  |         yield Button("Approve", id="approve", variant="warning") | 
					
						
							|  |  |  |         yield Button("Deny", id="deny", variant="default") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def on_mount(self) -> None: | 
					
						
							|  |  |  |         self.styles.margin = 1 | 
					
						
							|  |  |  |         self.styles.padding = 1 | 
					
						
							|  |  |  |         self.styles.border = ("solid", "red") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def on_button_pressed(self, event: Button.Pressed) -> None: | 
					
						
							|  |  |  |         button_id = event.button.id | 
					
						
							|  |  |  |         assert button_id is not None | 
					
						
							|  |  |  |         if button_id == "approve": | 
					
						
							|  |  |  |             self._future.set_result(ToolApprovalResponse(tool_call_id=self._tool_call.id, approved=True, reason="")) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             self._future.set_result(ToolApprovalResponse(tool_call_id=self._tool_call.id, approved=False, reason="")) | 
					
						
							|  |  |  |         self.remove() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class TextualChatApp(App):  # type: ignore | 
					
						
							|  |  |  |     """A Textual app for a chat interface.""" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-17 17:20:46 -07:00
										 |  |  |     def __init__(self, runtime: AgentRuntime, welcoming_notice: str | None = None, user_name: str = "User") -> None:  # type: ignore | 
					
						
							| 
									
										
										
										
											2024-06-15 18:04:24 -07:00
										 |  |  |         self._runtime = runtime | 
					
						
							|  |  |  |         self._welcoming_notice = welcoming_notice | 
					
						
							|  |  |  |         self._user_name = user_name | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def compose(self) -> ComposeResult: | 
					
						
							|  |  |  |         yield Header() | 
					
						
							|  |  |  |         yield Footer() | 
					
						
							|  |  |  |         yield ScrollableContainer(id="chat-messages") | 
					
						
							|  |  |  |         yield ChatInput() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def on_mount(self) -> None: | 
					
						
							| 
									
										
										
										
											2024-06-17 17:20:46 -07:00
										 |  |  |         if self._welcoming_notice is not None: | 
					
						
							|  |  |  |             chat_messages = self.query_one("#chat-messages") | 
					
						
							|  |  |  |             notice = WelcomeMessage(self._welcoming_notice, id="welcome") | 
					
						
							|  |  |  |             chat_messages.mount(notice) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @property | 
					
						
							|  |  |  |     def welcoming_notice(self) -> str | None: | 
					
						
							|  |  |  |         return self._welcoming_notice | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @welcoming_notice.setter | 
					
						
							|  |  |  |     def welcoming_notice(self, value: str) -> None: | 
					
						
							|  |  |  |         self._welcoming_notice = value | 
					
						
							| 
									
										
										
										
											2024-06-15 18:04:24 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  |     async def on_input_submitted(self, event: Input.Submitted) -> None: | 
					
						
							|  |  |  |         user_input = event.value | 
					
						
							|  |  |  |         await self.publish_user_message(user_input) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     async def post_request_user_input_notice(self) -> None: | 
					
						
							|  |  |  |         chat_messages = self.query_one("#chat-messages") | 
					
						
							|  |  |  |         notice = Static("Please enter your input.", id="typing") | 
					
						
							|  |  |  |         chat_messages.mount(notice) | 
					
						
							|  |  |  |         notice.scroll_visible() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     async def publish_user_message(self, user_input: str) -> None: | 
					
						
							|  |  |  |         chat_messages = self.query_one("#chat-messages") | 
					
						
							|  |  |  |         # Remove all typing messages. | 
					
						
							|  |  |  |         chat_messages.query("#typing").remove() | 
					
						
							|  |  |  |         # Publish the user message to the runtime. | 
					
						
							| 
									
										
										
										
											2024-06-18 14:53:18 -04:00
										 |  |  |         await self._runtime.publish_message( | 
					
						
							|  |  |  |             TextMessage(source=self._user_name, content=user_input), namespace="default" | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-06-15 18:04:24 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-17 17:20:46 -07:00
										 |  |  |     async def post_runtime_message(self, message: TextMessage | MultiModalMessage) -> None:  # type: ignore | 
					
						
							| 
									
										
										
										
											2024-06-15 18:04:24 -07:00
										 |  |  |         """Post a message from the agent runtime to the message list.""" | 
					
						
							|  |  |  |         chat_messages = self.query_one("#chat-messages") | 
					
						
							| 
									
										
										
										
											2024-06-17 17:20:46 -07:00
										 |  |  |         msg = ChatAppMessage(message) | 
					
						
							| 
									
										
										
										
											2024-06-15 18:04:24 -07:00
										 |  |  |         chat_messages.mount(msg) | 
					
						
							|  |  |  |         msg.scroll_visible() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     async def handle_tool_approval_request(self, message: ToolApprovalRequest) -> ToolApprovalResponse:  # type: ignore | 
					
						
							|  |  |  |         chat_messages = self.query_one("#chat-messages") | 
					
						
							| 
									
										
										
										
											2024-06-17 17:20:46 -07:00
										 |  |  |         future: Future[ToolApprovalResponse] = asyncio.get_event_loop().create_future()  # type: ignore | 
					
						
							| 
									
										
										
										
											2024-06-15 18:04:24 -07:00
										 |  |  |         tool_call_approval_notice = ToolApprovalRequestNotice(message, future) | 
					
						
							|  |  |  |         chat_messages.mount(tool_call_approval_notice) | 
					
						
							|  |  |  |         tool_call_approval_notice.scroll_visible() | 
					
						
							|  |  |  |         return await future | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class TextualUserAgent(TypeRoutedAgent):  # type: ignore | 
					
						
							|  |  |  |     """An agent that is used to receive messages from the runtime.""" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-18 14:53:18 -04:00
										 |  |  |     def __init__(self, description: str, app: TextualChatApp) -> None:  # type: ignore | 
					
						
							|  |  |  |         super().__init__(description) | 
					
						
							| 
									
										
										
										
											2024-06-15 18:04:24 -07:00
										 |  |  |         self._app = app | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @message_handler  # type: ignore | 
					
						
							|  |  |  |     async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:  # type: ignore | 
					
						
							|  |  |  |         await self._app.post_runtime_message(message) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-17 17:20:46 -07:00
										 |  |  |     @message_handler  # type: ignore | 
					
						
							|  |  |  |     async def on_multi_modal_message(self, message: MultiModalMessage, cancellation_token: CancellationToken) -> None:  # type: ignore | 
					
						
							|  |  |  |         # Save the message to file. | 
					
						
							|  |  |  |         # Generate a ramdom file name. | 
					
						
							|  |  |  |         for content in message.content: | 
					
						
							|  |  |  |             if isinstance(content, Image): | 
					
						
							|  |  |  |                 filename = f"{self.metadata['name']}_{message.source}_{random.randbytes(16).hex()}.png" | 
					
						
							|  |  |  |                 content.image.save(filename) | 
					
						
							|  |  |  |         await self._app.post_runtime_message(message) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-15 18:04:24 -07:00
										 |  |  |     @message_handler  # type: ignore | 
					
						
							|  |  |  |     async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> None:  # type: ignore | 
					
						
							|  |  |  |         await self._app.post_request_user_input_notice() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @message_handler  # type: ignore | 
					
						
							|  |  |  |     async def on_publish_now(self, message: PublishNow, cancellation_token: CancellationToken) -> None:  # type: ignore | 
					
						
							|  |  |  |         await self._app.post_request_user_input_notice() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @message_handler  # type: ignore | 
					
						
							|  |  |  |     async def on_tool_approval_request( | 
					
						
							|  |  |  |         self, message: ToolApprovalRequest, cancellation_token: CancellationToken | 
					
						
							|  |  |  |     ) -> ToolApprovalResponse: | 
					
						
							|  |  |  |         return await self._app.handle_tool_approval_request(message) |