# SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 from unittest.mock import MagicMock, patch import pytest from typing import Iterator import logging import os import json from datetime import datetime from openai import OpenAIError from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage, ChatCompletionMessageToolCall from openai.types.chat.chat_completion import Choice from openai.types.chat.chat_completion_message_tool_call import Function from openai.types.chat import chat_completion_chunk from openai import Stream from haystack.components.generators.utils import print_streaming_chunk from haystack.dataclasses import StreamingChunk from haystack.utils.auth import Secret from haystack.dataclasses import ChatMessage, Tool, ToolCall, ChatRole, TextContent from haystack.components.generators.chat.openai import OpenAIChatGenerator @pytest.fixture def chat_messages(): return [ ChatMessage.from_system("You are a helpful assistant"), ChatMessage.from_user("What's the capital of France"), ] @pytest.fixture def mock_chat_completion_chunk_with_tools(openai_mock_stream): """ Mock the OpenAI API completion chunk response and reuse it for tests """ with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: completion = ChatCompletionChunk( id="foo", model="gpt-4", object="chat.completion.chunk", choices=[ chat_completion_chunk.Choice( finish_reason="tool_calls", logprobs=None, index=0, delta=chat_completion_chunk.ChoiceDelta( role="assistant", tool_calls=[ chat_completion_chunk.ChoiceDeltaToolCall( index=0, id="123", type="function", function=chat_completion_chunk.ChoiceDeltaToolCallFunction( name="weather", arguments='{"city": "Paris"}' ), ) ], ), ) ], created=int(datetime.now().timestamp()), usage={"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, ) mock_chat_completion_create.return_value = openai_mock_stream( completion, cast_to=None, response=None, client=None ) yield mock_chat_completion_create @pytest.fixture def tools(): tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} tool = Tool( name="weather", description="useful to determine the weather in a given location", parameters=tool_parameters, function=lambda x: x, ) return [tool] class TestOpenAIChatGenerator: def test_init_default(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") component = OpenAIChatGenerator() assert component.client.api_key == "test-api-key" assert component.model == "gpt-4o-mini" assert component.streaming_callback is None assert not component.generation_kwargs assert component.client.timeout == 30 assert component.client.max_retries == 5 assert component.tools is None assert not component.tools_strict def test_init_fail_wo_api_key(self, monkeypatch): monkeypatch.delenv("OPENAI_API_KEY", raising=False) with pytest.raises(ValueError): OpenAIChatGenerator() def test_init_fail_with_duplicate_tool_names(self, monkeypatch, tools): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") duplicate_tools = [tools[0], tools[0]] with pytest.raises(ValueError): OpenAIChatGenerator(tools=duplicate_tools) def test_init_with_parameters(self, monkeypatch): tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=lambda x: x) monkeypatch.setenv("OPENAI_TIMEOUT", "100") monkeypatch.setenv("OPENAI_MAX_RETRIES", "10") component = OpenAIChatGenerator( api_key=Secret.from_token("test-api-key"), model="gpt-4o-mini", streaming_callback=print_streaming_chunk, api_base_url="test-base-url", generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, timeout=40.0, max_retries=1, tools=[tool], tools_strict=True, ) assert component.client.api_key == "test-api-key" assert component.model == "gpt-4o-mini" assert component.streaming_callback is print_streaming_chunk assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} assert component.client.timeout == 40.0 assert component.client.max_retries == 1 assert component.tools == [tool] assert component.tools_strict def test_init_with_parameters_and_env_vars(self, monkeypatch): monkeypatch.setenv("OPENAI_TIMEOUT", "100") monkeypatch.setenv("OPENAI_MAX_RETRIES", "10") component = OpenAIChatGenerator( api_key=Secret.from_token("test-api-key"), model="gpt-4o-mini", streaming_callback=print_streaming_chunk, api_base_url="test-base-url", generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, ) assert component.client.api_key == "test-api-key" assert component.model == "gpt-4o-mini" assert component.streaming_callback is print_streaming_chunk assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} assert component.client.timeout == 100.0 assert component.client.max_retries == 10 def test_to_dict_default(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") component = OpenAIChatGenerator() data = component.to_dict() assert data == { "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", "init_parameters": { "api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"}, "model": "gpt-4o-mini", "organization": None, "streaming_callback": None, "api_base_url": None, "generation_kwargs": {}, "tools": None, "tools_strict": False, "max_retries": None, "timeout": None, }, } def test_to_dict_with_parameters(self, monkeypatch): tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) monkeypatch.setenv("ENV_VAR", "test-api-key") component = OpenAIChatGenerator( api_key=Secret.from_env_var("ENV_VAR"), model="gpt-4o-mini", streaming_callback=print_streaming_chunk, api_base_url="test-base-url", generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, tools=[tool], tools_strict=True, max_retries=10, timeout=100.0, ) data = component.to_dict() assert data == { "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", "init_parameters": { "api_key": {"env_vars": ["ENV_VAR"], "strict": True, "type": "env_var"}, "model": "gpt-4o-mini", "organization": None, "api_base_url": "test-base-url", "max_retries": 10, "timeout": 100.0, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, "tools": [ { "description": "description", "function": "builtins.print", "name": "name", "parameters": {"x": {"type": "string"}}, } ], "tools_strict": True, }, } def test_to_dict_with_lambda_streaming_callback(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") component = OpenAIChatGenerator( model="gpt-4o-mini", streaming_callback=lambda x: x, api_base_url="test-base-url", generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, ) data = component.to_dict() assert data == { "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", "init_parameters": { "api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"}, "model": "gpt-4o-mini", "organization": None, "api_base_url": "test-base-url", "max_retries": None, "timeout": None, "streaming_callback": "chat.test_openai.", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, "tools": None, "tools_strict": False, }, } def test_from_dict(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") data = { "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", "init_parameters": { "api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"}, "model": "gpt-4o-mini", "api_base_url": "test-base-url", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "max_retries": 10, "timeout": 100.0, "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, "tools": [ { "description": "description", "function": "builtins.print", "name": "name", "parameters": {"x": {"type": "string"}}, } ], "tools_strict": True, }, } component = OpenAIChatGenerator.from_dict(data) assert isinstance(component, OpenAIChatGenerator) assert component.model == "gpt-4o-mini" assert component.streaming_callback is print_streaming_chunk assert component.api_base_url == "test-base-url" assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} assert component.api_key == Secret.from_env_var("OPENAI_API_KEY") assert component.tools == [ Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) ] assert component.tools_strict assert component.client.timeout == 100.0 assert component.client.max_retries == 10 def test_from_dict_fail_wo_env_var(self, monkeypatch): monkeypatch.delenv("OPENAI_API_KEY", raising=False) data = { "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", "init_parameters": { "api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"}, "model": "gpt-4", "organization": None, "api_base_url": "test-base-url", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, }, } with pytest.raises(ValueError): OpenAIChatGenerator.from_dict(data) def test_run(self, chat_messages, openai_mock_chat_completion): component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) response = component.run(chat_messages) # check that the component returns the correct ChatMessage response assert isinstance(response, dict) assert "replies" in response assert isinstance(response["replies"], list) assert len(response["replies"]) == 1 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] def test_run_with_params(self, chat_messages, openai_mock_chat_completion): component = OpenAIChatGenerator( api_key=Secret.from_token("test-api-key"), generation_kwargs={"max_tokens": 10, "temperature": 0.5} ) response = component.run(chat_messages) # check that the component calls the OpenAI API with the correct parameters _, kwargs = openai_mock_chat_completion.call_args assert kwargs["max_tokens"] == 10 assert kwargs["temperature"] == 0.5 # check that the component returns the correct response assert isinstance(response, dict) assert "replies" in response assert isinstance(response["replies"], list) assert len(response["replies"]) == 1 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] def test_run_with_params_streaming(self, chat_messages, openai_mock_chat_completion_chunk): streaming_callback_called = False def streaming_callback(chunk: StreamingChunk) -> None: nonlocal streaming_callback_called streaming_callback_called = True component = OpenAIChatGenerator( api_key=Secret.from_token("test-api-key"), streaming_callback=streaming_callback ) response = component.run(chat_messages) # check we called the streaming callback assert streaming_callback_called # check that the component still returns the correct response assert isinstance(response, dict) assert "replies" in response assert isinstance(response["replies"], list) assert len(response["replies"]) == 1 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] assert "Hello" in response["replies"][0].text # see openai_mock_chat_completion_chunk def test_run_with_streaming_callback_in_run_method(self, chat_messages, openai_mock_chat_completion_chunk): streaming_callback_called = False def streaming_callback(chunk: StreamingChunk) -> None: nonlocal streaming_callback_called streaming_callback_called = True component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) response = component.run(chat_messages, streaming_callback=streaming_callback) # check we called the streaming callback assert streaming_callback_called # check that the component still returns the correct response assert isinstance(response, dict) assert "replies" in response assert isinstance(response["replies"], list) assert len(response["replies"]) == 1 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] assert "Hello" in response["replies"][0].text # see openai_mock_chat_completion_chunk def test_check_abnormal_completions(self, caplog): caplog.set_level(logging.INFO) component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) messages = [ ChatMessage.from_assistant( "", meta={"finish_reason": "content_filter" if i % 2 == 0 else "length", "index": i} ) for i, _ in enumerate(range(4)) ] for m in messages: component._check_finish_reason(m.meta) # check truncation warning message_template = ( "The completion for index {index} has been truncated before reaching a natural stopping point. " "Increase the max_tokens parameter to allow for longer completions." ) for index in [1, 3]: assert caplog.records[index].message == message_template.format(index=index) # check content filter warning message_template = "The completion for index {index} has been truncated due to the content filter." for index in [0, 2]: assert caplog.records[index].message == message_template.format(index=index) def test_run_with_tools(self, tools): with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: completion = ChatCompletion( id="foo", model="gpt-4", object="chat.completion", choices=[ Choice( finish_reason="tool_calls", logprobs=None, index=0, message=ChatCompletionMessage( role="assistant", tool_calls=[ ChatCompletionMessageToolCall( id="123", type="function", function=Function(name="weather", arguments='{"city": "Paris"}'), ) ], ), ) ], created=int(datetime.now().timestamp()), usage={"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, ) mock_chat_completion_create.return_value = completion component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"), tools=tools) response = component.run([ChatMessage.from_user("What's the weather like in Paris?")]) assert len(response["replies"]) == 1 message = response["replies"][0] assert not message.texts assert not message.text assert message.tool_calls tool_call = message.tool_call assert isinstance(tool_call, ToolCall) assert tool_call.tool_name == "weather" assert tool_call.arguments == {"city": "Paris"} assert message.meta["finish_reason"] == "tool_calls" def test_run_with_tools_streaming(self, mock_chat_completion_chunk_with_tools, tools): streaming_callback_called = False def streaming_callback(chunk: StreamingChunk) -> None: nonlocal streaming_callback_called streaming_callback_called = True component = OpenAIChatGenerator( api_key=Secret.from_token("test-api-key"), streaming_callback=streaming_callback ) chat_messages = [ChatMessage.from_user("What's the weather like in Paris?")] response = component.run(chat_messages, tools=tools) # check we called the streaming callback assert streaming_callback_called # check that the component still returns the correct response assert isinstance(response, dict) assert "replies" in response assert isinstance(response["replies"], list) assert len(response["replies"]) == 1 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] message = response["replies"][0] assert message.tool_calls tool_call = message.tool_call assert isinstance(tool_call, ToolCall) assert tool_call.tool_name == "weather" assert tool_call.arguments == {"city": "Paris"} assert message.meta["finish_reason"] == "tool_calls" def test_invalid_tool_call_json(self, tools, caplog): caplog.set_level(logging.WARNING) with patch("openai.resources.chat.completions.Completions.create") as mock_create: mock_create.return_value = ChatCompletion( id="test", model="gpt-4o-mini", object="chat.completion", choices=[ Choice( finish_reason="tool_calls", index=0, message=ChatCompletionMessage( role="assistant", tool_calls=[ ChatCompletionMessageToolCall( id="1", type="function", function=Function(name="weather", arguments='"invalid": "json"'), ) ], ), ) ], created=1234567890, usage={"prompt_tokens": 50, "completion_tokens": 30, "total_tokens": 80}, ) component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"), tools=tools) response = component.run([ChatMessage.from_user("What's the weather in Paris?")]) assert len(response["replies"]) == 1 message = response["replies"][0] assert len(message.tool_calls) == 0 assert "OpenAI returned a malformed JSON string for tool call arguments" in caplog.text @pytest.mark.skipif( not os.environ.get("OPENAI_API_KEY", None), reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", ) @pytest.mark.integration def test_live_run(self): chat_messages = [ChatMessage.from_user("What's the capital of France")] component = OpenAIChatGenerator(generation_kwargs={"n": 1}) results = component.run(chat_messages) assert len(results["replies"]) == 1 message: ChatMessage = results["replies"][0] assert "Paris" in message.text assert "gpt-4o" in message.meta["model"] assert message.meta["finish_reason"] == "stop" @pytest.mark.skipif( not os.environ.get("OPENAI_API_KEY", None), reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", ) @pytest.mark.integration def test_live_run_wrong_model(self, chat_messages): component = OpenAIChatGenerator(model="something-obviously-wrong") with pytest.raises(OpenAIError): component.run(chat_messages) @pytest.mark.skipif( not os.environ.get("OPENAI_API_KEY", None), reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", ) @pytest.mark.integration def test_live_run_streaming(self): class Callback: def __init__(self): self.responses = "" self.counter = 0 def __call__(self, chunk: StreamingChunk) -> None: self.counter += 1 self.responses += chunk.content if chunk.content else "" callback = Callback() component = OpenAIChatGenerator(streaming_callback=callback) results = component.run([ChatMessage.from_user("What's the capital of France?")]) assert len(results["replies"]) == 1 message: ChatMessage = results["replies"][0] assert "Paris" in message.text assert "gpt-4o" in message.meta["model"] assert message.meta["finish_reason"] == "stop" assert callback.counter > 1 assert "Paris" in callback.responses @pytest.mark.skipif( not os.environ.get("OPENAI_API_KEY", None), reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", ) @pytest.mark.integration def test_live_run_with_tools(self, tools): chat_messages = [ChatMessage.from_user("What's the weather like in Paris?")] component = OpenAIChatGenerator(tools=tools) results = component.run(chat_messages) assert len(results["replies"]) == 1 message = results["replies"][0] assert not message.texts assert not message.text assert message.tool_calls tool_call = message.tool_call assert isinstance(tool_call, ToolCall) assert tool_call.tool_name == "weather" assert tool_call.arguments == {"city": "Paris"} assert message.meta["finish_reason"] == "tool_calls"