Cohere Client (#3004)

* initial setup for cohere client

* client update

* changes: ClintType added to the utils

* Revert "changes: ClintType added to the utils"

This reverts commit 80d61552287f2d2eff50b4c0f1a4adfd97233aa3.

* Message conversion to Cohere, Parameter handling, cost calculation, streaming, tool calling

* Changed Groq references.

* minor fix

* tests added

* ref fix

* added in the workflows

* Fixed bug on non-streaming text generation

* fix: formatting

* Support Cohere rule for last message not USER when tool_results exist

* Added Cohere to documentation

* fixed client.py merge, removed unnecessary comments in groq.py, updated Cohere documentation, added Groq documentation

* log: ignored params

* update: custom exception added

---------

Co-authored-by: Mark Sze <mark@sze.family>
Co-authored-by: Mark Sze <66362098+marklysze@users.noreply.github.com>
This commit is contained in:
HRUSHIKESH DOKALA 2024-07-03 20:03:03 +05:30 committed by GitHub
parent 8133b7de22
commit b4a3f263b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 1661 additions and 10 deletions

View File

@ -638,3 +638,39 @@ jobs:
with:
file: ./coverage.xml
flags: unittests
CohereTest:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v4
with:
lfs: true
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install packages and dependencies for all tests
run: |
python -m pip install --upgrade pip wheel
pip install pytest-cov>=5
- name: Install packages and dependencies for Cohere
run: |
pip install -e .[cohere,test]
- name: Set AUTOGEN_USE_DOCKER based on OS
shell: bash
run: |
if [[ ${{ matrix.os }} != ubuntu-latest ]]; then
echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV
fi
- name: Coverage
run: |
pytest test/oai/test_cohere.py --skip-openai
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittests

View File

@ -18,6 +18,7 @@ from .base_logger import LLMConfig
if TYPE_CHECKING:
from autogen import Agent, ConversableAgent, OpenAIWrapper
from autogen.oai.anthropic import AnthropicClient
from autogen.oai.cohere import CohereClient
from autogen.oai.gemini import GeminiClient
from autogen.oai.groq import GroqClient
from autogen.oai.mistral import MistralAIClient
@ -205,7 +206,16 @@ class FileLogger(BaseLogger):
def log_new_client(
self,
client: AzureOpenAI | OpenAI | GeminiClient | AnthropicClient | MistralAIClient | TogetherClient | GroqClient,
client: (
AzureOpenAI
| OpenAI
| GeminiClient
| AnthropicClient
| MistralAIClient
| TogetherClient
| GroqClient
| CohereClient
),
wrapper: OpenAIWrapper,
init_args: Dict[str, Any],
) -> None:

View File

@ -19,6 +19,7 @@ from .base_logger import LLMConfig
if TYPE_CHECKING:
from autogen import Agent, ConversableAgent, OpenAIWrapper
from autogen.oai.anthropic import AnthropicClient
from autogen.oai.cohere import CohereClient
from autogen.oai.gemini import GeminiClient
from autogen.oai.groq import GroqClient
from autogen.oai.mistral import MistralAIClient
@ -392,7 +393,16 @@ class SqliteLogger(BaseLogger):
def log_new_client(
self,
client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient, GroqClient],
client: Union[
AzureOpenAI,
OpenAI,
GeminiClient,
AnthropicClient,
MistralAIClient,
TogetherClient,
GroqClient,
CohereClient,
],
wrapper: OpenAIWrapper,
init_args: Dict[str, Any],
) -> None:

View File

@ -77,6 +77,13 @@ try:
except ImportError as e:
groq_import_exception = e
try:
from autogen.oai.cohere import CohereClient
cohere_import_exception: Optional[ImportError] = None
except ImportError as e:
cohere_import_exception = e
logger = logging.getLogger(__name__)
if not logger.handlers:
# Add the console handler.
@ -497,6 +504,11 @@ class OpenAIWrapper:
raise ImportError("Please install `groq` to use the Groq API.")
client = GroqClient(**openai_config)
self._clients.append(client)
elif api_type is not None and api_type.startswith("cohere"):
if cohere_import_exception:
raise ImportError("Please install `cohere` to use the Groq API.")
client = CohereClient(**openai_config)
self._clients.append(client)
else:
client = OpenAI(**openai_config)
self._clients.append(OpenAIClient(client))

459
autogen/oai/cohere.py Normal file
View File

@ -0,0 +1,459 @@
"""Create an OpenAI-compatible client using Cohere's API.
Example:
llm_config={
"config_list": [{
"api_type": "cohere",
"model": "command-r-plus",
"api_key": os.environ.get("COHERE_API_KEY")
}
]}
agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
Install Cohere's python library using: pip install --upgrade cohere
Resources:
- https://docs.cohere.com/reference/chat
"""
from __future__ import annotations
import json
import logging
import os
import random
import sys
import time
import warnings
from typing import Any, Dict, List
from cohere import Client as Cohere
from cohere.types import ToolParameterDefinitionsValue, ToolResult
from flaml.automl.logger import logger_formatter
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.completion_usage import CompletionUsage
from autogen.oai.client_utils import validate_parameter
logger = logging.getLogger(__name__)
if not logger.handlers:
# Add the console handler.
_ch = logging.StreamHandler(stream=sys.stdout)
_ch.setFormatter(logger_formatter)
logger.addHandler(_ch)
COHERE_PRICING_1K = {
"command-r-plus": (0.003, 0.015),
"command-r": (0.0005, 0.0015),
"command-nightly": (0.00025, 0.00125),
"command": (0.015, 0.075),
"command-light": (0.008, 0.024),
"command-light-nightly": (0.008, 0.024),
}
class CohereClient:
"""Client for Cohere's API."""
def __init__(self, **kwargs):
"""Requires api_key or environment variable to be set
Args:
api_key (str): The API key for using Cohere (or environment variable COHERE_API_KEY needs to be set)
"""
# Ensure we have the api_key upon instantiation
self.api_key = kwargs.get("api_key", None)
if not self.api_key:
self.api_key = os.getenv("COHERE_API_KEY")
assert (
self.api_key
), "Please include the api_key in your config list entry for Cohere or set the COHERE_API_KEY env variable."
def message_retrieval(self, response) -> List:
"""
Retrieve and return a list of strings or a list of Choice.Message from the response.
NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
"""
return [choice.message for choice in response.choices]
def cost(self, response) -> float:
return response.cost
@staticmethod
def get_usage(response) -> Dict:
"""Return usage summary of the response using RESPONSE_USAGE_KEYS."""
# ... # pragma: no cover
return {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
"cost": response.cost,
"model": response.model,
}
def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""Loads the parameters for Cohere API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults"""
cohere_params = {}
# Check that we have what we need to use Cohere's API
# We won't enforce the available models as they are likely to change
cohere_params["model"] = params.get("model", None)
assert cohere_params[
"model"
], "Please specify the 'model' in your config list entry to nominate the Cohere model to use."
# Validate allowed Cohere parameters
# https://docs.cohere.com/reference/chat
cohere_params["temperature"] = validate_parameter(
params, "temperature", (int, float), False, 0.3, (0, None), None
)
cohere_params["max_tokens"] = validate_parameter(params, "max_tokens", int, True, None, (0, None), None)
cohere_params["k"] = validate_parameter(params, "k", int, False, 0, (0, 500), None)
cohere_params["p"] = validate_parameter(params, "p", (int, float), False, 0.75, (0.01, 0.99), None)
cohere_params["seed"] = validate_parameter(params, "seed", int, True, None, None, None)
cohere_params["frequency_penalty"] = validate_parameter(
params, "frequency_penalty", (int, float), True, 0, (0, 1), None
)
cohere_params["presence_penalty"] = validate_parameter(
params, "presence_penalty", (int, float), True, 0, (0, 1), None
)
# Cohere parameters we are ignoring:
# preamble - we will put the system prompt in here.
# parallel_tool_calls (defaults to True), perfect as is.
# conversation_id - allows resuming a previous conversation, we don't support this.
logging.info("Conversation ID: %s", params.get("conversation_id", "None"))
# connectors - allows web search or other custom connectors, not implementing for now but could be useful in the future.
logging.info("Connectors: %s", params.get("connectors", "None"))
# search_queries_only - to control whether only search queries are used, we're not using connectors so ignoring.
# documents - a list of documents that can be used to support the chat. Perhaps useful in the future for RAG.
# citation_quality - used for RAG flows and dependent on other parameters we're ignoring.
# max_input_tokens - limits input tokens, not needed.
logging.info("Max Input Tokens: %s", params.get("max_input_tokens", "None"))
# stop_sequences - used to stop generation, not needed.
logging.info("Stop Sequences: %s", params.get("stop_sequences", "None"))
return cohere_params
def create(self, params: Dict) -> ChatCompletion:
messages = params.get("messages", [])
# Parse parameters to the Cohere API's parameters
cohere_params = self.parse_params(params)
# Convert AutoGen messages to Cohere messages
cohere_messages, preamble, final_message = oai_messages_to_cohere_messages(messages, params, cohere_params)
cohere_params["chat_history"] = cohere_messages
cohere_params["message"] = final_message
cohere_params["preamble"] = preamble
# We use chat model by default
client = Cohere(api_key=self.api_key)
# Token counts will be returned
prompt_tokens = 0
completion_tokens = 0
total_tokens = 0
# Stream if in parameters
streaming = True if "stream" in params and params["stream"] else False
cohere_finish = ""
max_retries = 5
for attempt in range(max_retries):
ans = None
try:
if streaming:
response = client.chat_stream(**cohere_params)
else:
response = client.chat(**cohere_params)
except CohereRateLimitError as e:
raise RuntimeError(f"Cohere exception occurred: {e}")
else:
if streaming:
# Streaming...
ans = ""
for event in response:
if event.event_type == "text-generation":
ans = ans + event.text
elif event.event_type == "tool-calls-generation":
# When streaming, tool calls are compiled at the end into a single event_type
ans = event.text
cohere_finish = "tool_calls"
tool_calls = []
for tool_call in event.tool_calls:
tool_calls.append(
ChatCompletionMessageToolCall(
id=str(random.randint(0, 100000)),
function={
"name": tool_call.name,
"arguments": (
"" if tool_call.parameters is None else json.dumps(tool_call.parameters)
),
},
type="function",
)
)
# Not using billed_units, but that may be better for cost purposes
prompt_tokens = event.response.meta.tokens.input_tokens
completion_tokens = event.response.meta.tokens.output_tokens
total_tokens = prompt_tokens + completion_tokens
response_id = event.response.response_id
else:
# Non-streaming finished
ans: str = response.text
# Not using billed_units, but that may be better for cost purposes
prompt_tokens = response.meta.tokens.input_tokens
completion_tokens = response.meta.tokens.output_tokens
total_tokens = prompt_tokens + completion_tokens
response_id = response.response_id
break
if response is not None:
response_content = ans
if streaming:
# Streaming response
if cohere_finish == "":
cohere_finish = "stop"
tool_calls = None
else:
# Non-streaming response
# If we have tool calls as the response, populate completed tool calls for our return OAI response
if response.tool_calls is not None:
cohere_finish = "tool_calls"
tool_calls = []
for tool_call in response.tool_calls:
# if parameters are null, clear them out (Cohere can return a string "null" if no parameter values)
tool_calls.append(
ChatCompletionMessageToolCall(
id=str(random.randint(0, 100000)),
function={
"name": tool_call.name,
"arguments": (
"" if tool_call.parameters is None else json.dumps(tool_call.parameters)
),
},
type="function",
)
)
else:
cohere_finish = "stop"
tool_calls = None
else:
raise RuntimeError(f"Failed to get response from Cohere after retrying {attempt + 1} times.")
# 3. convert output
message = ChatCompletionMessage(
role="assistant",
content=response_content,
function_call=None,
tool_calls=tool_calls,
)
choices = [Choice(finish_reason=cohere_finish, index=0, message=message)]
response_oai = ChatCompletion(
id=response_id,
model=cohere_params["model"],
created=int(time.time()),
object="chat.completion",
choices=choices,
usage=CompletionUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
),
cost=calculate_cohere_cost(prompt_tokens, completion_tokens, cohere_params["model"]),
)
return response_oai
def oai_messages_to_cohere_messages(
messages: list[Dict[str, Any]], params: Dict[str, Any], cohere_params: Dict[str, Any]
) -> tuple[list[dict[str, Any]], str, str]:
"""Convert messages from OAI format to Cohere's format.
We correct for any specific role orders and types.
Parameters:
messages: list[Dict[str, Any]]: AutoGen messages
params: Dict[str, Any]: AutoGen parameters dictionary
cohere_params: Dict[str, Any]: Cohere parameters dictionary
Returns:
List[Dict[str, Any]]: Chat History messages
str: Preamble (system message)
str: Message (the final user message)
"""
cohere_messages = []
preamble = ""
# Tools
if "tools" in params:
cohere_tools = []
for tool in params["tools"]:
# build list of properties
parameters = {}
for key, value in tool["function"]["parameters"]["properties"].items():
type_str = value["type"]
required = True # Defaults to False, we could consider leaving it as default.
description = value["description"]
# If we have an 'enum' key, add that to the description (as not allowed to pass in enum as a field)
if "enum" in value:
# Access the enum list
enum_values = value["enum"]
enum_strings = [str(value) for value in enum_values]
enum_string = ", ".join(enum_strings)
description = description + ". Possible values are " + enum_string + "."
parameters[key] = ToolParameterDefinitionsValue(
description=description, type=type_str, required=required
)
cohere_tool = {
"name": tool["function"]["name"],
"description": tool["function"]["description"],
"parameter_definitions": parameters,
}
cohere_tools.append(cohere_tool)
if len(cohere_tools) > 0:
cohere_params["tools"] = cohere_tools
tool_calls = []
tool_results = []
# Rules for cohere messages:
# no 'name' field
# 'system' messages go into the preamble parameter
# user role = 'USER'
# assistant role = 'CHATBOT'
# 'content' field renamed to 'message'
# tools go into tools parameter
# tool_results go into tool_results parameter
for message in messages:
if "role" in message and message["role"] == "system":
# System message
if preamble == "":
preamble = message["content"]
else:
preamble = preamble + "\n" + message["content"]
elif "tool_calls" in message:
# Suggested tool calls, build up the list before we put it into the tool_results
for tool_call in message["tool_calls"]:
tool_calls.append(tool_call)
# We also add the suggested tool call as a message
new_message = {
"role": "CHATBOT",
"message": message["content"],
# Not including tools in this message, may need to. Testing required.
}
cohere_messages.append(new_message)
elif "role" in message and message["role"] == "tool":
if "tool_call_id" in message:
# Convert the tool call to a result
tool_call_id = message["tool_call_id"]
content_output = message["content"]
# Find the original tool
for tool_call in tool_calls:
if tool_call["id"] == tool_call_id:
call = {
"name": tool_call["function"]["name"],
"parameters": json.loads(
tool_call["function"]["arguments"]
if not tool_call["function"]["arguments"] == ""
else "{}"
),
}
output = [{"value": content_output}]
tool_results.append(ToolResult(call=call, outputs=output))
break
elif "content" in message and isinstance(message["content"], str):
# Standard text message
new_message = {
"role": "USER" if message["role"] == "user" else "CHATBOT",
"message": message["content"],
}
cohere_messages.append(new_message)
# Append any Tool Results
if len(tool_results) != 0:
cohere_params["tool_results"] = tool_results
# Enable multi-step tool use: https://docs.cohere.com/docs/multi-step-tool-use
cohere_params["force_single_step"] = False
# If we're adding tool_results, like we are, the last message can't be a USER message
# So, we add a CHATBOT 'continue' message, if so.
if cohere_messages[-1]["role"] == "USER":
cohere_messages.append({"role": "CHATBOT", "content": "Please continue."})
# We return a blank message when we have tool results
# TODO: Check what happens if tool_results aren't the latest message
return cohere_messages, preamble, ""
else:
# We need to get the last message to assign to the message field for Cohere,
# if the last message is a user message, use that, otherwise put in 'continue'.
if cohere_messages[-1]["role"] == "USER":
return cohere_messages[0:-1], preamble, cohere_messages[-1]["message"]
else:
return cohere_messages, preamble, "Please continue."
def calculate_cohere_cost(input_tokens: int, output_tokens: int, model: str) -> float:
"""Calculate the cost of the completion using the Cohere pricing."""
total = 0.0
if model in COHERE_PRICING_1K:
input_cost_per_k, output_cost_per_k = COHERE_PRICING_1K[model]
input_cost = (input_tokens / 1000) * input_cost_per_k
output_cost = (output_tokens / 1000) * output_cost_per_k
total = input_cost + output_cost
else:
warnings.warn(f"Cost calculation not available for {model} model", UserWarning)
return total
class CohereError(Exception):
"""Base class for other Cohere exceptions"""
pass
class CohereRateLimitError(CohereError):
"""Raised when rate limit is exceeded"""
pass

View File

@ -259,13 +259,6 @@ def oai_messages_to_groq_messages(messages: list[Dict[str, Any]]) -> list[dict[s
groq_messages = copy.deepcopy(messages)
# If we have a message with role='tool', which occurs when a function is executed, change it to 'user'
"""
for msg in together_messages:
if "role" in msg and msg["role"] == "tool":
msg["role"] = "user"
"""
# Remove the name field
for message in groq_messages:
if "name" in message:

View File

@ -14,6 +14,7 @@ from autogen.logger.logger_factory import LoggerFactory
if TYPE_CHECKING:
from autogen import Agent, ConversableAgent, OpenAIWrapper
from autogen.oai.anthropic import AnthropicClient
from autogen.oai.cohere import CohereClient
from autogen.oai.gemini import GeminiClient
from autogen.oai.groq import GroqClient
from autogen.oai.mistral import MistralAIClient
@ -111,7 +112,9 @@ def log_new_wrapper(wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig
def log_new_client(
client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient, GroqClient],
client: Union[
AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient, GroqClient, CohereClient
],
wrapper: OpenAIWrapper,
init_args: Dict[str, Any],
) -> None:

View File

@ -89,6 +89,7 @@ extra_require = {
"anthropic": ["anthropic>=0.23.1"],
"mistral": ["mistralai>=0.2.0"],
"groq": ["groq>=0.9.0"],
"cohere": ["cohere>=5.5.8"],
}
setuptools.setup(

69
test/oai/test_cohere.py Normal file
View File

@ -0,0 +1,69 @@
#!/usr/bin/env python3 -m pytest
import os
import pytest
try:
from autogen.oai.cohere import CohereClient, calculate_cohere_cost
skip = False
except ImportError:
CohereClient = object
skip = True
reason = "Cohere dependency not installed!"
@pytest.fixture()
def cohere_client():
return CohereClient(api_key="dummy_api_key")
@pytest.mark.skipif(skip, reason=reason)
def test_initialization_missing_api_key():
os.environ.pop("COHERE_API_KEY", None)
with pytest.raises(
AssertionError,
match="Please include the api_key in your config list entry for Cohere or set the COHERE_API_KEY env variable.",
):
CohereClient()
CohereClient(api_key="dummy_api_key")
@pytest.mark.skipif(skip, reason=reason)
def test_intialization(cohere_client):
assert cohere_client.api_key == "dummy_api_key", "`api_key` should be correctly set in the config"
@pytest.mark.skipif(skip, reason=reason)
def test_calculate_cohere_cost():
assert (
calculate_cohere_cost(0, 0, model="command-r") == 0.0
), "Cost should be 0 for 0 input_tokens and 0 output_tokens"
assert calculate_cohere_cost(100, 200, model="command-r-plus") == 0.0033
@pytest.mark.skipif(skip, reason=reason)
def test_load_config(cohere_client):
params = {
"model": "command-r-plus",
"stream": False,
"temperature": 1,
"p": 0.8,
"max_tokens": 100,
}
expected_params = {
"model": "command-r-plus",
"temperature": 1,
"p": 0.8,
"seed": None,
"max_tokens": 100,
"frequency_penalty": 0,
"presence_penalty": 0,
"k": 0,
}
result = cohere_client.parse_params(params)
assert result == expected_params, "Config should be correctly loaded"

View File

@ -0,0 +1,534 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Cohere\n",
"\n",
"[Cohere](https://cohere.com/) is a cloud based platform serving their own LLMs, in particular the Command family of models.\n",
"\n",
"Cohere's API differs from OpenAI's, which is the native API used by AutoGen, so to use Cohere's LLMs you need to use this library.\n",
"\n",
"You will need a Cohere account and create an API key. [See their website for further details](https://cohere.com/)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Features\n",
"\n",
"When using this client class, AutoGen's messages are automatically tailored to accommodate the specific requirements of Cohere's API.\n",
"\n",
"Additionally, this client class provides support for function/tool calling and will track token usage and cost correctly as per Cohere's API costs (as of July 2024).\n",
"\n",
"## Getting started\n",
"\n",
"First you need to install the `pyautogen` package to use AutoGen with the Cohere API library.\n",
"\n",
"``` bash\n",
"pip install pyautogen[cohere]\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Cohere provides a number of models to use, included below. See the list of [models here](https://docs.cohere.com/docs/models).\n",
"\n",
"See the sample `OAI_CONFIG_LIST` below showing how the Cohere client class is used by specifying the `api_type` as `cohere`.\n",
"\n",
"```python\n",
"[\n",
" {\n",
" \"model\": \"gpt-35-turbo\",\n",
" \"api_key\": \"your OpenAI Key goes here\",\n",
" },\n",
" {\n",
" \"model\": \"gpt-4-vision-preview\",\n",
" \"api_key\": \"your OpenAI Key goes here\",\n",
" },\n",
" {\n",
" \"model\": \"dalle\",\n",
" \"api_key\": \"your OpenAI Key goes here\",\n",
" },\n",
" {\n",
" \"model\": \"command-r-plus\",\n",
" \"api_key\": \"your Cohere API Key goes here\",\n",
" \"api_type\": \"cohere\"\n",
" },\n",
" {\n",
" \"model\": \"command-r\",\n",
" \"api_key\": \"your Cohere API Key goes here\",\n",
" \"api_type\": \"cohere\"\n",
" },\n",
" {\n",
" \"model\": \"command\",\n",
" \"api_key\": \"your Cohere API Key goes here\",\n",
" \"api_type\": \"cohere\"\n",
" }\n",
"]\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As an alternative to the `api_key` key and value in the config, you can set the environment variable `COHERE_API_KEY` to your Cohere key.\n",
"\n",
"Linux/Mac:\n",
"``` bash\n",
"export COHERE_API_KEY=\"your_cohere_api_key_here\"\n",
"```\n",
"\n",
"Windows:\n",
"``` bash\n",
"set COHERE_API_KEY=your_cohere_api_key_here\n",
"```\n",
"\n",
"## API parameters\n",
"\n",
"The following parameters can be added to your config for the Cohere API. See [this link](https://docs.cohere.com/reference/chat) for further information on them and their default values.\n",
"\n",
"- temperature (number > 0)\n",
"- p (number 0.01..0.99)\n",
"- k (number 0..500)\n",
"- max_tokens (null, integer >= 0)\n",
"- seed (null, integer)\n",
"- frequency_penalty (number 0..1)\n",
"- presence_penalty (number 0..1)\n",
"\n",
"Example:\n",
"```python\n",
"[\n",
" {\n",
" \"model\": \"command-r\",\n",
" \"api_key\": \"your Cohere API Key goes here\",\n",
" \"api_type\": \"cohere\",\n",
" \"temperature\": 0.5,\n",
" \"p\": 0.2,\n",
" \"k\": 100,\n",
" \"max_tokens\": 2048,\n",
" \"seed\": 42,\n",
" \"frequency_penalty\": 0.5,\n",
" \"presence_penalty\": 0.2\n",
" }\n",
"]\n",
"```\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Two-Agent Coding Example\n",
"\n",
"In this example, we run a two-agent chat with an AssistantAgent (primarily a coding agent) to generate code to count the number of prime numbers between 1 and 10,000 and then it will be executed.\n",
"\n",
"We'll use Cohere's Command R model which is suitable for coding."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"config_list = [\n",
" {\n",
" # Let's choose the Command-R model\n",
" \"model\": \"command-r\",\n",
" # Provide your Cohere's API key here or put it into the COHERE_API_KEY environment variable.\n",
" \"api_key\": os.environ.get(\"COHERE_API_KEY\"),\n",
" # We specify the API Type as 'cohere' so it uses the Cohere client class\n",
" \"api_type\": \"cohere\",\n",
" }\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Importantly, we have tweaked the system message so that the model doesn't return the termination keyword, which we've changed to FINISH, with the code block."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"from pathlib import Path\n",
"\n",
"from autogen import AssistantAgent, UserProxyAgent\n",
"from autogen.coding import LocalCommandLineCodeExecutor\n",
"\n",
"# Setting up the code executor\n",
"workdir = Path(\"coding\")\n",
"workdir.mkdir(exist_ok=True)\n",
"code_executor = LocalCommandLineCodeExecutor(work_dir=workdir)\n",
"\n",
"# Setting up the agents\n",
"\n",
"# The UserProxyAgent will execute the code that the AssistantAgent provides\n",
"user_proxy_agent = UserProxyAgent(\n",
" name=\"User\",\n",
" code_execution_config={\"executor\": code_executor},\n",
" is_termination_msg=lambda msg: \"FINISH\" in msg.get(\"content\"),\n",
")\n",
"\n",
"system_message = \"\"\"You are a helpful AI assistant who writes code and the user executes it.\n",
"Solve tasks using your coding and language skills.\n",
"In the following cases, suggest python code (in a python coding block) for the user to execute.\n",
"Solve the task step by step if you need to. If a plan is not provided, explain your plan first. Be clear which step uses code, and which step uses your language skill.\n",
"When using code, you must indicate the script type in the code block. The user cannot provide any other feedback or perform any other action beyond executing the code you suggest. The user can't modify your code. So do not suggest incomplete code which requires users to modify. Don't use a code block if it's not intended to be executed by the user.\n",
"Don't include multiple code blocks in one response. Do not ask users to copy and paste the result. Instead, use 'print' function for the output when relevant. Check the execution result returned by the user.\n",
"If the result indicates there is an error, fix the error and output the code again. Suggest the full code instead of partial code or code changes. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, collect additional info you need, and think of a different approach to try.\n",
"When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible.\n",
"IMPORTANT: Wait for the user to execute your code and then you can reply with the word \"FINISH\". DO NOT OUTPUT \"FINISH\" after your code block.\"\"\"\n",
"\n",
"# The AssistantAgent, using Cohere's model, will take the coding request and return code\n",
"assistant_agent = AssistantAgent(\n",
" name=\"Cohere Assistant\",\n",
" system_message=system_message,\n",
" llm_config={\"config_list\": config_list},\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mUser\u001b[0m (to Cohere Assistant):\n",
"\n",
"Provide code to count the number of prime numbers from 1 to 10000.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mCohere Assistant\u001b[0m (to User):\n",
"\n",
"Here's the code to count the number of prime numbers from 1 to 10,000:\n",
"```python\n",
"# Prime Number Counter\n",
"count = 0\n",
"for num in range(2, 10001):\n",
" if num > 1:\n",
" for div in range(2, num):\n",
" if (num % div) == 0:\n",
" break\n",
" else:\n",
" count += 1\n",
"print(count)\n",
"```\n",
"\n",
"My plan is to use two nested loops. The outer loop iterates through numbers from 2 to 10,000. The inner loop checks if there's any divisor for the current number in the range from 2 to the number itself. If there's no such divisor, the number is prime and the counter is incremented.\n",
"\n",
"Please execute the code and let me know the output.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> NO HUMAN INPUT RECEIVED.\u001b[0m\n",
"\u001b[31m\n",
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[31m\n",
">>>>>>>> EXECUTING CODE BLOCK (inferred language is python)...\u001b[0m\n",
"\u001b[33mUser\u001b[0m (to Cohere Assistant):\n",
"\n",
"exitcode: 0 (execution succeeded)\n",
"Code output: 1229\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mCohere Assistant\u001b[0m (to User):\n",
"\n",
"That's correct! The code you executed successfully found 1229 prime numbers within the specified range.\n",
"\n",
"FINISH.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> NO HUMAN INPUT RECEIVED.\u001b[0m\n"
]
}
],
"source": [
"# Start the chat, with the UserProxyAgent asking the AssistantAgent the message\n",
"chat_result = user_proxy_agent.initiate_chat(\n",
" assistant_agent,\n",
" message=\"Provide code to count the number of prime numbers from 1 to 10000.\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tool Call Example\n",
"\n",
"In this example, instead of writing code, we will show how Cohere's Command R+ model can perform parallel tool calling, where it recommends calling more than one tool at a time.\n",
"\n",
"We'll use a simple travel agent assistant program where we have a couple of tools for weather and currency conversion.\n",
"\n",
"We start by importing libraries and setting up our configuration to use Command R+ and the `cohere` client class."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import os\n",
"from typing import Literal\n",
"\n",
"from typing_extensions import Annotated\n",
"\n",
"import autogen\n",
"\n",
"config_list = [\n",
" {\"api_type\": \"cohere\", \"model\": \"command-r-plus\", \"api_key\": os.getenv(\"COHERE_API_KEY\"), \"cache_seed\": None}\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Create our two agents."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# Create the agent for tool calling\n",
"chatbot = autogen.AssistantAgent(\n",
" name=\"chatbot\",\n",
" system_message=\"\"\"For currency exchange and weather forecasting tasks,\n",
" only use the functions you have been provided with.\n",
" Output 'HAVE FUN!' when an answer has been provided.\"\"\",\n",
" llm_config={\"config_list\": config_list},\n",
")\n",
"\n",
"# Note that we have changed the termination string to be \"HAVE FUN!\"\n",
"user_proxy = autogen.UserProxyAgent(\n",
" name=\"user_proxy\",\n",
" is_termination_msg=lambda x: x.get(\"content\", \"\") and \"HAVE FUN!\" in x.get(\"content\", \"\"),\n",
" human_input_mode=\"NEVER\",\n",
" max_consecutive_auto_reply=1,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Create the two functions, annotating them so that those descriptions can be passed through to the LLM.\n",
"\n",
"We associate them with the agents using `register_for_execution` for the user_proxy so it can execute the function and `register_for_llm` for the chatbot (powered by the LLM) so it can pass the function definitions to the LLM."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# Currency Exchange function\n",
"\n",
"CurrencySymbol = Literal[\"USD\", \"EUR\"]\n",
"\n",
"# Define our function that we expect to call\n",
"\n",
"\n",
"def exchange_rate(base_currency: CurrencySymbol, quote_currency: CurrencySymbol) -> float:\n",
" if base_currency == quote_currency:\n",
" return 1.0\n",
" elif base_currency == \"USD\" and quote_currency == \"EUR\":\n",
" return 1 / 1.1\n",
" elif base_currency == \"EUR\" and quote_currency == \"USD\":\n",
" return 1.1\n",
" else:\n",
" raise ValueError(f\"Unknown currencies {base_currency}, {quote_currency}\")\n",
"\n",
"\n",
"# Register the function with the agent\n",
"\n",
"\n",
"@user_proxy.register_for_execution()\n",
"@chatbot.register_for_llm(description=\"Currency exchange calculator.\")\n",
"def currency_calculator(\n",
" base_amount: Annotated[float, \"Amount of currency in base_currency\"],\n",
" base_currency: Annotated[CurrencySymbol, \"Base currency\"] = \"USD\",\n",
" quote_currency: Annotated[CurrencySymbol, \"Quote currency\"] = \"EUR\",\n",
") -> str:\n",
" quote_amount = exchange_rate(base_currency, quote_currency) * base_amount\n",
" return f\"{format(quote_amount, '.2f')} {quote_currency}\"\n",
"\n",
"\n",
"# Weather function\n",
"\n",
"\n",
"# Example function to make available to model\n",
"def get_current_weather(location, unit=\"fahrenheit\"):\n",
" \"\"\"Get the weather for some location\"\"\"\n",
" if \"chicago\" in location.lower():\n",
" return json.dumps({\"location\": \"Chicago\", \"temperature\": \"13\", \"unit\": unit})\n",
" elif \"san francisco\" in location.lower():\n",
" return json.dumps({\"location\": \"San Francisco\", \"temperature\": \"55\", \"unit\": unit})\n",
" elif \"new york\" in location.lower():\n",
" return json.dumps({\"location\": \"New York\", \"temperature\": \"11\", \"unit\": unit})\n",
" else:\n",
" return json.dumps({\"location\": location, \"temperature\": \"unknown\"})\n",
"\n",
"\n",
"# Register the function with the agent\n",
"\n",
"\n",
"@user_proxy.register_for_execution()\n",
"@chatbot.register_for_llm(description=\"Weather forecast for US cities.\")\n",
"def weather_forecast(\n",
" location: Annotated[str, \"City name\"],\n",
") -> str:\n",
" weather_details = get_current_weather(location=location)\n",
" weather = json.loads(weather_details)\n",
" return f\"{weather['location']} will be {weather['temperature']} degrees {weather['unit']}\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We pass through our customers message and run the chat.\n",
"\n",
"Finally, we ask the LLM to summarise the chat and print that out."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"What's the weather in New York and can you tell me how much is 123.45 EUR in USD so I can spend it on my holiday? Throw a few holiday tips in as well.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
"I will use the weather_forecast function to find out the weather in New York, and the currency_calculator function to convert 123.45 EUR to USD. I will then search for 'holiday tips' to find some extra information to include in my answer.\n",
"\u001b[32m***** Suggested tool call (45212): weather_forecast *****\u001b[0m\n",
"Arguments: \n",
"{\"location\": \"New York\"}\n",
"\u001b[32m*********************************************************\u001b[0m\n",
"\u001b[32m***** Suggested tool call (16564): currency_calculator *****\u001b[0m\n",
"Arguments: \n",
"{\"base_amount\": 123.45, \"base_currency\": \"EUR\", \"quote_currency\": \"USD\"}\n",
"\u001b[32m************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[35m\n",
">>>>>>>> EXECUTING FUNCTION weather_forecast...\u001b[0m\n",
"\u001b[35m\n",
">>>>>>>> EXECUTING FUNCTION currency_calculator...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[32m***** Response from calling tool (45212) *****\u001b[0m\n",
"New York will be 11 degrees fahrenheit\n",
"\u001b[32m**********************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[32m***** Response from calling tool (16564) *****\u001b[0m\n",
"135.80 USD\n",
"\u001b[32m**********************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
"The weather in New York is 11 degrees Fahrenheit. \n",
"\n",
"€123.45 is worth $135.80. \n",
"\n",
"Here are some holiday tips:\n",
"- Make sure to pack layers for the cold weather\n",
"- Try the local cuisine, New York is famous for its pizza\n",
"- Visit Central Park and take in the views from the top of the Rockefeller Centre\n",
"\n",
"HAVE FUN!\n",
"\n",
"--------------------------------------------------------------------------------\n",
"LLM SUMMARY: The weather in New York is 11 degrees Fahrenheit. 123.45 EUR is worth 135.80 USD. Holiday tips: make sure to pack warm clothes and have a great time!\n"
]
}
],
"source": [
"# start the conversation\n",
"res = user_proxy.initiate_chat(\n",
" chatbot,\n",
" message=\"What's the weather in New York and can you tell me how much is 123.45 EUR in USD so I can spend it on my holiday? Throw a few holiday tips in as well.\",\n",
" summary_method=\"reflection_with_llm\",\n",
")\n",
"\n",
"print(f\"LLM SUMMARY: {res.summary['content']}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see that Command R+ recommended we call both tools and passed through the right parameters. The `user_proxy` executed them and this was passed back to Command R+ to interpret them and respond. Finally, Command R+ was asked to summarise the whole conversation."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "autogen",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -0,0 +1,524 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Groq\n",
"\n",
"[Groq](https://groq.com/) is a cloud based platform serving a number of popular open weight models at high inference speeds. Models include Meta's Llama 3, Mistral AI's Mixtral, and Google's Gemma.\n",
"\n",
"Although Groq's API is aligned well with OpenAI's, which is the native API used by AutoGen, this library provides the ability to set specific parameters as well as track API costs.\n",
"\n",
"You will need a Groq account and create an API key. [See their website for further details](https://groq.com/)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Groq provides a number of models to use, included below. See the list of [models here (requires login)](https://console.groq.com/docs/models).\n",
"\n",
"See the sample `OAI_CONFIG_LIST` below showing how the Groq client class is used by specifying the `api_type` as `groq`.\n",
"\n",
"```python\n",
"[\n",
" {\n",
" \"model\": \"gpt-35-turbo\",\n",
" \"api_key\": \"your OpenAI Key goes here\",\n",
" },\n",
" {\n",
" \"model\": \"gpt-4-vision-preview\",\n",
" \"api_key\": \"your OpenAI Key goes here\",\n",
" },\n",
" {\n",
" \"model\": \"dalle\",\n",
" \"api_key\": \"your OpenAI Key goes here\",\n",
" },\n",
" {\n",
" \"model\": \"llama3-8b-8192\",\n",
" \"api_key\": \"your Groq API Key goes here\",\n",
" \"api_type\": \"groq\"\n",
" },\n",
" {\n",
" \"model\": \"llama3-70b-8192\",\n",
" \"api_key\": \"your Groq API Key goes here\",\n",
" \"api_type\": \"groq\"\n",
" },\n",
" {\n",
" \"model\": \"Mixtral 8x7b\",\n",
" \"api_key\": \"your Groq API Key goes here\",\n",
" \"api_type\": \"groq\"\n",
" },\n",
" {\n",
" \"model\": \"gemma-7b-it\",\n",
" \"api_key\": \"your Groq API Key goes here\",\n",
" \"api_type\": \"groq\"\n",
" }\n",
"]\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As an alternative to the `api_key` key and value in the config, you can set the environment variable `GROQ_API_KEY` to your Groq key.\n",
"\n",
"Linux/Mac:\n",
"``` bash\n",
"export GROQ_API_KEY=\"your_groq_api_key_here\"\n",
"```\n",
"\n",
"Windows:\n",
"``` bash\n",
"set GROQ_API_KEY=your_groq_api_key_here\n",
"```\n",
"\n",
"## API parameters\n",
"\n",
"The following parameters can be added to your config for the Groq API. See [this link](https://console.groq.com/docs/text-chat) for further information on them.\n",
"\n",
"- frequency_penalty (number 0..1)\n",
"- max_tokens (integer >= 0)\n",
"- presence_penalty (number -2..2)\n",
"- seed (integer)\n",
"- temperature (number 0..2)\n",
"- top_p (number)\n",
"\n",
"Example:\n",
"```python\n",
"[\n",
" {\n",
" \"model\": \"llama3-8b-8192\",\n",
" \"api_key\": \"your Groq API Key goes here\",\n",
" \"api_type\": \"groq\",\n",
" \"frequency_penalty\": 0.5,\n",
" \"max_tokens\": 2048,\n",
" \"presence_penalty\": 0.2,\n",
" \"seed\": 42,\n",
" \"temperature\": 0.5,\n",
" \"top_p\": 0.2\n",
" }\n",
"]\n",
"```\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Two-Agent Coding Example\n",
"\n",
"In this example, we run a two-agent chat with an AssistantAgent (primarily a coding agent) to generate code to count the number of prime numbers between 1 and 10,000 and then it will be executed.\n",
"\n",
"We'll use Meta's Llama 3 model which is suitable for coding."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"config_list = [\n",
" {\n",
" # Let's choose the Llama 3 model\n",
" \"model\": \"llama3-8b-8192\",\n",
" # Put your Groq API key here or put it into the GROQ_API_KEY environment variable.\n",
" \"api_key\": os.environ.get(\"GROQ_API_KEY\"),\n",
" # We specify the API Type as 'groq' so it uses the Groq client class\n",
" \"api_type\": \"groq\",\n",
" }\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Importantly, we have tweaked the system message so that the model doesn't return the termination keyword, which we've changed to FINISH, with the code block."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"from pathlib import Path\n",
"\n",
"from autogen import AssistantAgent, UserProxyAgent\n",
"from autogen.coding import LocalCommandLineCodeExecutor\n",
"\n",
"# Setting up the code executor\n",
"workdir = Path(\"coding\")\n",
"workdir.mkdir(exist_ok=True)\n",
"code_executor = LocalCommandLineCodeExecutor(work_dir=workdir)\n",
"\n",
"# Setting up the agents\n",
"\n",
"# The UserProxyAgent will execute the code that the AssistantAgent provides\n",
"user_proxy_agent = UserProxyAgent(\n",
" name=\"User\",\n",
" code_execution_config={\"executor\": code_executor},\n",
" is_termination_msg=lambda msg: \"FINISH\" in msg.get(\"content\"),\n",
")\n",
"\n",
"system_message = \"\"\"You are a helpful AI assistant who writes code and the user executes it.\n",
"Solve tasks using your coding and language skills.\n",
"In the following cases, suggest python code (in a python coding block) for the user to execute.\n",
"Solve the task step by step if you need to. If a plan is not provided, explain your plan first. Be clear which step uses code, and which step uses your language skill.\n",
"When using code, you must indicate the script type in the code block. The user cannot provide any other feedback or perform any other action beyond executing the code you suggest. The user can't modify your code. So do not suggest incomplete code which requires users to modify. Don't use a code block if it's not intended to be executed by the user.\n",
"Don't include multiple code blocks in one response. Do not ask users to copy and paste the result. Instead, use 'print' function for the output when relevant. Check the execution result returned by the user.\n",
"If the result indicates there is an error, fix the error and output the code again. Suggest the full code instead of partial code or code changes. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, collect additional info you need, and think of a different approach to try.\n",
"When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible.\n",
"IMPORTANT: Wait for the user to execute your code and then you can reply with the word \"FINISH\". DO NOT OUTPUT \"FINISH\" after your code block.\"\"\"\n",
"\n",
"# The AssistantAgent, using Groq's model, will take the coding request and return code\n",
"assistant_agent = AssistantAgent(\n",
" name=\"Groq Assistant\",\n",
" system_message=system_message,\n",
" llm_config={\"config_list\": config_list},\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mUser\u001b[0m (to Groq Assistant):\n",
"\n",
"Provide code to count the number of prime numbers from 1 to 10000.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mGroq Assistant\u001b[0m (to User):\n",
"\n",
"Here's the plan to count the number of prime numbers from 1 to 10000:\n",
"\n",
"First, we need to write a helper function to check if a number is prime. A prime number is a number that is divisible only by 1 and itself.\n",
"\n",
"Then, we can use a loop to iterate through all numbers from 1 to 10000, check if each number is prime using our helper function, and count the number of prime numbers found.\n",
"\n",
"Here's the Python code to implement this plan:\n",
"```python\n",
"def is_prime(n):\n",
" if n <= 1:\n",
" return False\n",
" for i in range(2, int(n ** 0.5) + 1):\n",
" if n % i == 0:\n",
" return False\n",
" return True\n",
"\n",
"count = 0\n",
"for i in range(2, 10001):\n",
" if is_prime(i):\n",
" count += 1\n",
"\n",
"print(count)\n",
"```\n",
"Please execute this code, and I'll wait for the result.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> NO HUMAN INPUT RECEIVED.\u001b[0m\n",
"\u001b[31m\n",
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[31m\n",
">>>>>>>> EXECUTING CODE BLOCK (inferred language is python)...\u001b[0m\n",
"\u001b[33mUser\u001b[0m (to Groq Assistant):\n",
"\n",
"exitcode: 0 (execution succeeded)\n",
"Code output: 1229\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mGroq Assistant\u001b[0m (to User):\n",
"\n",
"FINISH\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> NO HUMAN INPUT RECEIVED.\u001b[0m\n"
]
}
],
"source": [
"# Start the chat, with the UserProxyAgent asking the AssistantAgent the message\n",
"chat_result = user_proxy_agent.initiate_chat(\n",
" assistant_agent,\n",
" message=\"Provide code to count the number of prime numbers from 1 to 10000.\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tool Call Example\n",
"\n",
"In this example, instead of writing code, we will show how we can use Meta's Llama 3 model to perform parallel tool calling, where it recommends calling more than one tool at a time, using Groq's cloud inference.\n",
"\n",
"We'll use a simple travel agent assistant program where we have a couple of tools for weather and currency conversion.\n",
"\n",
"We start by importing libraries and setting up our configuration to use Meta's Llama 3 model and the `groq` client class."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import os\n",
"from typing import Literal\n",
"\n",
"from typing_extensions import Annotated\n",
"\n",
"import autogen\n",
"\n",
"config_list = [\n",
" {\"api_type\": \"groq\", \"model\": \"llama3-8b-8192\", \"api_key\": os.getenv(\"GROQ_API_KEY\"), \"cache_seed\": None}\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Create our two agents."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# Create the agent for tool calling\n",
"chatbot = autogen.AssistantAgent(\n",
" name=\"chatbot\",\n",
" system_message=\"\"\"For currency exchange and weather forecasting tasks,\n",
" only use the functions you have been provided with.\n",
" Output 'HAVE FUN!' when an answer has been provided.\"\"\",\n",
" llm_config={\"config_list\": config_list},\n",
")\n",
"\n",
"# Note that we have changed the termination string to be \"HAVE FUN!\"\n",
"user_proxy = autogen.UserProxyAgent(\n",
" name=\"user_proxy\",\n",
" is_termination_msg=lambda x: x.get(\"content\", \"\") and \"HAVE FUN!\" in x.get(\"content\", \"\"),\n",
" human_input_mode=\"NEVER\",\n",
" max_consecutive_auto_reply=1,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Create the two functions, annotating them so that those descriptions can be passed through to the LLM.\n",
"\n",
"We associate them with the agents using `register_for_execution` for the user_proxy so it can execute the function and `register_for_llm` for the chatbot (powered by the LLM) so it can pass the function definitions to the LLM."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# Currency Exchange function\n",
"\n",
"CurrencySymbol = Literal[\"USD\", \"EUR\"]\n",
"\n",
"# Define our function that we expect to call\n",
"\n",
"\n",
"def exchange_rate(base_currency: CurrencySymbol, quote_currency: CurrencySymbol) -> float:\n",
" if base_currency == quote_currency:\n",
" return 1.0\n",
" elif base_currency == \"USD\" and quote_currency == \"EUR\":\n",
" return 1 / 1.1\n",
" elif base_currency == \"EUR\" and quote_currency == \"USD\":\n",
" return 1.1\n",
" else:\n",
" raise ValueError(f\"Unknown currencies {base_currency}, {quote_currency}\")\n",
"\n",
"\n",
"# Register the function with the agent\n",
"\n",
"\n",
"@user_proxy.register_for_execution()\n",
"@chatbot.register_for_llm(description=\"Currency exchange calculator.\")\n",
"def currency_calculator(\n",
" base_amount: Annotated[float, \"Amount of currency in base_currency\"],\n",
" base_currency: Annotated[CurrencySymbol, \"Base currency\"] = \"USD\",\n",
" quote_currency: Annotated[CurrencySymbol, \"Quote currency\"] = \"EUR\",\n",
") -> str:\n",
" quote_amount = exchange_rate(base_currency, quote_currency) * base_amount\n",
" return f\"{format(quote_amount, '.2f')} {quote_currency}\"\n",
"\n",
"\n",
"# Weather function\n",
"\n",
"\n",
"# Example function to make available to model\n",
"def get_current_weather(location, unit=\"fahrenheit\"):\n",
" \"\"\"Get the weather for some location\"\"\"\n",
" if \"chicago\" in location.lower():\n",
" return json.dumps({\"location\": \"Chicago\", \"temperature\": \"13\", \"unit\": unit})\n",
" elif \"san francisco\" in location.lower():\n",
" return json.dumps({\"location\": \"San Francisco\", \"temperature\": \"55\", \"unit\": unit})\n",
" elif \"new york\" in location.lower():\n",
" return json.dumps({\"location\": \"New York\", \"temperature\": \"11\", \"unit\": unit})\n",
" else:\n",
" return json.dumps({\"location\": location, \"temperature\": \"unknown\"})\n",
"\n",
"\n",
"# Register the function with the agent\n",
"\n",
"\n",
"@user_proxy.register_for_execution()\n",
"@chatbot.register_for_llm(description=\"Weather forecast for US cities.\")\n",
"def weather_forecast(\n",
" location: Annotated[str, \"City name\"],\n",
") -> str:\n",
" weather_details = get_current_weather(location=location)\n",
" weather = json.loads(weather_details)\n",
" return f\"{weather['location']} will be {weather['temperature']} degrees {weather['unit']}\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We pass through our customers message and run the chat.\n",
"\n",
"Finally, we ask the LLM to summarise the chat and print that out."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"What's the weather in New York and can you tell me how much is 123.45 EUR in USD so I can spend it on my holiday? Throw a few holiday tips in as well.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
"\u001b[32m***** Suggested tool call (call_hg7g): weather_forecast *****\u001b[0m\n",
"Arguments: \n",
"{\"location\":\"New York\"}\n",
"\u001b[32m*************************************************************\u001b[0m\n",
"\u001b[32m***** Suggested tool call (call_hrsf): currency_calculator *****\u001b[0m\n",
"Arguments: \n",
"{\"base_amount\":123.45,\"base_currency\":\"EUR\",\"quote_currency\":\"USD\"}\n",
"\u001b[32m****************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[35m\n",
">>>>>>>> EXECUTING FUNCTION weather_forecast...\u001b[0m\n",
"\u001b[35m\n",
">>>>>>>> EXECUTING FUNCTION currency_calculator...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[32m***** Response from calling tool (call_hg7g) *****\u001b[0m\n",
"New York will be 11 degrees fahrenheit\n",
"\u001b[32m**************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[32m***** Response from calling tool (call_hrsf) *****\u001b[0m\n",
"135.80 USD\n",
"\u001b[32m**************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
"\u001b[32m***** Suggested tool call (call_ahwk): weather_forecast *****\u001b[0m\n",
"Arguments: \n",
"{\"location\":\"New York\"}\n",
"\u001b[32m*************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"LLM SUMMARY: Based on the conversation, it's predicted that New York will be 11 degrees Fahrenheit. You also found out that 123.45 EUR is equal to 135.80 USD. Here are a few holiday tips:\n",
"\n",
"* Pack warm clothing for your visit to New York, as the temperature is expected to be quite chilly.\n",
"* Consider exchanging your money at a local currency exchange or an ATM since the exchange rate might not be as favorable in tourist areas.\n",
"* Make sure to check the estimated expenses for your holiday and adjust your budget accordingly.\n",
"\n",
"I hope you have a great trip!\n"
]
}
],
"source": [
"# start the conversation\n",
"res = user_proxy.initiate_chat(\n",
" chatbot,\n",
" message=\"What's the weather in New York and can you tell me how much is 123.45 EUR in USD so I can spend it on my holiday? Throw a few holiday tips in as well.\",\n",
" summary_method=\"reflection_with_llm\",\n",
")\n",
"\n",
"print(f\"LLM SUMMARY: {res.summary['content']}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Using its fast inference, Groq required less than 2 seconds for the whole chat!\n",
"\n",
"Additionally, Llama 3 was able to call both tools and pass through the right parameters. The `user_proxy` then executed them and this was passed back for Llama 3 to summarise the whole conversation."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "autogen",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}