| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  | import os | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import pytest | 
					
						
							| 
									
										
										
										
											2023-12-21 16:21:24 +01:00
										 |  |  | from openai import OpenAIError | 
					
						
							| 
									
										
										
										
											2024-02-05 13:17:01 +01:00
										 |  |  | from haystack.utils.auth import Secret | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-22 19:37:29 +01:00
										 |  |  | from haystack.components.generators.chat import OpenAIChatGenerator | 
					
						
							| 
									
										
										
										
											2024-01-26 16:00:02 +01:00
										 |  |  | from haystack.components.generators.utils import print_streaming_chunk | 
					
						
							| 
									
										
										
										
											2023-11-24 14:48:43 +01:00
										 |  |  | from haystack.dataclasses import ChatMessage, StreamingChunk | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @pytest.fixture | 
					
						
							|  |  |  | def chat_messages(): | 
					
						
							|  |  |  |     return [ | 
					
						
							|  |  |  |         ChatMessage.from_system("You are a helpful assistant"), | 
					
						
							|  |  |  |         ChatMessage.from_user("What's the capital of France"), | 
					
						
							|  |  |  |     ] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-22 19:37:29 +01:00
										 |  |  | class TestOpenAIChatGenerator: | 
					
						
							| 
									
										
										
										
											2024-02-05 13:17:01 +01:00
										 |  |  |     def test_init_default(self, monkeypatch): | 
					
						
							|  |  |  |         monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") | 
					
						
							|  |  |  |         component = OpenAIChatGenerator() | 
					
						
							| 
									
										
										
										
											2023-12-21 16:21:24 +01:00
										 |  |  |         assert component.client.api_key == "test-api-key" | 
					
						
							| 
									
										
										
										
											2024-01-12 17:28:01 +05:30
										 |  |  |         assert component.model == "gpt-3.5-turbo" | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  |         assert component.streaming_callback is None | 
					
						
							|  |  |  |         assert not component.generation_kwargs | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_init_fail_wo_api_key(self, monkeypatch): | 
					
						
							|  |  |  |         monkeypatch.delenv("OPENAI_API_KEY", raising=False) | 
					
						
							| 
									
										
										
										
											2024-02-05 13:17:01 +01:00
										 |  |  |         with pytest.raises(ValueError, match="None of the .* environment variables are set"): | 
					
						
							| 
									
										
										
										
											2023-12-22 19:37:29 +01:00
										 |  |  |             OpenAIChatGenerator() | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def test_init_with_parameters(self): | 
					
						
							| 
									
										
										
										
											2023-12-22 19:37:29 +01:00
										 |  |  |         component = OpenAIChatGenerator( | 
					
						
							| 
									
										
										
										
											2024-02-05 13:17:01 +01:00
										 |  |  |             api_key=Secret.from_token("test-api-key"), | 
					
						
							| 
									
										
										
										
											2024-01-12 17:28:01 +05:30
										 |  |  |             model="gpt-4", | 
					
						
							| 
									
										
										
										
											2024-01-26 16:00:02 +01:00
										 |  |  |             streaming_callback=print_streaming_chunk, | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  |             api_base_url="test-base-url", | 
					
						
							| 
									
										
										
										
											2023-11-22 10:40:48 +01:00
										 |  |  |             generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2023-12-21 16:21:24 +01:00
										 |  |  |         assert component.client.api_key == "test-api-key" | 
					
						
							| 
									
										
										
										
											2024-01-12 17:28:01 +05:30
										 |  |  |         assert component.model == "gpt-4" | 
					
						
							| 
									
										
										
										
											2024-01-26 16:00:02 +01:00
										 |  |  |         assert component.streaming_callback is print_streaming_chunk | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  |         assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-05 13:17:01 +01:00
										 |  |  |     def test_to_dict_default(self, monkeypatch): | 
					
						
							|  |  |  |         monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") | 
					
						
							|  |  |  |         component = OpenAIChatGenerator() | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  |         data = component.to_dict() | 
					
						
							|  |  |  |         assert data == { | 
					
						
							| 
									
										
										
										
											2023-12-22 19:37:29 +01:00
										 |  |  |             "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  |             "init_parameters": { | 
					
						
							| 
									
										
										
										
											2024-02-05 13:17:01 +01:00
										 |  |  |                 "api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"}, | 
					
						
							| 
									
										
										
										
											2024-01-12 17:28:01 +05:30
										 |  |  |                 "model": "gpt-3.5-turbo", | 
					
						
							| 
									
										
										
										
											2023-12-21 16:21:24 +01:00
										 |  |  |                 "organization": None, | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  |                 "streaming_callback": None, | 
					
						
							| 
									
										
										
										
											2023-12-21 16:21:24 +01:00
										 |  |  |                 "api_base_url": None, | 
					
						
							| 
									
										
										
										
											2023-11-22 10:40:48 +01:00
										 |  |  |                 "generation_kwargs": {}, | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  |             }, | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-05 13:17:01 +01:00
										 |  |  |     def test_to_dict_with_parameters(self, monkeypatch): | 
					
						
							|  |  |  |         monkeypatch.setenv("ENV_VAR", "test-api-key") | 
					
						
							| 
									
										
										
										
											2023-12-22 19:37:29 +01:00
										 |  |  |         component = OpenAIChatGenerator( | 
					
						
							| 
									
										
										
										
											2024-02-05 13:17:01 +01:00
										 |  |  |             api_key=Secret.from_env_var("ENV_VAR"), | 
					
						
							| 
									
										
										
										
											2024-01-12 17:28:01 +05:30
										 |  |  |             model="gpt-4", | 
					
						
							| 
									
										
										
										
											2024-01-26 16:00:02 +01:00
										 |  |  |             streaming_callback=print_streaming_chunk, | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  |             api_base_url="test-base-url", | 
					
						
							| 
									
										
										
										
											2023-11-22 10:40:48 +01:00
										 |  |  |             generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  |         ) | 
					
						
							|  |  |  |         data = component.to_dict() | 
					
						
							|  |  |  |         assert data == { | 
					
						
							| 
									
										
										
										
											2023-12-22 19:37:29 +01:00
										 |  |  |             "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  |             "init_parameters": { | 
					
						
							| 
									
										
										
										
											2024-02-05 13:17:01 +01:00
										 |  |  |                 "api_key": {"env_vars": ["ENV_VAR"], "strict": True, "type": "env_var"}, | 
					
						
							| 
									
										
										
										
											2024-01-12 17:28:01 +05:30
										 |  |  |                 "model": "gpt-4", | 
					
						
							| 
									
										
										
										
											2023-12-21 16:21:24 +01:00
										 |  |  |                 "organization": None, | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  |                 "api_base_url": "test-base-url", | 
					
						
							| 
									
										
										
										
											2024-01-26 16:00:02 +01:00
										 |  |  |                 "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", | 
					
						
							| 
									
										
										
										
											2023-11-22 10:40:48 +01:00
										 |  |  |                 "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  |             }, | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-05 13:17:01 +01:00
										 |  |  |     def test_to_dict_with_lambda_streaming_callback(self, monkeypatch): | 
					
						
							|  |  |  |         monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") | 
					
						
							| 
									
										
										
										
											2023-12-22 19:37:29 +01:00
										 |  |  |         component = OpenAIChatGenerator( | 
					
						
							| 
									
										
										
										
											2024-01-12 17:28:01 +05:30
										 |  |  |             model="gpt-4", | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  |             streaming_callback=lambda x: x, | 
					
						
							|  |  |  |             api_base_url="test-base-url", | 
					
						
							| 
									
										
										
										
											2023-11-22 10:40:48 +01:00
										 |  |  |             generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  |         ) | 
					
						
							|  |  |  |         data = component.to_dict() | 
					
						
							|  |  |  |         assert data == { | 
					
						
							| 
									
										
										
										
											2023-12-22 19:37:29 +01:00
										 |  |  |             "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  |             "init_parameters": { | 
					
						
							| 
									
										
										
										
											2024-02-05 13:17:01 +01:00
										 |  |  |                 "api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"}, | 
					
						
							| 
									
										
										
										
											2024-01-12 17:28:01 +05:30
										 |  |  |                 "model": "gpt-4", | 
					
						
							| 
									
										
										
										
											2023-12-21 16:21:24 +01:00
										 |  |  |                 "organization": None, | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  |                 "api_base_url": "test-base-url", | 
					
						
							|  |  |  |                 "streaming_callback": "chat.test_openai.<lambda>", | 
					
						
							| 
									
										
										
										
											2023-11-22 10:40:48 +01:00
										 |  |  |                 "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  |             }, | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-05 13:17:01 +01:00
										 |  |  |     def test_from_dict(self, monkeypatch): | 
					
						
							|  |  |  |         monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  |         data = { | 
					
						
							| 
									
										
										
										
											2023-12-22 19:37:29 +01:00
										 |  |  |             "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  |             "init_parameters": { | 
					
						
							| 
									
										
										
										
											2024-02-05 13:17:01 +01:00
										 |  |  |                 "api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"}, | 
					
						
							| 
									
										
										
										
											2024-01-12 17:28:01 +05:30
										 |  |  |                 "model": "gpt-4", | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  |                 "api_base_url": "test-base-url", | 
					
						
							| 
									
										
										
										
											2024-01-26 16:00:02 +01:00
										 |  |  |                 "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", | 
					
						
							| 
									
										
										
										
											2023-11-22 10:40:48 +01:00
										 |  |  |                 "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  |             }, | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2023-12-22 19:37:29 +01:00
										 |  |  |         component = OpenAIChatGenerator.from_dict(data) | 
					
						
							| 
									
										
										
										
											2024-01-12 17:28:01 +05:30
										 |  |  |         assert component.model == "gpt-4" | 
					
						
							| 
									
										
										
										
											2024-01-26 16:00:02 +01:00
										 |  |  |         assert component.streaming_callback is print_streaming_chunk | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  |         assert component.api_base_url == "test-base-url" | 
					
						
							|  |  |  |         assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} | 
					
						
							| 
									
										
										
										
											2024-02-05 13:17:01 +01:00
										 |  |  |         assert component.api_key == Secret.from_env_var("OPENAI_API_KEY") | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def test_from_dict_fail_wo_env_var(self, monkeypatch): | 
					
						
							|  |  |  |         monkeypatch.delenv("OPENAI_API_KEY", raising=False) | 
					
						
							|  |  |  |         data = { | 
					
						
							| 
									
										
										
										
											2023-12-22 19:37:29 +01:00
										 |  |  |             "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  |             "init_parameters": { | 
					
						
							| 
									
										
										
										
											2024-02-05 13:17:01 +01:00
										 |  |  |                 "api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"}, | 
					
						
							| 
									
										
										
										
											2024-01-12 17:28:01 +05:30
										 |  |  |                 "model": "gpt-4", | 
					
						
							| 
									
										
										
										
											2023-12-21 16:21:24 +01:00
										 |  |  |                 "organization": None, | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  |                 "api_base_url": "test-base-url", | 
					
						
							| 
									
										
										
										
											2024-01-26 16:00:02 +01:00
										 |  |  |                 "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", | 
					
						
							| 
									
										
										
										
											2023-11-22 10:40:48 +01:00
										 |  |  |                 "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  |             }, | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2024-02-05 13:17:01 +01:00
										 |  |  |         with pytest.raises(ValueError, match="None of the .* environment variables are set"): | 
					
						
							| 
									
										
										
										
											2023-12-22 19:37:29 +01:00
										 |  |  |             OpenAIChatGenerator.from_dict(data) | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def test_run(self, chat_messages, mock_chat_completion): | 
					
						
							| 
									
										
										
										
											2024-02-05 13:17:01 +01:00
										 |  |  |         component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  |         response = component.run(chat_messages) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # check that the component returns the correct ChatMessage response | 
					
						
							|  |  |  |         assert isinstance(response, dict) | 
					
						
							|  |  |  |         assert "replies" in response | 
					
						
							|  |  |  |         assert isinstance(response["replies"], list) | 
					
						
							|  |  |  |         assert len(response["replies"]) == 1 | 
					
						
							|  |  |  |         assert [isinstance(reply, ChatMessage) for reply in response["replies"]] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_run_with_params(self, chat_messages, mock_chat_completion): | 
					
						
							| 
									
										
										
										
											2024-02-05 13:17:01 +01:00
										 |  |  |         component = OpenAIChatGenerator( | 
					
						
							|  |  |  |             api_key=Secret.from_token("test-api-key"), generation_kwargs={"max_tokens": 10, "temperature": 0.5} | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  |         response = component.run(chat_messages) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # check that the component calls the OpenAI API with the correct parameters | 
					
						
							|  |  |  |         _, kwargs = mock_chat_completion.call_args | 
					
						
							|  |  |  |         assert kwargs["max_tokens"] == 10 | 
					
						
							|  |  |  |         assert kwargs["temperature"] == 0.5 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # check that the component returns the correct response | 
					
						
							|  |  |  |         assert isinstance(response, dict) | 
					
						
							|  |  |  |         assert "replies" in response | 
					
						
							|  |  |  |         assert isinstance(response["replies"], list) | 
					
						
							|  |  |  |         assert len(response["replies"]) == 1 | 
					
						
							|  |  |  |         assert [isinstance(reply, ChatMessage) for reply in response["replies"]] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-21 16:21:24 +01:00
										 |  |  |     def test_run_with_params_streaming(self, chat_messages, mock_chat_completion_chunk): | 
					
						
							|  |  |  |         streaming_callback_called = False | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-21 16:21:24 +01:00
										 |  |  |         def streaming_callback(chunk: StreamingChunk) -> None: | 
					
						
							|  |  |  |             nonlocal streaming_callback_called | 
					
						
							|  |  |  |             streaming_callback_called = True | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-05 13:17:01 +01:00
										 |  |  |         component = OpenAIChatGenerator( | 
					
						
							|  |  |  |             api_key=Secret.from_token("test-api-key"), streaming_callback=streaming_callback | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2023-12-21 16:21:24 +01:00
										 |  |  |         response = component.run(chat_messages) | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-21 16:21:24 +01:00
										 |  |  |         # check we called the streaming callback | 
					
						
							|  |  |  |         assert streaming_callback_called | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-21 16:21:24 +01:00
										 |  |  |         # check that the component still returns the correct response | 
					
						
							|  |  |  |         assert isinstance(response, dict) | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  |         assert "replies" in response | 
					
						
							|  |  |  |         assert isinstance(response["replies"], list) | 
					
						
							| 
									
										
										
										
											2023-12-21 16:21:24 +01:00
										 |  |  |         assert len(response["replies"]) == 1 | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  |         assert [isinstance(reply, ChatMessage) for reply in response["replies"]] | 
					
						
							| 
									
										
										
										
											2023-12-21 16:21:24 +01:00
										 |  |  |         assert "Hello" in response["replies"][0].content  # see mock_chat_completion_chunk | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def test_check_abnormal_completions(self, caplog): | 
					
						
							| 
									
										
										
										
											2024-02-05 13:17:01 +01:00
										 |  |  |         component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  |         messages = [ | 
					
						
							|  |  |  |             ChatMessage.from_assistant( | 
					
						
							| 
									
										
										
										
											2023-12-21 17:09:58 +05:30
										 |  |  |                 "", meta={"finish_reason": "content_filter" if i % 2 == 0 else "length", "index": i} | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  |             ) | 
					
						
							|  |  |  |             for i, _ in enumerate(range(4)) | 
					
						
							|  |  |  |         ] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for m in messages: | 
					
						
							|  |  |  |             component._check_finish_reason(m) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # check truncation warning | 
					
						
							|  |  |  |         message_template = ( | 
					
						
							|  |  |  |             "The completion for index {index} has been truncated before reaching a natural stopping point. " | 
					
						
							|  |  |  |             "Increase the max_tokens parameter to allow for longer completions." | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for index in [1, 3]: | 
					
						
							|  |  |  |             assert caplog.records[index].message == message_template.format(index=index) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # check content filter warning | 
					
						
							|  |  |  |         message_template = "The completion for index {index} has been truncated due to the content filter." | 
					
						
							|  |  |  |         for index in [0, 2]: | 
					
						
							|  |  |  |             assert caplog.records[index].message == message_template.format(index=index) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @pytest.mark.skipif( | 
					
						
							|  |  |  |         not os.environ.get("OPENAI_API_KEY", None), | 
					
						
							|  |  |  |         reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_live_run(self): | 
					
						
							|  |  |  |         chat_messages = [ChatMessage.from_user("What's the capital of France")] | 
					
						
							| 
									
										
										
										
											2024-02-05 13:17:01 +01:00
										 |  |  |         component = OpenAIChatGenerator(generation_kwargs={"n": 1}) | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  |         results = component.run(chat_messages) | 
					
						
							|  |  |  |         assert len(results["replies"]) == 1 | 
					
						
							|  |  |  |         message: ChatMessage = results["replies"][0] | 
					
						
							|  |  |  |         assert "Paris" in message.content | 
					
						
							| 
									
										
										
										
											2023-12-21 14:09:31 +01:00
										 |  |  |         assert "gpt-3.5" in message.meta["model"] | 
					
						
							|  |  |  |         assert message.meta["finish_reason"] == "stop" | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     @pytest.mark.skipif( | 
					
						
							|  |  |  |         not os.environ.get("OPENAI_API_KEY", None), | 
					
						
							|  |  |  |         reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_live_run_wrong_model(self, chat_messages): | 
					
						
							| 
									
										
										
										
											2024-02-05 13:17:01 +01:00
										 |  |  |         component = OpenAIChatGenerator(model="something-obviously-wrong") | 
					
						
							| 
									
										
										
										
											2023-12-21 16:21:24 +01:00
										 |  |  |         with pytest.raises(OpenAIError): | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  |             component.run(chat_messages) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @pytest.mark.skipif( | 
					
						
							|  |  |  |         not os.environ.get("OPENAI_API_KEY", None), | 
					
						
							|  |  |  |         reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_live_run_streaming(self): | 
					
						
							|  |  |  |         class Callback: | 
					
						
							|  |  |  |             def __init__(self): | 
					
						
							|  |  |  |                 self.responses = "" | 
					
						
							|  |  |  |                 self.counter = 0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             def __call__(self, chunk: StreamingChunk) -> None: | 
					
						
							|  |  |  |                 self.counter += 1 | 
					
						
							|  |  |  |                 self.responses += chunk.content if chunk.content else "" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         callback = Callback() | 
					
						
							| 
									
										
										
										
											2023-12-22 19:37:29 +01:00
										 |  |  |         component = OpenAIChatGenerator(streaming_callback=callback) | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  |         results = component.run([ChatMessage.from_user("What's the capital of France?")]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         assert len(results["replies"]) == 1 | 
					
						
							|  |  |  |         message: ChatMessage = results["replies"][0] | 
					
						
							|  |  |  |         assert "Paris" in message.content | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-21 14:09:31 +01:00
										 |  |  |         assert "gpt-3.5" in message.meta["model"] | 
					
						
							|  |  |  |         assert message.meta["finish_reason"] == "stop" | 
					
						
							| 
									
										
										
										
											2023-11-09 10:45:41 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |         assert callback.counter > 1 | 
					
						
							|  |  |  |         assert "Paris" in callback.responses |