From ec500a31b8b8d9f6382a06966f12d01dbb5f3918 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Mon, 15 Jul 2024 13:43:04 -0400 Subject: [PATCH] Marketing sample agents (#210) * add writer impl * add graphic designer * add worker and auditor, remove writer * add worker, add simple test main --- python/examples/marketing-agents/app.py | 31 +++++++++++ python/examples/marketing-agents/auditor.py | 33 ++++++++++++ .../marketing-agents/graphic_designer.py | 31 +++++++++++ python/examples/marketing-agents/messages.py | 21 ++++++++ .../examples/marketing-agents/test_usage.py | 51 +++++++++++++++++++ python/examples/marketing-agents/worker.py | 24 +++++++++ python/pyproject.toml | 1 + .../src/agnext/components/tools/__init__.py | 2 +- 8 files changed, 193 insertions(+), 1 deletion(-) create mode 100644 python/examples/marketing-agents/app.py create mode 100644 python/examples/marketing-agents/auditor.py create mode 100644 python/examples/marketing-agents/graphic_designer.py create mode 100644 python/examples/marketing-agents/messages.py create mode 100644 python/examples/marketing-agents/test_usage.py create mode 100644 python/examples/marketing-agents/worker.py diff --git a/python/examples/marketing-agents/app.py b/python/examples/marketing-agents/app.py new file mode 100644 index 000000000..25b40c39e --- /dev/null +++ b/python/examples/marketing-agents/app.py @@ -0,0 +1,31 @@ +import os + +from agnext.components.models import AzureOpenAIChatCompletionClient +from agnext.core import AgentRuntime +from auditor import AuditAgent +from graphic_designer import GraphicDesignerAgent +from openai import AsyncAzureOpenAI + + +async def build_app(runtime: AgentRuntime) -> None: + chat_client = AzureOpenAIChatCompletionClient( + model="gpt-4-32", + azure_endpoint=os.environ["CHAT_ENDPOINT"], + api_version="2024-02-01", + model_capabilities={ + "vision": True, + "function_calling": True, + "json_output": True, + }, + api_key=os.environ["CHAT_ENDPOINT_KEY"], + ) + + image_client = AsyncAzureOpenAI( + azure_endpoint=os.environ["IMAGE_ENDPOINT"], + azure_deployment="dall-e-3", + api_key=os.environ["IMAGE_ENDPOINT_KEY"], + api_version="2024-02-01", + ) + + runtime.register("GraphicDesigner", lambda: GraphicDesignerAgent(client=image_client)) + runtime.register("Auditor", lambda: AuditAgent(model_client=chat_client)) diff --git a/python/examples/marketing-agents/auditor.py b/python/examples/marketing-agents/auditor.py new file mode 100644 index 000000000..8cf7b9c5a --- /dev/null +++ b/python/examples/marketing-agents/auditor.py @@ -0,0 +1,33 @@ +from agnext.components import TypeRoutedAgent, message_handler +from agnext.components.models import ChatCompletionClient +from agnext.components.models._types import SystemMessage +from agnext.core import CancellationToken +from messages import AuditorAlert, AuditText + +auditor_prompt = """You are an Auditor in a Marketing team +Audit the text bello and make sure we do not give discounts larger than 10% +If the text talks about a larger than 10% discount, reply with a message to the user saying that the discount is too large, and by company policy we are not allowed. +If the message says who wrote it, add that information in the response as well +In any other case, reply with NOTFORME +--- +Input: {input} +--- +""" + + +class AuditAgent(TypeRoutedAgent): + def __init__( + self, + model_client: ChatCompletionClient, + ) -> None: + super().__init__("") + self._model_client = model_client + + @message_handler + async def handle_user_chat_input(self, message: AuditText, cancellation_token: CancellationToken) -> None: + sys_prompt = auditor_prompt.format(input=message.text) + completion = await self._model_client.create(messages=[SystemMessage(content=sys_prompt)]) + assert isinstance(completion.content, str) + if "NOTFORME" in completion.content: + return + await self.publish_message(AuditorAlert(user_id=message.user_id, auditor_alert_message=completion.content)) diff --git a/python/examples/marketing-agents/graphic_designer.py b/python/examples/marketing-agents/graphic_designer.py new file mode 100644 index 000000000..d6785919b --- /dev/null +++ b/python/examples/marketing-agents/graphic_designer.py @@ -0,0 +1,31 @@ +from typing import Literal + +import openai +from agnext.components import ( + TypeRoutedAgent, + message_handler, +) +from agnext.core import CancellationToken +from messages import ArticleCreated, GraphicDesignCreated + + +class GraphicDesignerAgent(TypeRoutedAgent): + def __init__( + self, + client: openai.AsyncClient, + model: Literal["dall-e-2", "dall-e-3"] = "dall-e-3", + ): + super().__init__("") + self._client = client + self._model = model + + @message_handler + async def handle_user_chat_input(self, message: ArticleCreated, cancellation_token: CancellationToken) -> None: + response = await self._client.images.generate( + model=self._model, prompt=message.article, response_format="b64_json" + ) + assert len(response.data) > 0 and response.data[0].b64_json is not None + image_base64 = response.data[0].b64_json + image_uri = f"data:image/png;base64,{image_base64}" + + await self.publish_message(GraphicDesignCreated(user_id=message.user_id, image_uri=image_uri)) diff --git a/python/examples/marketing-agents/messages.py b/python/examples/marketing-agents/messages.py new file mode 100644 index 000000000..0726f814e --- /dev/null +++ b/python/examples/marketing-agents/messages.py @@ -0,0 +1,21 @@ +from pydantic import BaseModel + + +class ArticleCreated(BaseModel): + user_id: str + article: str + + +class GraphicDesignCreated(BaseModel): + user_id: str + image_uri: str + + +class AuditText(BaseModel): + user_id: str + text: str + + +class AuditorAlert(BaseModel): + user_id: str + auditor_alert_message: str diff --git a/python/examples/marketing-agents/test_usage.py b/python/examples/marketing-agents/test_usage.py new file mode 100644 index 000000000..fd4c226d3 --- /dev/null +++ b/python/examples/marketing-agents/test_usage.py @@ -0,0 +1,51 @@ +import asyncio +import os + +from agnext.application import SingleThreadedAgentRuntime +from agnext.components import Image, TypeRoutedAgent, message_handler +from agnext.core import CancellationToken +from app import build_app +from dotenv import load_dotenv +from messages import ArticleCreated, AuditorAlert, AuditText, GraphicDesignCreated + + +class Printer(TypeRoutedAgent): + def __init__( + self, + ) -> None: + super().__init__("") + + @message_handler + async def handle_graphic_design(self, message: GraphicDesignCreated, cancellation_token: CancellationToken) -> None: + image = Image.from_uri(message.image_uri) + # Save image to random name in current directory + image.image.save(os.path.join(os.getcwd(), f"{message.user_id}.png")) + print(f"Received GraphicDesignCreated: user {message.user_id}, saved to {message.user_id}.png") + + @message_handler + async def handle_auditor_alert(self, message: AuditorAlert, cancellation_token: CancellationToken) -> None: + print(f"Received AuditorAlert: {message.auditor_alert_message} for user {message.user_id}") + + +async def main() -> None: + runtime = SingleThreadedAgentRuntime() + await build_app(runtime) + runtime.register("Printer", lambda: Printer()) + + ctx = runtime.start() + + await runtime.publish_message( + AuditText(text="Buy my product for a MASSIVE 50% discount.", user_id="user-1"), namespace="default" + ) + + await runtime.publish_message( + ArticleCreated(article="The best article ever written about trees and rocks", user_id="user-2"), + namespace="default", + ) + + await ctx.stop_when_idle() + + +if __name__ == "__main__": + load_dotenv() + asyncio.run(main()) diff --git a/python/examples/marketing-agents/worker.py b/python/examples/marketing-agents/worker.py new file mode 100644 index 000000000..4ab805bd5 --- /dev/null +++ b/python/examples/marketing-agents/worker.py @@ -0,0 +1,24 @@ +import asyncio +import logging +import os + +from agnext.worker.worker_runtime import WorkerAgentRuntime +from app import build_app + + +async def main() -> None: + runtime = WorkerAgentRuntime() + await runtime.setup_channel(os.environ["AGENT_HOST"]) + + await build_app(runtime) + + # just to keep the runtime running + try: + await asyncio.sleep(1000000) + except KeyboardInterrupt: + pass + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + asyncio.run(main()) diff --git a/python/pyproject.toml b/python/pyproject.toml index ef14f4141..dfea206e1 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -48,6 +48,7 @@ dependencies = [ "grpcio-tools", "markdownify", "types-protobuf", + "python-dotenv" ] [tool.hatch.envs.default.extra-scripts] diff --git a/python/src/agnext/components/tools/__init__.py b/python/src/agnext/components/tools/__init__.py index f791167dd..dcfb1759a 100644 --- a/python/src/agnext/components/tools/__init__.py +++ b/python/src/agnext/components/tools/__init__.py @@ -1,4 +1,4 @@ -from ._base import BaseTool, BaseToolWithState, Tool, ToolSchema, ParametersSchema +from ._base import BaseTool, BaseToolWithState, ParametersSchema, Tool, ToolSchema from ._code_execution import CodeExecutionInput, CodeExecutionResult, PythonCodeExecutionTool from ._function_tool import FunctionTool