Add Cerebras Integration (#3585)

* Cerebras Integration

* Address feedback

* Fix typo

* Run formatter
This commit is contained in:
Henry Tu 2024-09-30 17:14:55 -04:00 committed by GitHub
parent b8d749daac
commit 3fdf8dea22
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 1083 additions and 0 deletions

View File

@ -474,6 +474,46 @@ jobs:
file: ./coverage.xml
flags: unittests
CerebrasTest:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-2019]
python-version: ["3.9", "3.10", "3.11", "3.12"]
exclude:
- os: macos-latest
python-version: "3.9"
steps:
- uses: actions/checkout@v4
with:
lfs: true
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install packages and dependencies for all tests
run: |
python -m pip install --upgrade pip wheel
pip install pytest-cov>=5
- name: Install packages and dependencies for Cerebras
run: |
pip install -e .[cerebras_cloud_sdk,test]
- name: Set AUTOGEN_USE_DOCKER based on OS
shell: bash
run: |
if [[ ${{ matrix.os }} != ubuntu-latest ]]; then
echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV
fi
- name: Coverage
run: |
pytest test/oai/test_cerebras.py --skip-openai
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittests
MistralTest:
runs-on: ${{ matrix.os }}
strategy:

View File

@ -19,6 +19,7 @@ if TYPE_CHECKING:
from autogen import Agent, ConversableAgent, OpenAIWrapper
from autogen.oai.anthropic import AnthropicClient
from autogen.oai.bedrock import BedrockClient
from autogen.oai.cerebras import CerebrasClient
from autogen.oai.cohere import CohereClient
from autogen.oai.gemini import GeminiClient
from autogen.oai.groq import GroqClient
@ -210,6 +211,7 @@ class FileLogger(BaseLogger):
client: (
AzureOpenAI
| OpenAI
| CerebrasClient
| GeminiClient
| AnthropicClient
| MistralAIClient

View File

@ -20,6 +20,7 @@ if TYPE_CHECKING:
from autogen import Agent, ConversableAgent, OpenAIWrapper
from autogen.oai.anthropic import AnthropicClient
from autogen.oai.bedrock import BedrockClient
from autogen.oai.cerebras import CerebrasClient
from autogen.oai.cohere import CohereClient
from autogen.oai.gemini import GeminiClient
from autogen.oai.groq import GroqClient
@ -397,6 +398,7 @@ class SqliteLogger(BaseLogger):
client: Union[
AzureOpenAI,
OpenAI,
CerebrasClient,
GeminiClient,
AnthropicClient,
MistralAIClient,

270
autogen/oai/cerebras.py Normal file
View File

@ -0,0 +1,270 @@
"""Create an OpenAI-compatible client using Cerebras's API.
Example:
llm_config={
"config_list": [{
"api_type": "cerebras",
"model": "llama3.1-8b",
"api_key": os.environ.get("CEREBRAS_API_KEY")
}]
}
agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
Install Cerebras's python library using: pip install --upgrade cerebras_cloud_sdk
Resources:
- https://inference-docs.cerebras.ai/quickstart
"""
from __future__ import annotations
import copy
import os
import time
import warnings
from typing import Any, Dict, List
from cerebras.cloud.sdk import Cerebras, Stream
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.completion_usage import CompletionUsage
from autogen.oai.client_utils import should_hide_tools, validate_parameter
CEREBRAS_PRICING_1K = {
# Convert pricing per million to per thousand tokens.
"llama3.1-8b": (0.10 / 1000, 0.10 / 1000),
"llama3.1-70b": (0.60 / 1000, 0.60 / 1000),
}
class CerebrasClient:
"""Client for Cerebras's API."""
def __init__(self, api_key=None, **kwargs):
"""Requires api_key or environment variable to be set
Args:
api_key (str): The API key for using Cerebras (or environment variable CEREBRAS_API_KEY needs to be set)
"""
# Ensure we have the api_key upon instantiation
self.api_key = api_key
if not self.api_key:
self.api_key = os.getenv("CEREBRAS_API_KEY")
assert (
self.api_key
), "Please include the api_key in your config list entry for Cerebras or set the CEREBRAS_API_KEY env variable."
def message_retrieval(self, response: ChatCompletion) -> List:
"""
Retrieve and return a list of strings or a list of Choice.Message from the response.
NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
"""
return [choice.message for choice in response.choices]
def cost(self, response: ChatCompletion) -> float:
# Note: This field isn't explicitly in `ChatCompletion`, but is injected during chat creation.
return response.cost
@staticmethod
def get_usage(response: ChatCompletion) -> Dict:
"""Return usage summary of the response using RESPONSE_USAGE_KEYS."""
# ... # pragma: no cover
return {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
"cost": response.cost,
"model": response.model,
}
def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""Loads the parameters for Cerebras API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults"""
cerebras_params = {}
# Check that we have what we need to use Cerebras's API
# We won't enforce the available models as they are likely to change
cerebras_params["model"] = params.get("model", None)
assert cerebras_params[
"model"
], "Please specify the 'model' in your config list entry to nominate the Cerebras model to use."
# Validate allowed Cerebras parameters
# https://inference-docs.cerebras.ai/api-reference/chat-completions
cerebras_params["max_tokens"] = validate_parameter(params, "max_tokens", int, True, None, (0, None), None)
cerebras_params["seed"] = validate_parameter(params, "seed", int, True, None, None, None)
cerebras_params["stream"] = validate_parameter(params, "stream", bool, True, False, None, None)
cerebras_params["temperature"] = validate_parameter(
params, "temperature", (int, float), True, 1, (0, 1.5), None
)
cerebras_params["top_p"] = validate_parameter(params, "top_p", (int, float), True, None, None, None)
return cerebras_params
def create(self, params: Dict) -> ChatCompletion:
messages = params.get("messages", [])
# Convert AutoGen messages to Cerebras messages
cerebras_messages = oai_messages_to_cerebras_messages(messages)
# Parse parameters to the Cerebras API's parameters
cerebras_params = self.parse_params(params)
# Add tools to the call if we have them and aren't hiding them
if "tools" in params:
hide_tools = validate_parameter(
params, "hide_tools", str, False, "never", None, ["if_all_run", "if_any_run", "never"]
)
if not should_hide_tools(cerebras_messages, params["tools"], hide_tools):
cerebras_params["tools"] = params["tools"]
cerebras_params["messages"] = cerebras_messages
# We use chat model by default, and set max_retries to 5 (in line with typical retries loop)
client = Cerebras(api_key=self.api_key, max_retries=5)
# Token counts will be returned
prompt_tokens = 0
completion_tokens = 0
total_tokens = 0
# Streaming tool call recommendations
streaming_tool_calls = []
ans = None
try:
response = client.chat.completions.create(**cerebras_params)
except Exception as e:
raise RuntimeError(f"Cerebras exception occurred: {e}")
else:
if cerebras_params["stream"]:
# Read in the chunks as they stream, taking in tool_calls which may be across
# multiple chunks if more than one suggested
ans = ""
for chunk in response:
# Grab first choice, which _should_ always be generated.
ans = ans + (chunk.choices[0].delta.content or "")
if chunk.choices[0].delta.tool_calls:
# We have a tool call recommendation
for tool_call in chunk.choices[0].delta.tool_calls:
streaming_tool_calls.append(
ChatCompletionMessageToolCall(
id=tool_call.id,
function={
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
},
type="function",
)
)
if chunk.choices[0].finish_reason:
prompt_tokens = chunk.x_cerebras.usage.prompt_tokens
completion_tokens = chunk.x_cerebras.usage.completion_tokens
total_tokens = chunk.x_cerebras.usage.total_tokens
else:
# Non-streaming finished
ans: str = response.choices[0].message.content
prompt_tokens = response.usage.prompt_tokens
completion_tokens = response.usage.completion_tokens
total_tokens = response.usage.total_tokens
if response is not None:
if isinstance(response, Stream):
# Streaming response
if chunk.choices[0].finish_reason == "tool_calls":
cerebras_finish = "tool_calls"
tool_calls = streaming_tool_calls
else:
cerebras_finish = "stop"
tool_calls = None
response_content = ans
response_id = chunk.id
else:
# Non-streaming response
# If we have tool calls as the response, populate completed tool calls for our return OAI response
if response.choices[0].finish_reason == "tool_calls":
cerebras_finish = "tool_calls"
tool_calls = []
for tool_call in response.choices[0].message.tool_calls:
tool_calls.append(
ChatCompletionMessageToolCall(
id=tool_call.id,
function={"name": tool_call.function.name, "arguments": tool_call.function.arguments},
type="function",
)
)
else:
cerebras_finish = "stop"
tool_calls = None
response_content = response.choices[0].message.content
response_id = response.id
else:
raise RuntimeError("Failed to get response from Cerebras after retrying 5 times.")
# 3. convert output
message = ChatCompletionMessage(
role="assistant",
content=response_content,
function_call=None,
tool_calls=tool_calls,
)
choices = [Choice(finish_reason=cerebras_finish, index=0, message=message)]
response_oai = ChatCompletion(
id=response_id,
model=cerebras_params["model"],
created=int(time.time()),
object="chat.completion",
choices=choices,
usage=CompletionUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
),
# Note: This seems to be a field that isn't in the schema of `ChatCompletion`, so Pydantic
# just adds it dynamically.
cost=calculate_cerebras_cost(prompt_tokens, completion_tokens, cerebras_params["model"]),
)
return response_oai
def oai_messages_to_cerebras_messages(messages: list[Dict[str, Any]]) -> list[dict[str, Any]]:
"""Convert messages from OAI format to Cerebras's format.
We correct for any specific role orders and types.
"""
cerebras_messages = copy.deepcopy(messages)
# Remove the name field
for message in cerebras_messages:
if "name" in message:
message.pop("name", None)
return cerebras_messages
def calculate_cerebras_cost(input_tokens: int, output_tokens: int, model: str) -> float:
"""Calculate the cost of the completion using the Cerebras pricing."""
total = 0.0
if model in CEREBRAS_PRICING_1K:
input_cost_per_k, output_cost_per_k = CEREBRAS_PRICING_1K[model]
input_cost = (input_tokens / 1000) * input_cost_per_k
output_cost = (output_tokens / 1000) * output_cost_per_k
total = input_cost + output_cost
else:
warnings.warn(f"Cost calculation not available for model {model}", UserWarning)
return total

View File

@ -44,6 +44,13 @@ else:
TOOL_ENABLED = True
ERROR = None
try:
from autogen.oai.cerebras import CerebrasClient
cerebras_import_exception: Optional[ImportError] = None
except ImportError as e:
cerebras_import_exception = e
try:
from autogen.oai.gemini import GeminiClient
@ -505,6 +512,11 @@ class OpenAIWrapper:
self._configure_azure_openai(config, openai_config)
client = AzureOpenAI(**openai_config)
self._clients.append(OpenAIClient(client))
elif api_type is not None and api_type.startswith("cerebras"):
if cerebras_import_exception:
raise ImportError("Please install `cerebras_cloud_sdk` to use Cerebras OpenAI API.")
client = CerebrasClient(**openai_config)
self._clients.append(client)
elif api_type is not None and api_type.startswith("google"):
if gemini_import_exception:
raise ImportError("Please install `google-generativeai` to use Google OpenAI API.")

View File

@ -15,6 +15,7 @@ if TYPE_CHECKING:
from autogen import Agent, ConversableAgent, OpenAIWrapper
from autogen.oai.anthropic import AnthropicClient
from autogen.oai.bedrock import BedrockClient
from autogen.oai.cerebras import CerebrasClient
from autogen.oai.cohere import CohereClient
from autogen.oai.gemini import GeminiClient
from autogen.oai.groq import GroqClient
@ -116,6 +117,7 @@ def log_new_client(
client: Union[
AzureOpenAI,
OpenAI,
CerebrasClient,
GeminiClient,
AnthropicClient,
MistralAIClient,

View File

@ -126,6 +126,7 @@ class LLMConfig(SQLModel, table=False):
class ModelTypes(str, Enum):
openai = "open_ai"
cerebras = "cerebras"
google = "google"
azure = "azure"
anthropic = "anthropic"

View File

@ -101,6 +101,7 @@ extra_require = {
"types": ["mypy==1.9.0", "pytest>=6.1.1,<8"] + jupyter_executor,
"long-context": ["llmlingua<0.3"],
"anthropic": ["anthropic>=0.23.1"],
"cerebras": ["cerebras_cloud_sdk>=1.0.0"],
"mistral": ["mistralai>=1.0.1"],
"groq": ["groq>=0.9.0"],
"cohere": ["cohere>=5.5.8"],

248
test/oai/test_cerebras.py Normal file
View File

@ -0,0 +1,248 @@
from unittest.mock import MagicMock, patch
import pytest
try:
from autogen.oai.cerebras import CerebrasClient, calculate_cerebras_cost
skip = False
except ImportError:
CerebrasClient = object
InternalServerError = object
skip = True
# Fixtures for mock data
@pytest.fixture
def mock_response():
class MockResponse:
def __init__(self, text, choices, usage, cost, model):
self.text = text
self.choices = choices
self.usage = usage
self.cost = cost
self.model = model
return MockResponse
@pytest.fixture
def cerebras_client():
return CerebrasClient(api_key="fake_api_key")
skip_reason = "Cerebras dependency is not installed"
# Test initialization and configuration
@pytest.mark.skipif(skip, reason=skip_reason)
def test_initialization():
# Missing any api_key
with pytest.raises(AssertionError) as assertinfo:
CerebrasClient() # Should raise an AssertionError due to missing api_key
assert (
"Please include the api_key in your config list entry for Cerebras or set the CEREBRAS_API_KEY env variable."
in str(assertinfo.value)
)
# Creation works
CerebrasClient(api_key="fake_api_key") # Should create okay now.
# Test standard initialization
@pytest.mark.skipif(skip, reason=skip_reason)
def test_valid_initialization(cerebras_client):
assert cerebras_client.api_key == "fake_api_key", "Config api_key should be correctly set"
# Test parameters
@pytest.mark.skipif(skip, reason=skip_reason)
def test_parsing_params(cerebras_client):
# All parameters
params = {
"model": "llama3.1-8b",
"max_tokens": 1000,
"seed": 42,
"stream": False,
"temperature": 1,
"top_p": 0.8,
}
expected_params = {
"model": "llama3.1-8b",
"max_tokens": 1000,
"seed": 42,
"stream": False,
"temperature": 1,
"top_p": 0.8,
}
result = cerebras_client.parse_params(params)
assert result == expected_params
# Only model, others set as defaults
params = {
"model": "llama3.1-8b",
}
expected_params = {
"model": "llama3.1-8b",
"max_tokens": None,
"seed": None,
"stream": False,
"temperature": 1,
"top_p": None,
}
result = cerebras_client.parse_params(params)
assert result == expected_params
# Incorrect types, defaults should be set, will show warnings but not trigger assertions
params = {
"model": "llama3.1-8b",
"max_tokens": "1000",
"seed": "42",
"stream": "False",
"temperature": "1",
"top_p": "0.8",
}
result = cerebras_client.parse_params(params)
assert result == expected_params
# Values outside bounds, should warn and set to defaults
params = {
"model": "llama3.1-8b",
"temperature": 33123,
}
result = cerebras_client.parse_params(params)
assert result == expected_params
# No model
params = {
"temperature": 1,
}
with pytest.raises(AssertionError) as assertinfo:
result = cerebras_client.parse_params(params)
assert "Please specify the 'model' in your config list entry to nominate the Cerebras model to use." in str(
assertinfo.value
)
# Test cost calculation
@pytest.mark.skipif(skip, reason=skip_reason)
def test_cost_calculation(mock_response):
response = mock_response(
text="Example response",
choices=[{"message": "Test message 1"}],
usage={"prompt_tokens": 500, "completion_tokens": 300, "total_tokens": 800},
cost=None,
model="llama3.1-70b",
)
calculated_cost = calculate_cerebras_cost(
response.usage["prompt_tokens"], response.usage["completion_tokens"], response.model
)
# Convert cost per milliion to cost per token.
expected_cost = (
response.usage["prompt_tokens"] * 0.6 / 1000000 + response.usage["completion_tokens"] * 0.6 / 1000000
)
assert calculated_cost == expected_cost, f"Cost for this should be ${expected_cost} but got ${calculated_cost}"
# Test text generation
@pytest.mark.skipif(skip, reason=skip_reason)
@patch("autogen.oai.cerebras.CerebrasClient.create")
def test_create_response(mock_chat, cerebras_client):
# Mock CerebrasClient.chat response
mock_cerebras_response = MagicMock()
mock_cerebras_response.choices = [
MagicMock(finish_reason="stop", message=MagicMock(content="Example Cerebras response", tool_calls=None))
]
mock_cerebras_response.id = "mock_cerebras_response_id"
mock_cerebras_response.model = "llama3.1-70b"
mock_cerebras_response.usage = MagicMock(prompt_tokens=10, completion_tokens=20) # Example token usage
mock_chat.return_value = mock_cerebras_response
# Test parameters
params = {
"messages": [{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "World"}],
"model": "llama3.1-70b",
}
# Call the create method
response = cerebras_client.create(params)
# Assertions to check if response is structured as expected
assert (
response.choices[0].message.content == "Example Cerebras response"
), "Response content should match expected output"
assert response.id == "mock_cerebras_response_id", "Response ID should match the mocked response ID"
assert response.model == "llama3.1-70b", "Response model should match the mocked response model"
assert response.usage.prompt_tokens == 10, "Response prompt tokens should match the mocked response usage"
assert response.usage.completion_tokens == 20, "Response completion tokens should match the mocked response usage"
# Test functions/tools
@pytest.mark.skipif(skip, reason=skip_reason)
@patch("autogen.oai.cerebras.CerebrasClient.create")
def test_create_response_with_tool_call(mock_chat, cerebras_client):
# Mock `cerebras_response = client.chat(**cerebras_params)`
mock_function = MagicMock(name="currency_calculator")
mock_function.name = "currency_calculator"
mock_function.arguments = '{"base_currency": "EUR", "quote_currency": "USD", "base_amount": 123.45}'
mock_function_2 = MagicMock(name="get_weather")
mock_function_2.name = "get_weather"
mock_function_2.arguments = '{"location": "Chicago"}'
mock_chat.return_value = MagicMock(
choices=[
MagicMock(
finish_reason="tool_calls",
message=MagicMock(
content="Sample text about the functions",
tool_calls=[
MagicMock(id="gdRdrvnHh", function=mock_function),
MagicMock(id="abRdrvnHh", function=mock_function_2),
],
),
)
],
id="mock_cerebras_response_id",
model="llama3.1-70b",
usage=MagicMock(prompt_tokens=10, completion_tokens=20),
)
# Construct parameters
converted_functions = [
{
"type": "function",
"function": {
"description": "Currency exchange calculator.",
"name": "currency_calculator",
"parameters": {
"type": "object",
"properties": {
"base_amount": {"type": "number", "description": "Amount of currency in base_currency"},
},
"required": ["base_amount"],
},
},
}
]
cerebras_messages = [
{"role": "user", "content": "How much is 123.45 EUR in USD?"},
{"role": "assistant", "content": "World"},
]
# Call the create method
response = cerebras_client.create(
{"messages": cerebras_messages, "tools": converted_functions, "model": "llama3.1-70b"}
)
# Assertions to check if the functions and content are included in the response
assert response.choices[0].message.content == "Sample text about the functions"
assert response.choices[0].message.tool_calls[0].function.name == "currency_calculator"
assert response.choices[0].message.tool_calls[1].function.name == "get_weather"

View File

@ -0,0 +1,505 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Cerebras\n",
"\n",
"[Cerebras](https://cerebras.ai) has developed the world's largest and fastest AI processor, the Wafer-Scale Engine-3 (WSE-3). Notably, the CS-3 system can run large language models like Llama-3.1-8B and Llama-3.1-70B at extremely fast speeds, making it an ideal platform for demanding AI workloads.\n",
"\n",
"While it's technically possible to adapt AutoGen to work with Cerebras' API by updating the `base_url`, this approach may not fully account for minor differences in parameter support. Using this library will also allow for tracking of the API costs based on actual token usage.\n",
"\n",
"For more information about Cerebras Cloud, visit [cloud.cerebras.ai](https://cloud.cerebras.ai). Their API reference is available at [inference-docs.cerebras.ai](https://inference-docs.cerebras.ai)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Requirements\n",
"To use Cerebras with AutoGen, install the `pyautogen[cerebras]` package."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install pyautogen[\"cerebras\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Getting Started\n",
"\n",
"Cerebras provides a number of models to use. See the list of [models here](https://inference-docs.cerebras.ai/introduction).\n",
"\n",
"See the sample `OAI_CONFIG_LIST` below showing how the Cerebras AI client class is used by specifying the `api_type` as `cerebras`.\n",
"```python\n",
"[\n",
" {\n",
" \"model\": \"llama3.1-8b\",\n",
" \"api_key\": \"your Cerebras API Key goes here\",\n",
" \"api_type\": \"cerebras\"\n",
" },\n",
" {\n",
" \"model\": \"llama3.1-70b\",\n",
" \"api_key\": \"your Cerebras API Key goes here\",\n",
" \"api_type\": \"cerebras\"\n",
" }\n",
"]\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Credentials\n",
"\n",
"Get an API Key from [cloud.cerebras.ai](https://cloud.cerebras.ai/) and add it to your environment variables:\n",
"\n",
"```\n",
"export CEREBRAS_API_KEY=\"your-api-key-here\"\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## API parameters\n",
"\n",
"The following parameters can be added to your config for the Cerebras API. See [this link](https://inference-docs.cerebras.ai/api-reference/chat-completions) for further information on them and their default values.\n",
"\n",
"- max_tokens (null, integer >= 0)\n",
"- seed (number)\n",
"- stream (True or False)\n",
"- temperature (number 0..1.5)\n",
"- top_p (number)\n",
"\n",
"Example:\n",
"```python\n",
"[\n",
" {\n",
" \"model\": \"llama3.1-70b\",\n",
" \"api_key\": \"your Cerebras API Key goes here\",\n",
" \"api_type\": \"cerebras\"\n",
" \"max_tokens\": 10000,\n",
" \"seed\": 1234,\n",
" \"stream\" True,\n",
" \"temperature\": 0.5,\n",
" \"top_p\": 0.2, # Note: It is recommended to set temperature or top_p but not both.\n",
" }\n",
"]\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Two-Agent Coding Example\n",
"\n",
"In this example, we run a two-agent chat with an AssistantAgent (primarily a coding agent) to generate code to count the number of prime numbers between 1 and 10,000 and then it will be executed.\n",
"\n",
"We'll use Meta's LLama-3.1-70B model which is suitable for coding."
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"from autogen.oai.cerebras import CerebrasClient, calculate_cerebras_cost\n",
"\n",
"config_list = [{\"model\": \"llama3.1-70b\", \"api_key\": os.environ.get(\"CEREBRAS_API_KEY\"), \"api_type\": \"cerebras\"}]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Importantly, we have tweaked the system message so that the model doesn't return the termination keyword, which we've changed to FINISH, with the code block."
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"from autogen import AssistantAgent, UserProxyAgent\n",
"from autogen.coding import LocalCommandLineCodeExecutor\n",
"\n",
"# Setting up the code executor\n",
"workdir = Path(\"coding\")\n",
"workdir.mkdir(exist_ok=True)\n",
"code_executor = LocalCommandLineCodeExecutor(work_dir=workdir)\n",
"\n",
"# Setting up the agents\n",
"\n",
"# The UserProxyAgent will execute the code that the AssistantAgent provides\n",
"user_proxy_agent = UserProxyAgent(\n",
" name=\"User\",\n",
" code_execution_config={\"executor\": code_executor},\n",
" is_termination_msg=lambda msg: \"FINISH\" in msg.get(\"content\"),\n",
")\n",
"\n",
"system_message = \"\"\"You are a helpful AI assistant who writes code and the user executes it.\n",
"Solve tasks using your coding and language skills.\n",
"In the following cases, suggest python code (in a python coding block) for the user to execute.\n",
"Solve the task step by step if you need to. If a plan is not provided, explain your plan first. Be clear which step uses code, and which step uses your language skill.\n",
"When using code, you must indicate the script type in the code block. The user cannot provide any other feedback or perform any other action beyond executing the code you suggest. The user can't modify your code. So do not suggest incomplete code which requires users to modify. Don't use a code block if it's not intended to be executed by the user.\n",
"Don't include multiple code blocks in one response. Do not ask users to copy and paste the result. Instead, use 'print' function for the output when relevant. Check the execution result returned by the user.\n",
"If the result indicates there is an error, fix the error and output the code again. Suggest the full code instead of partial code or code changes. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, collect additional info you need, and think of a different approach to try.\n",
"When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible.\n",
"IMPORTANT: Wait for the user to execute your code and then you can reply with the word \"FINISH\". DO NOT OUTPUT \"FINISH\" after your code block.\"\"\"\n",
"\n",
"# The AssistantAgent, using Cerebras AI's model, will take the coding request and return code\n",
"assistant_agent = AssistantAgent(\n",
" name=\"Cerebras Assistant\",\n",
" system_message=system_message,\n",
" llm_config={\"config_list\": config_list},\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mUser\u001b[0m (to Cerebras Assistant):\n",
"\n",
"Provide code to count the number of prime numbers from 1 to 10000.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mCerebras Assistant\u001b[0m (to User):\n",
"\n",
"To count the number of prime numbers from 1 to 10000, we will utilize a simple algorithm that checks each number in the range to see if it is prime. A prime number is a natural number greater than 1 that has no positive divisors other than 1 and itself.\n",
"\n",
"Here's how we can do it using a Python script:\n",
"\n",
"```python\n",
"def count_primes(n):\n",
" primes = 0\n",
" for possiblePrime in range(2, n + 1):\n",
" # Assume number is prime until shown it is not. \n",
" isPrime = True\n",
" for num in range(2, int(possiblePrime ** 0.5) + 1):\n",
" if possiblePrime % num == 0:\n",
" isPrime = False\n",
" break\n",
" if isPrime:\n",
" primes += 1\n",
" return primes\n",
"\n",
"# Counting prime numbers from 1 to 10000\n",
"count = count_primes(10000)\n",
"print(count)\n",
"```\n",
"\n",
"Please execute this code. I will respond with \"FINISH\" after you provide the result.\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Replying as User. Provide feedback to Cerebras Assistant. Press enter to skip and use auto-reply, or type 'exit' to end the conversation: \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[31m\n",
">>>>>>>> NO HUMAN INPUT RECEIVED.\u001b[0m\n"
]
}
],
"source": [
"# Start the chat, with the UserProxyAgent asking the AssistantAgent the message\n",
"chat_result = user_proxy_agent.initiate_chat(\n",
" assistant_agent,\n",
" message=\"Provide code to count the number of prime numbers from 1 to 10000.\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tool Call Example\n",
"\n",
"In this example, instead of writing code, we will show how Meta's Llama-3.1-70B model can perform parallel tool calling, where it recommends calling more than one tool at a time.\n",
"\n",
"We'll use a simple travel agent assistant program where we have a couple of tools for weather and currency conversion.\n",
"\n",
"We start by importing libraries and setting up our configuration to use Llama-3.1-70B and the `cerebras` client class."
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import os\n",
"from typing import Literal\n",
"\n",
"from typing_extensions import Annotated\n",
"\n",
"import autogen\n",
"\n",
"config_list = [\n",
" {\n",
" \"model\": \"llama3.1-70b\",\n",
" \"api_key\": os.environ.get(\"CEREBRAS_API_KEY\"),\n",
" \"api_type\": \"cerebras\",\n",
" }\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Create our two agents."
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
"# Create the agent for tool calling\n",
"chatbot = autogen.AssistantAgent(\n",
" name=\"chatbot\",\n",
" system_message=\"\"\"\n",
" For currency exchange and weather forecasting tasks,\n",
" only use the functions you have been provided with.\n",
" When you summarize, make sure you've considered ALL previous instructions.\n",
" Output 'HAVE FUN!' when an answer has been provided.\n",
" \"\"\",\n",
" llm_config={\"config_list\": config_list},\n",
")\n",
"\n",
"# Note that we have changed the termination string to be \"HAVE FUN!\"\n",
"user_proxy = autogen.UserProxyAgent(\n",
" name=\"user_proxy\",\n",
" is_termination_msg=lambda x: x.get(\"content\", \"\") and \"HAVE FUN!\" in x.get(\"content\", \"\"),\n",
" human_input_mode=\"NEVER\",\n",
" max_consecutive_auto_reply=1,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Create the two functions, annotating them so that those descriptions can be passed through to the LLM.\n",
"\n",
"We associate them with the agents using `register_for_execution` for the user_proxy so it can execute the function and `register_for_llm` for the chatbot (powered by the LLM) so it can pass the function definitions to the LLM."
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
"# Currency Exchange function\n",
"\n",
"CurrencySymbol = Literal[\"USD\", \"EUR\"]\n",
"\n",
"# Define our function that we expect to call\n",
"\n",
"\n",
"def exchange_rate(base_currency: CurrencySymbol, quote_currency: CurrencySymbol) -> float:\n",
" if base_currency == quote_currency:\n",
" return 1.0\n",
" elif base_currency == \"USD\" and quote_currency == \"EUR\":\n",
" return 1 / 1.1\n",
" elif base_currency == \"EUR\" and quote_currency == \"USD\":\n",
" return 1.1\n",
" else:\n",
" raise ValueError(f\"Unknown currencies {base_currency}, {quote_currency}\")\n",
"\n",
"\n",
"# Register the function with the agent\n",
"\n",
"\n",
"@user_proxy.register_for_execution()\n",
"@chatbot.register_for_llm(description=\"Currency exchange calculator.\")\n",
"def currency_calculator(\n",
" base_amount: Annotated[float, \"Amount of currency in base_currency\"],\n",
" base_currency: Annotated[CurrencySymbol, \"Base currency\"] = \"USD\",\n",
" quote_currency: Annotated[CurrencySymbol, \"Quote currency\"] = \"EUR\",\n",
") -> str:\n",
" quote_amount = exchange_rate(base_currency, quote_currency) * base_amount\n",
" return f\"{format(quote_amount, '.2f')} {quote_currency}\"\n",
"\n",
"\n",
"# Weather function\n",
"\n",
"\n",
"# Example function to make available to model\n",
"def get_current_weather(location, unit=\"fahrenheit\"):\n",
" \"\"\"Get the weather for some location\"\"\"\n",
" if \"chicago\" in location.lower():\n",
" return json.dumps({\"location\": \"Chicago\", \"temperature\": \"13\", \"unit\": unit})\n",
" elif \"san francisco\" in location.lower():\n",
" return json.dumps({\"location\": \"San Francisco\", \"temperature\": \"55\", \"unit\": unit})\n",
" elif \"new york\" in location.lower():\n",
" return json.dumps({\"location\": \"New York\", \"temperature\": \"11\", \"unit\": unit})\n",
" else:\n",
" return json.dumps({\"location\": location, \"temperature\": \"unknown\"})\n",
"\n",
"\n",
"# Register the function with the agent\n",
"\n",
"\n",
"@user_proxy.register_for_execution()\n",
"@chatbot.register_for_llm(description=\"Weather forecast for US cities.\")\n",
"def weather_forecast(\n",
" location: Annotated[str, \"City name\"],\n",
") -> str:\n",
" weather_details = get_current_weather(location=location)\n",
" weather = json.loads(weather_details)\n",
" return f\"{weather['location']} will be {weather['temperature']} degrees {weather['unit']}\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We pass through our customer's message and run the chat.\n",
"\n",
"Finally, we ask the LLM to summarise the chat and print that out."
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"What's the weather in New York and can you tell me how much is 123.45 EUR in USD so I can spend it on my holiday? Throw a few holiday tips in as well.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
"\u001b[32m***** Suggested tool call (210f6ac6d): weather_forecast *****\u001b[0m\n",
"Arguments: \n",
"{\"location\": \"New York\"}\n",
"\u001b[32m*************************************************************\u001b[0m\n",
"\u001b[32m***** Suggested tool call (3c00ac7d5): currency_calculator *****\u001b[0m\n",
"Arguments: \n",
"{\"base_amount\": 123.45, \"base_currency\": \"EUR\", \"quote_currency\": \"USD\"}\n",
"\u001b[32m****************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[35m\n",
">>>>>>>> EXECUTING FUNCTION weather_forecast...\u001b[0m\n",
"\u001b[35m\n",
">>>>>>>> EXECUTING FUNCTION currency_calculator...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[32m***** Response from calling tool (210f6ac6d) *****\u001b[0m\n",
"New York will be 11 degrees fahrenheit\n",
"\u001b[32m**************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[32m***** Response from calling tool (3c00ac7d5) *****\u001b[0m\n",
"135.80 USD\n",
"\u001b[32m**************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
"New York will be 11 degrees fahrenheit.\n",
"123.45 EUR is equivalent to 135.80 USD.\n",
" \n",
"For a great holiday, explore the Statue of Liberty, take a walk through Central Park, or visit one of the many world-class museums. Also, you'll find great food ranging from bagels to fine dining experiences. HAVE FUN!\n",
"\n",
"--------------------------------------------------------------------------------\n",
"LLM SUMMARY: New York will be 11 degrees fahrenheit. 123.45 EUR is equivalent to 135.80 USD. Explore the Statue of Liberty, walk through Central Park, or visit one of the many world-class museums for a great holiday in New York.\n",
"\n",
"Duration: 73.97937774658203ms\n"
]
}
],
"source": [
"import time\n",
"\n",
"start_time = time.time()\n",
"\n",
"# start the conversation\n",
"res = user_proxy.initiate_chat(\n",
" chatbot,\n",
" message=\"What's the weather in New York and can you tell me how much is 123.45 EUR in USD so I can spend it on my holiday? Throw a few holiday tips in as well.\",\n",
" summary_method=\"reflection_with_llm\",\n",
")\n",
"\n",
"end_time = time.time()\n",
"\n",
"print(f\"LLM SUMMARY: {res.summary['content']}\\n\\nDuration: {(end_time - start_time) * 1000}ms\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see that the Cerebras Wafer-Scale Engine-3 (WSE-3) completed the query in 74ms -- faster than the blink of an eye!"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
"nbformat_minor": 4
}