Groq Client (#3003)

* Groq Client Class - main class and setup, except tests

* Change pricing per K, added tests

* Streaming support, including with tool calling

* Used Groq retries instead of loop, thanks Gal-Gilor!

* Fixed bug when using logging.

---------

Co-authored-by: Qingyun Wu <qingyun0327@gmail.com>
This commit is contained in:
Mark Sze 2024-06-28 15:58:42 +10:00 committed by GitHub
parent 55cc542bcf
commit 23c1dec206
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 599 additions and 4 deletions

View File

@ -598,3 +598,43 @@ jobs:
with:
file: ./coverage.xml
flags: unittests
GroqTest:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-2019]
python-version: ["3.9", "3.10", "3.11", "3.12"]
exclude:
- os: macos-latest
python-version: "3.9"
steps:
- uses: actions/checkout@v4
with:
lfs: true
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install packages and dependencies for all tests
run: |
python -m pip install --upgrade pip wheel
pip install pytest-cov>=5
- name: Install packages and dependencies for Groq
run: |
pip install -e .[groq,test]
- name: Set AUTOGEN_USE_DOCKER based on OS
shell: bash
run: |
if [[ ${{ matrix.os }} != ubuntu-latest ]]; then
echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV
fi
- name: Coverage
run: |
pytest test/oai/test_groq.py --skip-openai
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittests

View File

@ -19,6 +19,7 @@ if TYPE_CHECKING:
from autogen import Agent, ConversableAgent, OpenAIWrapper
from autogen.oai.anthropic import AnthropicClient
from autogen.oai.gemini import GeminiClient
from autogen.oai.groq import GroqClient
from autogen.oai.mistral import MistralAIClient
from autogen.oai.together import TogetherClient
@ -204,7 +205,7 @@ class FileLogger(BaseLogger):
def log_new_client(
self,
client: AzureOpenAI | OpenAI | GeminiClient | AnthropicClient | MistralAIClient | TogetherClient,
client: AzureOpenAI | OpenAI | GeminiClient | AnthropicClient | MistralAIClient | TogetherClient | GroqClient,
wrapper: OpenAIWrapper,
init_args: Dict[str, Any],
) -> None:

View File

@ -20,6 +20,7 @@ if TYPE_CHECKING:
from autogen import Agent, ConversableAgent, OpenAIWrapper
from autogen.oai.anthropic import AnthropicClient
from autogen.oai.gemini import GeminiClient
from autogen.oai.groq import GroqClient
from autogen.oai.mistral import MistralAIClient
from autogen.oai.together import TogetherClient
@ -391,7 +392,7 @@ class SqliteLogger(BaseLogger):
def log_new_client(
self,
client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient],
client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient, GroqClient],
wrapper: OpenAIWrapper,
init_args: Dict[str, Any],
) -> None:

View File

@ -70,6 +70,13 @@ try:
except ImportError as e:
together_import_exception = e
try:
from autogen.oai.groq import GroqClient
groq_import_exception: Optional[ImportError] = None
except ImportError as e:
groq_import_exception = e
logger = logging.getLogger(__name__)
if not logger.handlers:
# Add the console handler.
@ -483,7 +490,13 @@ class OpenAIWrapper:
elif api_type is not None and api_type.startswith("together"):
if together_import_exception:
raise ImportError("Please install `together` to use the Together.AI API.")
self._clients.append(TogetherClient(**config))
client = TogetherClient(**openai_config)
self._clients.append(client)
elif api_type is not None and api_type.startswith("groq"):
if groq_import_exception:
raise ImportError("Please install `groq` to use the Groq API.")
client = GroqClient(**openai_config)
self._clients.append(client)
else:
client = OpenAI(**openai_config)
self._clients.append(OpenAIClient(client))

289
autogen/oai/groq.py Normal file
View File

@ -0,0 +1,289 @@
"""Create an OpenAI-compatible client using Groq's API.
Example:
llm_config={
"config_list": [{
"api_type": "groq",
"model": "mixtral-8x7b-32768",
"api_key": os.environ.get("GROQ_API_KEY")
}
]}
agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
Install Groq's python library using: pip install --upgrade groq
Resources:
- https://console.groq.com/docs/quickstart
"""
from __future__ import annotations
import copy
import os
import time
import warnings
from typing import Any, Dict, List
from groq import Groq, Stream
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.completion_usage import CompletionUsage
from autogen.oai.client_utils import should_hide_tools, validate_parameter
# Cost per thousand tokens - Input / Output (NOTE: Convert $/Million to $/K)
GROQ_PRICING_1K = {
"llama3-70b-8192": (0.00059, 0.00079),
"mixtral-8x7b-32768": (0.00024, 0.00024),
"llama3-8b-8192": (0.00005, 0.00008),
"gemma-7b-it": (0.00007, 0.00007),
}
class GroqClient:
"""Client for Groq's API."""
def __init__(self, **kwargs):
"""Requires api_key or environment variable to be set
Args:
api_key (str): The API key for using Groq (or environment variable GROQ_API_KEY needs to be set)
"""
# Ensure we have the api_key upon instantiation
self.api_key = kwargs.get("api_key", None)
if not self.api_key:
self.api_key = os.getenv("GROQ_API_KEY")
assert (
self.api_key
), "Please include the api_key in your config list entry for Groq or set the GROQ_API_KEY env variable."
def message_retrieval(self, response) -> List:
"""
Retrieve and return a list of strings or a list of Choice.Message from the response.
NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
"""
return [choice.message for choice in response.choices]
def cost(self, response) -> float:
return response.cost
@staticmethod
def get_usage(response) -> Dict:
"""Return usage summary of the response using RESPONSE_USAGE_KEYS."""
# ... # pragma: no cover
return {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
"cost": response.cost,
"model": response.model,
}
def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""Loads the parameters for Groq API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults"""
groq_params = {}
# Check that we have what we need to use Groq's API
# We won't enforce the available models as they are likely to change
groq_params["model"] = params.get("model", None)
assert groq_params[
"model"
], "Please specify the 'model' in your config list entry to nominate the Groq model to use."
# Validate allowed Groq parameters
# https://console.groq.com/docs/api-reference#chat
groq_params["frequency_penalty"] = validate_parameter(
params, "frequency_penalty", (int, float), True, None, (-2, 2), None
)
groq_params["max_tokens"] = validate_parameter(params, "max_tokens", int, True, None, (0, None), None)
groq_params["presence_penalty"] = validate_parameter(
params, "presence_penalty", (int, float), True, None, (-2, 2), None
)
groq_params["seed"] = validate_parameter(params, "seed", int, True, None, None, None)
groq_params["stream"] = validate_parameter(params, "stream", bool, True, False, None, None)
groq_params["temperature"] = validate_parameter(params, "temperature", (int, float), True, 1, (0, 2), None)
groq_params["top_p"] = validate_parameter(params, "top_p", (int, float), True, None, None, None)
# Groq parameters not supported by their models yet, ignoring
# logit_bias, logprobs, top_logprobs
# Groq parameters we are ignoring:
# n (must be 1), response_format (to enforce JSON but needs prompting as well), user,
# parallel_tool_calls (defaults to True), stop
# function_call (deprecated), functions (deprecated)
# tool_choice (none if no tools, auto if there are tools)
return groq_params
def create(self, params: Dict) -> ChatCompletion:
messages = params.get("messages", [])
# Convert AutoGen messages to Groq messages
groq_messages = oai_messages_to_groq_messages(messages)
# Parse parameters to the Groq API's parameters
groq_params = self.parse_params(params)
# Add tools to the call if we have them and aren't hiding them
if "tools" in params:
hide_tools = validate_parameter(
params, "hide_tools", str, False, "never", None, ["if_all_run", "if_any_run", "never"]
)
if not should_hide_tools(groq_messages, params["tools"], hide_tools):
groq_params["tools"] = params["tools"]
groq_params["messages"] = groq_messages
# We use chat model by default, and set max_retries to 5 (in line with typical retries loop)
client = Groq(api_key=self.api_key, max_retries=5)
# Token counts will be returned
prompt_tokens = 0
completion_tokens = 0
total_tokens = 0
# Streaming tool call recommendations
streaming_tool_calls = []
ans = None
try:
response = client.chat.completions.create(**groq_params)
except Exception as e:
raise RuntimeError(f"Groq exception occurred: {e}")
else:
if groq_params["stream"]:
# Read in the chunks as they stream, taking in tool_calls which may be across
# multiple chunks if more than one suggested
ans = ""
for chunk in response:
ans = ans + (chunk.choices[0].delta.content or "")
if chunk.choices[0].delta.tool_calls:
# We have a tool call recommendation
for tool_call in chunk.choices[0].delta.tool_calls:
streaming_tool_calls.append(
ChatCompletionMessageToolCall(
id=tool_call.id,
function={
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
},
type="function",
)
)
if chunk.choices[0].finish_reason:
prompt_tokens = chunk.x_groq.usage.prompt_tokens
completion_tokens = chunk.x_groq.usage.completion_tokens
total_tokens = chunk.x_groq.usage.total_tokens
else:
# Non-streaming finished
ans: str = response.choices[0].message.content
prompt_tokens = response.usage.prompt_tokens
completion_tokens = response.usage.completion_tokens
total_tokens = response.usage.total_tokens
if response is not None:
if isinstance(response, Stream):
# Streaming response
if chunk.choices[0].finish_reason == "tool_calls":
groq_finish = "tool_calls"
tool_calls = streaming_tool_calls
else:
groq_finish = "stop"
tool_calls = None
response_content = ans
response_id = chunk.id
else:
# Non-streaming response
# If we have tool calls as the response, populate completed tool calls for our return OAI response
if response.choices[0].finish_reason == "tool_calls":
groq_finish = "tool_calls"
tool_calls = []
for tool_call in response.choices[0].message.tool_calls:
tool_calls.append(
ChatCompletionMessageToolCall(
id=tool_call.id,
function={"name": tool_call.function.name, "arguments": tool_call.function.arguments},
type="function",
)
)
else:
groq_finish = "stop"
tool_calls = None
response_content = response.choices[0].message.content
response_id = response.id
else:
raise RuntimeError("Failed to get response from Groq after retrying 5 times.")
# 3. convert output
message = ChatCompletionMessage(
role="assistant",
content=response_content,
function_call=None,
tool_calls=tool_calls,
)
choices = [Choice(finish_reason=groq_finish, index=0, message=message)]
response_oai = ChatCompletion(
id=response_id,
model=groq_params["model"],
created=int(time.time()),
object="chat.completion",
choices=choices,
usage=CompletionUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
),
cost=calculate_groq_cost(prompt_tokens, completion_tokens, groq_params["model"]),
)
return response_oai
def oai_messages_to_groq_messages(messages: list[Dict[str, Any]]) -> list[dict[str, Any]]:
"""Convert messages from OAI format to Groq's format.
We correct for any specific role orders and types.
"""
groq_messages = copy.deepcopy(messages)
# If we have a message with role='tool', which occurs when a function is executed, change it to 'user'
"""
for msg in together_messages:
if "role" in msg and msg["role"] == "tool":
msg["role"] = "user"
"""
# Remove the name field
for message in groq_messages:
if "name" in message:
message.pop("name", None)
return groq_messages
def calculate_groq_cost(input_tokens: int, output_tokens: int, model: str) -> float:
"""Calculate the cost of the completion using the Groq pricing."""
total = 0.0
if model in GROQ_PRICING_1K:
input_cost_per_k, output_cost_per_k = GROQ_PRICING_1K[model]
input_cost = (input_tokens / 1000) * input_cost_per_k
output_cost = (output_tokens / 1000) * output_cost_per_k
total = input_cost + output_cost
else:
warnings.warn(f"Cost calculation not available for model {model}", UserWarning)
return total

View File

@ -15,6 +15,7 @@ if TYPE_CHECKING:
from autogen import Agent, ConversableAgent, OpenAIWrapper
from autogen.oai.anthropic import AnthropicClient
from autogen.oai.gemini import GeminiClient
from autogen.oai.groq import GroqClient
from autogen.oai.mistral import MistralAIClient
from autogen.oai.together import TogetherClient
@ -110,7 +111,7 @@ def log_new_wrapper(wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig
def log_new_client(
client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient],
client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient, GroqClient],
wrapper: OpenAIWrapper,
init_args: Dict[str, Any],
) -> None:

View File

@ -91,6 +91,7 @@ extra_require = {
"long-context": ["llmlingua<0.3"],
"anthropic": ["anthropic>=0.23.1"],
"mistral": ["mistralai>=0.2.0"],
"groq": ["groq>=0.9.0"],
}
setuptools.setup(

249
test/oai/test_groq.py Normal file
View File

@ -0,0 +1,249 @@
from unittest.mock import MagicMock, patch
import pytest
try:
from autogen.oai.groq import GroqClient, calculate_groq_cost
skip = False
except ImportError:
GroqClient = object
InternalServerError = object
skip = True
# Fixtures for mock data
@pytest.fixture
def mock_response():
class MockResponse:
def __init__(self, text, choices, usage, cost, model):
self.text = text
self.choices = choices
self.usage = usage
self.cost = cost
self.model = model
return MockResponse
@pytest.fixture
def groq_client():
return GroqClient(api_key="fake_api_key")
skip_reason = "Groq dependency is not installed"
# Test initialization and configuration
@pytest.mark.skipif(skip, reason=skip_reason)
def test_initialization():
# Missing any api_key
with pytest.raises(AssertionError) as assertinfo:
GroqClient() # Should raise an AssertionError due to missing api_key
assert "Please include the api_key in your config list entry for Groq or set the GROQ_API_KEY env variable." in str(
assertinfo.value
)
# Creation works
GroqClient(api_key="fake_api_key") # Should create okay now.
# Test standard initialization
@pytest.mark.skipif(skip, reason=skip_reason)
def test_valid_initialization(groq_client):
assert groq_client.api_key == "fake_api_key", "Config api_key should be correctly set"
# Test parameters
@pytest.mark.skipif(skip, reason=skip_reason)
def test_parsing_params(groq_client):
# All parameters
params = {
"model": "llama3-8b-8192",
"frequency_penalty": 1.5,
"presence_penalty": 1.5,
"max_tokens": 1000,
"seed": 42,
"stream": False,
"temperature": 1,
"top_p": 0.8,
}
expected_params = {
"model": "llama3-8b-8192",
"frequency_penalty": 1.5,
"presence_penalty": 1.5,
"max_tokens": 1000,
"seed": 42,
"stream": False,
"temperature": 1,
"top_p": 0.8,
}
result = groq_client.parse_params(params)
assert result == expected_params
# Only model, others set as defaults
params = {
"model": "llama3-8b-8192",
}
expected_params = {
"model": "llama3-8b-8192",
"frequency_penalty": None,
"presence_penalty": None,
"max_tokens": None,
"seed": None,
"stream": False,
"temperature": 1,
"top_p": None,
}
result = groq_client.parse_params(params)
assert result == expected_params
# Incorrect types, defaults should be set, will show warnings but not trigger assertions
params = {
"model": "llama3-8b-8192",
"frequency_penalty": "1.5",
"presence_penalty": "1.5",
"max_tokens": "1000",
"seed": "42",
"stream": "False",
"temperature": "1",
"top_p": "0.8",
}
result = groq_client.parse_params(params)
assert result == expected_params
# Values outside bounds, should warn and set to defaults
params = {
"model": "llama3-8b-8192",
"frequency_penalty": 5000,
"presence_penalty": -500,
"temperature": 3,
}
result = groq_client.parse_params(params)
assert result == expected_params
# No model
params = {
"frequency_penalty": 1,
}
with pytest.raises(AssertionError) as assertinfo:
result = groq_client.parse_params(params)
assert "Please specify the 'model' in your config list entry to nominate the Groq model to use." in str(
assertinfo.value
)
# Test cost calculation
@pytest.mark.skipif(skip, reason=skip_reason)
def test_cost_calculation(mock_response):
response = mock_response(
text="Example response",
choices=[{"message": "Test message 1"}],
usage={"prompt_tokens": 500, "completion_tokens": 300, "total_tokens": 800},
cost=None,
model="llama3-70b-8192",
)
assert (
calculate_groq_cost(response.usage["prompt_tokens"], response.usage["completion_tokens"], response.model)
== 0.000532
), "Cost for this should be $0.000532"
# Test text generation
@pytest.mark.skipif(skip, reason=skip_reason)
@patch("autogen.oai.groq.GroqClient.create")
def test_create_response(mock_chat, groq_client):
# Mock GroqClient.chat response
mock_groq_response = MagicMock()
mock_groq_response.choices = [
MagicMock(finish_reason="stop", message=MagicMock(content="Example Groq response", tool_calls=None))
]
mock_groq_response.id = "mock_groq_response_id"
mock_groq_response.model = "llama3-70b-8192"
mock_groq_response.usage = MagicMock(prompt_tokens=10, completion_tokens=20) # Example token usage
mock_chat.return_value = mock_groq_response
# Test parameters
params = {
"messages": [{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "World"}],
"model": "llama3-70b-8192",
}
# Call the create method
response = groq_client.create(params)
# Assertions to check if response is structured as expected
assert (
response.choices[0].message.content == "Example Groq response"
), "Response content should match expected output"
assert response.id == "mock_groq_response_id", "Response ID should match the mocked response ID"
assert response.model == "llama3-70b-8192", "Response model should match the mocked response model"
assert response.usage.prompt_tokens == 10, "Response prompt tokens should match the mocked response usage"
assert response.usage.completion_tokens == 20, "Response completion tokens should match the mocked response usage"
# Test functions/tools
@pytest.mark.skipif(skip, reason=skip_reason)
@patch("autogen.oai.groq.GroqClient.create")
def test_create_response_with_tool_call(mock_chat, groq_client):
# Mock `groq_response = client.chat(**groq_params)`
mock_function = MagicMock(name="currency_calculator")
mock_function.name = "currency_calculator"
mock_function.arguments = '{"base_currency": "EUR", "quote_currency": "USD", "base_amount": 123.45}'
mock_function_2 = MagicMock(name="get_weather")
mock_function_2.name = "get_weather"
mock_function_2.arguments = '{"location": "Chicago"}'
mock_chat.return_value = MagicMock(
choices=[
MagicMock(
finish_reason="tool_calls",
message=MagicMock(
content="Sample text about the functions",
tool_calls=[
MagicMock(id="gdRdrvnHh", function=mock_function),
MagicMock(id="abRdrvnHh", function=mock_function_2),
],
),
)
],
id="mock_groq_response_id",
model="llama3-70b-8192",
usage=MagicMock(prompt_tokens=10, completion_tokens=20),
)
# Construct parameters
converted_functions = [
{
"type": "function",
"function": {
"description": "Currency exchange calculator.",
"name": "currency_calculator",
"parameters": {
"type": "object",
"properties": {
"base_amount": {"type": "number", "description": "Amount of currency in base_currency"},
},
"required": ["base_amount"],
},
},
}
]
groq_messages = [
{"role": "user", "content": "How much is 123.45 EUR in USD?"},
{"role": "assistant", "content": "World"},
]
# Call the create method
response = groq_client.create({"messages": groq_messages, "tools": converted_functions, "model": "llama3-70b-8192"})
# Assertions to check if the functions and content are included in the response
assert response.choices[0].message.content == "Sample text about the functions"
assert response.choices[0].message.tool_calls[0].function.name == "currency_calculator"
assert response.choices[0].message.tool_calls[1].function.name == "get_weather"