mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-05 04:09:51 +00:00
Groq Client (#3003)
* Groq Client Class - main class and setup, except tests * Change pricing per K, added tests * Streaming support, including with tool calling * Used Groq retries instead of loop, thanks Gal-Gilor! * Fixed bug when using logging. --------- Co-authored-by: Qingyun Wu <qingyun0327@gmail.com>
This commit is contained in:
parent
55cc542bcf
commit
23c1dec206
40
.github/workflows/contrib-tests.yml
vendored
40
.github/workflows/contrib-tests.yml
vendored
@ -598,3 +598,43 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
file: ./coverage.xml
|
file: ./coverage.xml
|
||||||
flags: unittests
|
flags: unittests
|
||||||
|
|
||||||
|
GroqTest:
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-latest, macos-latest, windows-2019]
|
||||||
|
python-version: ["3.9", "3.10", "3.11", "3.12"]
|
||||||
|
exclude:
|
||||||
|
- os: macos-latest
|
||||||
|
python-version: "3.9"
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
lfs: true
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
- name: Install packages and dependencies for all tests
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip wheel
|
||||||
|
pip install pytest-cov>=5
|
||||||
|
- name: Install packages and dependencies for Groq
|
||||||
|
run: |
|
||||||
|
pip install -e .[groq,test]
|
||||||
|
- name: Set AUTOGEN_USE_DOCKER based on OS
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
if [[ ${{ matrix.os }} != ubuntu-latest ]]; then
|
||||||
|
echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV
|
||||||
|
fi
|
||||||
|
- name: Coverage
|
||||||
|
run: |
|
||||||
|
pytest test/oai/test_groq.py --skip-openai
|
||||||
|
- name: Upload coverage to Codecov
|
||||||
|
uses: codecov/codecov-action@v3
|
||||||
|
with:
|
||||||
|
file: ./coverage.xml
|
||||||
|
flags: unittests
|
||||||
|
|||||||
@ -19,6 +19,7 @@ if TYPE_CHECKING:
|
|||||||
from autogen import Agent, ConversableAgent, OpenAIWrapper
|
from autogen import Agent, ConversableAgent, OpenAIWrapper
|
||||||
from autogen.oai.anthropic import AnthropicClient
|
from autogen.oai.anthropic import AnthropicClient
|
||||||
from autogen.oai.gemini import GeminiClient
|
from autogen.oai.gemini import GeminiClient
|
||||||
|
from autogen.oai.groq import GroqClient
|
||||||
from autogen.oai.mistral import MistralAIClient
|
from autogen.oai.mistral import MistralAIClient
|
||||||
from autogen.oai.together import TogetherClient
|
from autogen.oai.together import TogetherClient
|
||||||
|
|
||||||
@ -204,7 +205,7 @@ class FileLogger(BaseLogger):
|
|||||||
|
|
||||||
def log_new_client(
|
def log_new_client(
|
||||||
self,
|
self,
|
||||||
client: AzureOpenAI | OpenAI | GeminiClient | AnthropicClient | MistralAIClient | TogetherClient,
|
client: AzureOpenAI | OpenAI | GeminiClient | AnthropicClient | MistralAIClient | TogetherClient | GroqClient,
|
||||||
wrapper: OpenAIWrapper,
|
wrapper: OpenAIWrapper,
|
||||||
init_args: Dict[str, Any],
|
init_args: Dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|||||||
@ -20,6 +20,7 @@ if TYPE_CHECKING:
|
|||||||
from autogen import Agent, ConversableAgent, OpenAIWrapper
|
from autogen import Agent, ConversableAgent, OpenAIWrapper
|
||||||
from autogen.oai.anthropic import AnthropicClient
|
from autogen.oai.anthropic import AnthropicClient
|
||||||
from autogen.oai.gemini import GeminiClient
|
from autogen.oai.gemini import GeminiClient
|
||||||
|
from autogen.oai.groq import GroqClient
|
||||||
from autogen.oai.mistral import MistralAIClient
|
from autogen.oai.mistral import MistralAIClient
|
||||||
from autogen.oai.together import TogetherClient
|
from autogen.oai.together import TogetherClient
|
||||||
|
|
||||||
@ -391,7 +392,7 @@ class SqliteLogger(BaseLogger):
|
|||||||
|
|
||||||
def log_new_client(
|
def log_new_client(
|
||||||
self,
|
self,
|
||||||
client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient],
|
client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient, GroqClient],
|
||||||
wrapper: OpenAIWrapper,
|
wrapper: OpenAIWrapper,
|
||||||
init_args: Dict[str, Any],
|
init_args: Dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|||||||
@ -70,6 +70,13 @@ try:
|
|||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
together_import_exception = e
|
together_import_exception = e
|
||||||
|
|
||||||
|
try:
|
||||||
|
from autogen.oai.groq import GroqClient
|
||||||
|
|
||||||
|
groq_import_exception: Optional[ImportError] = None
|
||||||
|
except ImportError as e:
|
||||||
|
groq_import_exception = e
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
if not logger.handlers:
|
if not logger.handlers:
|
||||||
# Add the console handler.
|
# Add the console handler.
|
||||||
@ -483,7 +490,13 @@ class OpenAIWrapper:
|
|||||||
elif api_type is not None and api_type.startswith("together"):
|
elif api_type is not None and api_type.startswith("together"):
|
||||||
if together_import_exception:
|
if together_import_exception:
|
||||||
raise ImportError("Please install `together` to use the Together.AI API.")
|
raise ImportError("Please install `together` to use the Together.AI API.")
|
||||||
self._clients.append(TogetherClient(**config))
|
client = TogetherClient(**openai_config)
|
||||||
|
self._clients.append(client)
|
||||||
|
elif api_type is not None and api_type.startswith("groq"):
|
||||||
|
if groq_import_exception:
|
||||||
|
raise ImportError("Please install `groq` to use the Groq API.")
|
||||||
|
client = GroqClient(**openai_config)
|
||||||
|
self._clients.append(client)
|
||||||
else:
|
else:
|
||||||
client = OpenAI(**openai_config)
|
client = OpenAI(**openai_config)
|
||||||
self._clients.append(OpenAIClient(client))
|
self._clients.append(OpenAIClient(client))
|
||||||
|
|||||||
289
autogen/oai/groq.py
Normal file
289
autogen/oai/groq.py
Normal file
@ -0,0 +1,289 @@
|
|||||||
|
"""Create an OpenAI-compatible client using Groq's API.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
llm_config={
|
||||||
|
"config_list": [{
|
||||||
|
"api_type": "groq",
|
||||||
|
"model": "mixtral-8x7b-32768",
|
||||||
|
"api_key": os.environ.get("GROQ_API_KEY")
|
||||||
|
}
|
||||||
|
]}
|
||||||
|
|
||||||
|
agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
|
||||||
|
|
||||||
|
Install Groq's python library using: pip install --upgrade groq
|
||||||
|
|
||||||
|
Resources:
|
||||||
|
- https://console.groq.com/docs/quickstart
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import warnings
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from groq import Groq, Stream
|
||||||
|
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
|
||||||
|
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
|
||||||
|
from openai.types.completion_usage import CompletionUsage
|
||||||
|
|
||||||
|
from autogen.oai.client_utils import should_hide_tools, validate_parameter
|
||||||
|
|
||||||
|
# Cost per thousand tokens - Input / Output (NOTE: Convert $/Million to $/K)
|
||||||
|
GROQ_PRICING_1K = {
|
||||||
|
"llama3-70b-8192": (0.00059, 0.00079),
|
||||||
|
"mixtral-8x7b-32768": (0.00024, 0.00024),
|
||||||
|
"llama3-8b-8192": (0.00005, 0.00008),
|
||||||
|
"gemma-7b-it": (0.00007, 0.00007),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class GroqClient:
|
||||||
|
"""Client for Groq's API."""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
"""Requires api_key or environment variable to be set
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key (str): The API key for using Groq (or environment variable GROQ_API_KEY needs to be set)
|
||||||
|
"""
|
||||||
|
# Ensure we have the api_key upon instantiation
|
||||||
|
self.api_key = kwargs.get("api_key", None)
|
||||||
|
if not self.api_key:
|
||||||
|
self.api_key = os.getenv("GROQ_API_KEY")
|
||||||
|
|
||||||
|
assert (
|
||||||
|
self.api_key
|
||||||
|
), "Please include the api_key in your config list entry for Groq or set the GROQ_API_KEY env variable."
|
||||||
|
|
||||||
|
def message_retrieval(self, response) -> List:
|
||||||
|
"""
|
||||||
|
Retrieve and return a list of strings or a list of Choice.Message from the response.
|
||||||
|
|
||||||
|
NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
|
||||||
|
since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
|
||||||
|
"""
|
||||||
|
return [choice.message for choice in response.choices]
|
||||||
|
|
||||||
|
def cost(self, response) -> float:
|
||||||
|
return response.cost
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_usage(response) -> Dict:
|
||||||
|
"""Return usage summary of the response using RESPONSE_USAGE_KEYS."""
|
||||||
|
# ... # pragma: no cover
|
||||||
|
return {
|
||||||
|
"prompt_tokens": response.usage.prompt_tokens,
|
||||||
|
"completion_tokens": response.usage.completion_tokens,
|
||||||
|
"total_tokens": response.usage.total_tokens,
|
||||||
|
"cost": response.cost,
|
||||||
|
"model": response.model,
|
||||||
|
}
|
||||||
|
|
||||||
|
def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Loads the parameters for Groq API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults"""
|
||||||
|
groq_params = {}
|
||||||
|
|
||||||
|
# Check that we have what we need to use Groq's API
|
||||||
|
# We won't enforce the available models as they are likely to change
|
||||||
|
groq_params["model"] = params.get("model", None)
|
||||||
|
assert groq_params[
|
||||||
|
"model"
|
||||||
|
], "Please specify the 'model' in your config list entry to nominate the Groq model to use."
|
||||||
|
|
||||||
|
# Validate allowed Groq parameters
|
||||||
|
# https://console.groq.com/docs/api-reference#chat
|
||||||
|
groq_params["frequency_penalty"] = validate_parameter(
|
||||||
|
params, "frequency_penalty", (int, float), True, None, (-2, 2), None
|
||||||
|
)
|
||||||
|
groq_params["max_tokens"] = validate_parameter(params, "max_tokens", int, True, None, (0, None), None)
|
||||||
|
groq_params["presence_penalty"] = validate_parameter(
|
||||||
|
params, "presence_penalty", (int, float), True, None, (-2, 2), None
|
||||||
|
)
|
||||||
|
groq_params["seed"] = validate_parameter(params, "seed", int, True, None, None, None)
|
||||||
|
groq_params["stream"] = validate_parameter(params, "stream", bool, True, False, None, None)
|
||||||
|
groq_params["temperature"] = validate_parameter(params, "temperature", (int, float), True, 1, (0, 2), None)
|
||||||
|
groq_params["top_p"] = validate_parameter(params, "top_p", (int, float), True, None, None, None)
|
||||||
|
|
||||||
|
# Groq parameters not supported by their models yet, ignoring
|
||||||
|
# logit_bias, logprobs, top_logprobs
|
||||||
|
|
||||||
|
# Groq parameters we are ignoring:
|
||||||
|
# n (must be 1), response_format (to enforce JSON but needs prompting as well), user,
|
||||||
|
# parallel_tool_calls (defaults to True), stop
|
||||||
|
# function_call (deprecated), functions (deprecated)
|
||||||
|
# tool_choice (none if no tools, auto if there are tools)
|
||||||
|
|
||||||
|
return groq_params
|
||||||
|
|
||||||
|
def create(self, params: Dict) -> ChatCompletion:
|
||||||
|
|
||||||
|
messages = params.get("messages", [])
|
||||||
|
|
||||||
|
# Convert AutoGen messages to Groq messages
|
||||||
|
groq_messages = oai_messages_to_groq_messages(messages)
|
||||||
|
|
||||||
|
# Parse parameters to the Groq API's parameters
|
||||||
|
groq_params = self.parse_params(params)
|
||||||
|
|
||||||
|
# Add tools to the call if we have them and aren't hiding them
|
||||||
|
if "tools" in params:
|
||||||
|
hide_tools = validate_parameter(
|
||||||
|
params, "hide_tools", str, False, "never", None, ["if_all_run", "if_any_run", "never"]
|
||||||
|
)
|
||||||
|
if not should_hide_tools(groq_messages, params["tools"], hide_tools):
|
||||||
|
groq_params["tools"] = params["tools"]
|
||||||
|
|
||||||
|
groq_params["messages"] = groq_messages
|
||||||
|
|
||||||
|
# We use chat model by default, and set max_retries to 5 (in line with typical retries loop)
|
||||||
|
client = Groq(api_key=self.api_key, max_retries=5)
|
||||||
|
|
||||||
|
# Token counts will be returned
|
||||||
|
prompt_tokens = 0
|
||||||
|
completion_tokens = 0
|
||||||
|
total_tokens = 0
|
||||||
|
|
||||||
|
# Streaming tool call recommendations
|
||||||
|
streaming_tool_calls = []
|
||||||
|
|
||||||
|
ans = None
|
||||||
|
try:
|
||||||
|
response = client.chat.completions.create(**groq_params)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Groq exception occurred: {e}")
|
||||||
|
else:
|
||||||
|
|
||||||
|
if groq_params["stream"]:
|
||||||
|
# Read in the chunks as they stream, taking in tool_calls which may be across
|
||||||
|
# multiple chunks if more than one suggested
|
||||||
|
ans = ""
|
||||||
|
for chunk in response:
|
||||||
|
ans = ans + (chunk.choices[0].delta.content or "")
|
||||||
|
|
||||||
|
if chunk.choices[0].delta.tool_calls:
|
||||||
|
# We have a tool call recommendation
|
||||||
|
for tool_call in chunk.choices[0].delta.tool_calls:
|
||||||
|
streaming_tool_calls.append(
|
||||||
|
ChatCompletionMessageToolCall(
|
||||||
|
id=tool_call.id,
|
||||||
|
function={
|
||||||
|
"name": tool_call.function.name,
|
||||||
|
"arguments": tool_call.function.arguments,
|
||||||
|
},
|
||||||
|
type="function",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if chunk.choices[0].finish_reason:
|
||||||
|
prompt_tokens = chunk.x_groq.usage.prompt_tokens
|
||||||
|
completion_tokens = chunk.x_groq.usage.completion_tokens
|
||||||
|
total_tokens = chunk.x_groq.usage.total_tokens
|
||||||
|
else:
|
||||||
|
# Non-streaming finished
|
||||||
|
ans: str = response.choices[0].message.content
|
||||||
|
|
||||||
|
prompt_tokens = response.usage.prompt_tokens
|
||||||
|
completion_tokens = response.usage.completion_tokens
|
||||||
|
total_tokens = response.usage.total_tokens
|
||||||
|
|
||||||
|
if response is not None:
|
||||||
|
|
||||||
|
if isinstance(response, Stream):
|
||||||
|
# Streaming response
|
||||||
|
if chunk.choices[0].finish_reason == "tool_calls":
|
||||||
|
groq_finish = "tool_calls"
|
||||||
|
tool_calls = streaming_tool_calls
|
||||||
|
else:
|
||||||
|
groq_finish = "stop"
|
||||||
|
tool_calls = None
|
||||||
|
|
||||||
|
response_content = ans
|
||||||
|
response_id = chunk.id
|
||||||
|
else:
|
||||||
|
# Non-streaming response
|
||||||
|
# If we have tool calls as the response, populate completed tool calls for our return OAI response
|
||||||
|
if response.choices[0].finish_reason == "tool_calls":
|
||||||
|
groq_finish = "tool_calls"
|
||||||
|
tool_calls = []
|
||||||
|
for tool_call in response.choices[0].message.tool_calls:
|
||||||
|
tool_calls.append(
|
||||||
|
ChatCompletionMessageToolCall(
|
||||||
|
id=tool_call.id,
|
||||||
|
function={"name": tool_call.function.name, "arguments": tool_call.function.arguments},
|
||||||
|
type="function",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
groq_finish = "stop"
|
||||||
|
tool_calls = None
|
||||||
|
|
||||||
|
response_content = response.choices[0].message.content
|
||||||
|
response_id = response.id
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Failed to get response from Groq after retrying 5 times.")
|
||||||
|
|
||||||
|
# 3. convert output
|
||||||
|
message = ChatCompletionMessage(
|
||||||
|
role="assistant",
|
||||||
|
content=response_content,
|
||||||
|
function_call=None,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
)
|
||||||
|
choices = [Choice(finish_reason=groq_finish, index=0, message=message)]
|
||||||
|
|
||||||
|
response_oai = ChatCompletion(
|
||||||
|
id=response_id,
|
||||||
|
model=groq_params["model"],
|
||||||
|
created=int(time.time()),
|
||||||
|
object="chat.completion",
|
||||||
|
choices=choices,
|
||||||
|
usage=CompletionUsage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
),
|
||||||
|
cost=calculate_groq_cost(prompt_tokens, completion_tokens, groq_params["model"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
return response_oai
|
||||||
|
|
||||||
|
|
||||||
|
def oai_messages_to_groq_messages(messages: list[Dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
"""Convert messages from OAI format to Groq's format.
|
||||||
|
We correct for any specific role orders and types.
|
||||||
|
"""
|
||||||
|
|
||||||
|
groq_messages = copy.deepcopy(messages)
|
||||||
|
|
||||||
|
# If we have a message with role='tool', which occurs when a function is executed, change it to 'user'
|
||||||
|
"""
|
||||||
|
for msg in together_messages:
|
||||||
|
if "role" in msg and msg["role"] == "tool":
|
||||||
|
msg["role"] = "user"
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Remove the name field
|
||||||
|
for message in groq_messages:
|
||||||
|
if "name" in message:
|
||||||
|
message.pop("name", None)
|
||||||
|
|
||||||
|
return groq_messages
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_groq_cost(input_tokens: int, output_tokens: int, model: str) -> float:
|
||||||
|
"""Calculate the cost of the completion using the Groq pricing."""
|
||||||
|
total = 0.0
|
||||||
|
|
||||||
|
if model in GROQ_PRICING_1K:
|
||||||
|
input_cost_per_k, output_cost_per_k = GROQ_PRICING_1K[model]
|
||||||
|
input_cost = (input_tokens / 1000) * input_cost_per_k
|
||||||
|
output_cost = (output_tokens / 1000) * output_cost_per_k
|
||||||
|
total = input_cost + output_cost
|
||||||
|
else:
|
||||||
|
warnings.warn(f"Cost calculation not available for model {model}", UserWarning)
|
||||||
|
|
||||||
|
return total
|
||||||
@ -15,6 +15,7 @@ if TYPE_CHECKING:
|
|||||||
from autogen import Agent, ConversableAgent, OpenAIWrapper
|
from autogen import Agent, ConversableAgent, OpenAIWrapper
|
||||||
from autogen.oai.anthropic import AnthropicClient
|
from autogen.oai.anthropic import AnthropicClient
|
||||||
from autogen.oai.gemini import GeminiClient
|
from autogen.oai.gemini import GeminiClient
|
||||||
|
from autogen.oai.groq import GroqClient
|
||||||
from autogen.oai.mistral import MistralAIClient
|
from autogen.oai.mistral import MistralAIClient
|
||||||
from autogen.oai.together import TogetherClient
|
from autogen.oai.together import TogetherClient
|
||||||
|
|
||||||
@ -110,7 +111,7 @@ def log_new_wrapper(wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig
|
|||||||
|
|
||||||
|
|
||||||
def log_new_client(
|
def log_new_client(
|
||||||
client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient],
|
client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient, GroqClient],
|
||||||
wrapper: OpenAIWrapper,
|
wrapper: OpenAIWrapper,
|
||||||
init_args: Dict[str, Any],
|
init_args: Dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|||||||
1
setup.py
1
setup.py
@ -91,6 +91,7 @@ extra_require = {
|
|||||||
"long-context": ["llmlingua<0.3"],
|
"long-context": ["llmlingua<0.3"],
|
||||||
"anthropic": ["anthropic>=0.23.1"],
|
"anthropic": ["anthropic>=0.23.1"],
|
||||||
"mistral": ["mistralai>=0.2.0"],
|
"mistral": ["mistralai>=0.2.0"],
|
||||||
|
"groq": ["groq>=0.9.0"],
|
||||||
}
|
}
|
||||||
|
|
||||||
setuptools.setup(
|
setuptools.setup(
|
||||||
|
|||||||
249
test/oai/test_groq.py
Normal file
249
test/oai/test_groq.py
Normal file
@ -0,0 +1,249 @@
|
|||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
try:
|
||||||
|
from autogen.oai.groq import GroqClient, calculate_groq_cost
|
||||||
|
|
||||||
|
skip = False
|
||||||
|
except ImportError:
|
||||||
|
GroqClient = object
|
||||||
|
InternalServerError = object
|
||||||
|
skip = True
|
||||||
|
|
||||||
|
|
||||||
|
# Fixtures for mock data
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_response():
|
||||||
|
class MockResponse:
|
||||||
|
def __init__(self, text, choices, usage, cost, model):
|
||||||
|
self.text = text
|
||||||
|
self.choices = choices
|
||||||
|
self.usage = usage
|
||||||
|
self.cost = cost
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
return MockResponse
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def groq_client():
|
||||||
|
return GroqClient(api_key="fake_api_key")
|
||||||
|
|
||||||
|
|
||||||
|
skip_reason = "Groq dependency is not installed"
|
||||||
|
|
||||||
|
|
||||||
|
# Test initialization and configuration
|
||||||
|
@pytest.mark.skipif(skip, reason=skip_reason)
|
||||||
|
def test_initialization():
|
||||||
|
|
||||||
|
# Missing any api_key
|
||||||
|
with pytest.raises(AssertionError) as assertinfo:
|
||||||
|
GroqClient() # Should raise an AssertionError due to missing api_key
|
||||||
|
|
||||||
|
assert "Please include the api_key in your config list entry for Groq or set the GROQ_API_KEY env variable." in str(
|
||||||
|
assertinfo.value
|
||||||
|
)
|
||||||
|
|
||||||
|
# Creation works
|
||||||
|
GroqClient(api_key="fake_api_key") # Should create okay now.
|
||||||
|
|
||||||
|
|
||||||
|
# Test standard initialization
|
||||||
|
@pytest.mark.skipif(skip, reason=skip_reason)
|
||||||
|
def test_valid_initialization(groq_client):
|
||||||
|
assert groq_client.api_key == "fake_api_key", "Config api_key should be correctly set"
|
||||||
|
|
||||||
|
|
||||||
|
# Test parameters
|
||||||
|
@pytest.mark.skipif(skip, reason=skip_reason)
|
||||||
|
def test_parsing_params(groq_client):
|
||||||
|
# All parameters
|
||||||
|
params = {
|
||||||
|
"model": "llama3-8b-8192",
|
||||||
|
"frequency_penalty": 1.5,
|
||||||
|
"presence_penalty": 1.5,
|
||||||
|
"max_tokens": 1000,
|
||||||
|
"seed": 42,
|
||||||
|
"stream": False,
|
||||||
|
"temperature": 1,
|
||||||
|
"top_p": 0.8,
|
||||||
|
}
|
||||||
|
expected_params = {
|
||||||
|
"model": "llama3-8b-8192",
|
||||||
|
"frequency_penalty": 1.5,
|
||||||
|
"presence_penalty": 1.5,
|
||||||
|
"max_tokens": 1000,
|
||||||
|
"seed": 42,
|
||||||
|
"stream": False,
|
||||||
|
"temperature": 1,
|
||||||
|
"top_p": 0.8,
|
||||||
|
}
|
||||||
|
result = groq_client.parse_params(params)
|
||||||
|
assert result == expected_params
|
||||||
|
|
||||||
|
# Only model, others set as defaults
|
||||||
|
params = {
|
||||||
|
"model": "llama3-8b-8192",
|
||||||
|
}
|
||||||
|
expected_params = {
|
||||||
|
"model": "llama3-8b-8192",
|
||||||
|
"frequency_penalty": None,
|
||||||
|
"presence_penalty": None,
|
||||||
|
"max_tokens": None,
|
||||||
|
"seed": None,
|
||||||
|
"stream": False,
|
||||||
|
"temperature": 1,
|
||||||
|
"top_p": None,
|
||||||
|
}
|
||||||
|
result = groq_client.parse_params(params)
|
||||||
|
assert result == expected_params
|
||||||
|
|
||||||
|
# Incorrect types, defaults should be set, will show warnings but not trigger assertions
|
||||||
|
params = {
|
||||||
|
"model": "llama3-8b-8192",
|
||||||
|
"frequency_penalty": "1.5",
|
||||||
|
"presence_penalty": "1.5",
|
||||||
|
"max_tokens": "1000",
|
||||||
|
"seed": "42",
|
||||||
|
"stream": "False",
|
||||||
|
"temperature": "1",
|
||||||
|
"top_p": "0.8",
|
||||||
|
}
|
||||||
|
result = groq_client.parse_params(params)
|
||||||
|
assert result == expected_params
|
||||||
|
|
||||||
|
# Values outside bounds, should warn and set to defaults
|
||||||
|
params = {
|
||||||
|
"model": "llama3-8b-8192",
|
||||||
|
"frequency_penalty": 5000,
|
||||||
|
"presence_penalty": -500,
|
||||||
|
"temperature": 3,
|
||||||
|
}
|
||||||
|
result = groq_client.parse_params(params)
|
||||||
|
assert result == expected_params
|
||||||
|
|
||||||
|
# No model
|
||||||
|
params = {
|
||||||
|
"frequency_penalty": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError) as assertinfo:
|
||||||
|
result = groq_client.parse_params(params)
|
||||||
|
|
||||||
|
assert "Please specify the 'model' in your config list entry to nominate the Groq model to use." in str(
|
||||||
|
assertinfo.value
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Test cost calculation
|
||||||
|
@pytest.mark.skipif(skip, reason=skip_reason)
|
||||||
|
def test_cost_calculation(mock_response):
|
||||||
|
response = mock_response(
|
||||||
|
text="Example response",
|
||||||
|
choices=[{"message": "Test message 1"}],
|
||||||
|
usage={"prompt_tokens": 500, "completion_tokens": 300, "total_tokens": 800},
|
||||||
|
cost=None,
|
||||||
|
model="llama3-70b-8192",
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
calculate_groq_cost(response.usage["prompt_tokens"], response.usage["completion_tokens"], response.model)
|
||||||
|
== 0.000532
|
||||||
|
), "Cost for this should be $0.000532"
|
||||||
|
|
||||||
|
|
||||||
|
# Test text generation
|
||||||
|
@pytest.mark.skipif(skip, reason=skip_reason)
|
||||||
|
@patch("autogen.oai.groq.GroqClient.create")
|
||||||
|
def test_create_response(mock_chat, groq_client):
|
||||||
|
# Mock GroqClient.chat response
|
||||||
|
mock_groq_response = MagicMock()
|
||||||
|
mock_groq_response.choices = [
|
||||||
|
MagicMock(finish_reason="stop", message=MagicMock(content="Example Groq response", tool_calls=None))
|
||||||
|
]
|
||||||
|
mock_groq_response.id = "mock_groq_response_id"
|
||||||
|
mock_groq_response.model = "llama3-70b-8192"
|
||||||
|
mock_groq_response.usage = MagicMock(prompt_tokens=10, completion_tokens=20) # Example token usage
|
||||||
|
|
||||||
|
mock_chat.return_value = mock_groq_response
|
||||||
|
|
||||||
|
# Test parameters
|
||||||
|
params = {
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "World"}],
|
||||||
|
"model": "llama3-70b-8192",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Call the create method
|
||||||
|
response = groq_client.create(params)
|
||||||
|
|
||||||
|
# Assertions to check if response is structured as expected
|
||||||
|
assert (
|
||||||
|
response.choices[0].message.content == "Example Groq response"
|
||||||
|
), "Response content should match expected output"
|
||||||
|
assert response.id == "mock_groq_response_id", "Response ID should match the mocked response ID"
|
||||||
|
assert response.model == "llama3-70b-8192", "Response model should match the mocked response model"
|
||||||
|
assert response.usage.prompt_tokens == 10, "Response prompt tokens should match the mocked response usage"
|
||||||
|
assert response.usage.completion_tokens == 20, "Response completion tokens should match the mocked response usage"
|
||||||
|
|
||||||
|
|
||||||
|
# Test functions/tools
|
||||||
|
@pytest.mark.skipif(skip, reason=skip_reason)
|
||||||
|
@patch("autogen.oai.groq.GroqClient.create")
|
||||||
|
def test_create_response_with_tool_call(mock_chat, groq_client):
|
||||||
|
# Mock `groq_response = client.chat(**groq_params)`
|
||||||
|
mock_function = MagicMock(name="currency_calculator")
|
||||||
|
mock_function.name = "currency_calculator"
|
||||||
|
mock_function.arguments = '{"base_currency": "EUR", "quote_currency": "USD", "base_amount": 123.45}'
|
||||||
|
|
||||||
|
mock_function_2 = MagicMock(name="get_weather")
|
||||||
|
mock_function_2.name = "get_weather"
|
||||||
|
mock_function_2.arguments = '{"location": "Chicago"}'
|
||||||
|
|
||||||
|
mock_chat.return_value = MagicMock(
|
||||||
|
choices=[
|
||||||
|
MagicMock(
|
||||||
|
finish_reason="tool_calls",
|
||||||
|
message=MagicMock(
|
||||||
|
content="Sample text about the functions",
|
||||||
|
tool_calls=[
|
||||||
|
MagicMock(id="gdRdrvnHh", function=mock_function),
|
||||||
|
MagicMock(id="abRdrvnHh", function=mock_function_2),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
id="mock_groq_response_id",
|
||||||
|
model="llama3-70b-8192",
|
||||||
|
usage=MagicMock(prompt_tokens=10, completion_tokens=20),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Construct parameters
|
||||||
|
converted_functions = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"description": "Currency exchange calculator.",
|
||||||
|
"name": "currency_calculator",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"base_amount": {"type": "number", "description": "Amount of currency in base_currency"},
|
||||||
|
},
|
||||||
|
"required": ["base_amount"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
groq_messages = [
|
||||||
|
{"role": "user", "content": "How much is 123.45 EUR in USD?"},
|
||||||
|
{"role": "assistant", "content": "World"},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Call the create method
|
||||||
|
response = groq_client.create({"messages": groq_messages, "tools": converted_functions, "model": "llama3-70b-8192"})
|
||||||
|
|
||||||
|
# Assertions to check if the functions and content are included in the response
|
||||||
|
assert response.choices[0].message.content == "Sample text about the functions"
|
||||||
|
assert response.choices[0].message.tool_calls[0].function.name == "currency_calculator"
|
||||||
|
assert response.choices[0].message.tool_calls[1].function.name == "get_weather"
|
||||||
Loading…
x
Reference in New Issue
Block a user