| 
									
										
										
										
											2024-01-05 15:48:28 +01:00
										 |  |  | import os | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import pytest | 
					
						
							|  |  |  | from openai import OpenAIError | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from haystack.components.generators.chat import AzureOpenAIChatGenerator | 
					
						
							| 
									
										
										
										
											2024-01-26 16:00:02 +01:00
										 |  |  | from haystack.components.generators.utils import print_streaming_chunk | 
					
						
							| 
									
										
										
										
											2024-01-05 15:48:28 +01:00
										 |  |  | from haystack.dataclasses import ChatMessage | 
					
						
							| 
									
										
										
										
											2024-02-05 13:17:01 +01:00
										 |  |  | from haystack.utils.auth import Secret | 
					
						
							| 
									
										
										
										
											2024-01-05 15:48:28 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class TestOpenAIChatGenerator: | 
					
						
							| 
									
										
										
										
											2024-02-05 13:17:01 +01:00
										 |  |  |     def test_init_default(self, monkeypatch): | 
					
						
							|  |  |  |         monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key") | 
					
						
							|  |  |  |         component = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint") | 
					
						
							| 
									
										
										
										
											2024-01-05 15:48:28 +01:00
										 |  |  |         assert component.client.api_key == "test-api-key" | 
					
						
							|  |  |  |         assert component.azure_deployment == "gpt-35-turbo" | 
					
						
							|  |  |  |         assert component.streaming_callback is None | 
					
						
							|  |  |  |         assert not component.generation_kwargs | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_init_fail_wo_api_key(self, monkeypatch): | 
					
						
							|  |  |  |         monkeypatch.delenv("AZURE_OPENAI_API_KEY", raising=False) | 
					
						
							| 
									
										
										
										
											2024-02-05 13:17:01 +01:00
										 |  |  |         monkeypatch.delenv("AZURE_OPENAI_AD_TOKEN", raising=False) | 
					
						
							| 
									
										
										
										
											2024-01-05 15:48:28 +01:00
										 |  |  |         with pytest.raises(OpenAIError): | 
					
						
							|  |  |  |             AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_init_with_parameters(self): | 
					
						
							|  |  |  |         component = AzureOpenAIChatGenerator( | 
					
						
							| 
									
										
										
										
											2024-02-05 13:17:01 +01:00
										 |  |  |             api_key=Secret.from_token("test-api-key"), | 
					
						
							| 
									
										
										
										
											2024-01-05 15:48:28 +01:00
										 |  |  |             azure_endpoint="some-non-existing-endpoint", | 
					
						
							| 
									
										
										
										
											2024-01-26 16:00:02 +01:00
										 |  |  |             streaming_callback=print_streaming_chunk, | 
					
						
							| 
									
										
										
										
											2024-01-05 15:48:28 +01:00
										 |  |  |             generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         assert component.client.api_key == "test-api-key" | 
					
						
							|  |  |  |         assert component.azure_deployment == "gpt-35-turbo" | 
					
						
							| 
									
										
										
										
											2024-01-26 16:00:02 +01:00
										 |  |  |         assert component.streaming_callback is print_streaming_chunk | 
					
						
							| 
									
										
										
										
											2024-01-05 15:48:28 +01:00
										 |  |  |         assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-05 13:17:01 +01:00
										 |  |  |     def test_to_dict_default(self, monkeypatch): | 
					
						
							|  |  |  |         monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key") | 
					
						
							|  |  |  |         component = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint") | 
					
						
							| 
									
										
										
										
											2024-01-05 15:48:28 +01:00
										 |  |  |         data = component.to_dict() | 
					
						
							|  |  |  |         assert data == { | 
					
						
							|  |  |  |             "type": "haystack.components.generators.chat.azure.AzureOpenAIChatGenerator", | 
					
						
							|  |  |  |             "init_parameters": { | 
					
						
							| 
									
										
										
										
											2024-02-05 13:17:01 +01:00
										 |  |  |                 "api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"}, | 
					
						
							|  |  |  |                 "azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"}, | 
					
						
							| 
									
										
										
										
											2024-01-05 15:48:28 +01:00
										 |  |  |                 "api_version": "2023-05-15", | 
					
						
							|  |  |  |                 "azure_endpoint": "some-non-existing-endpoint", | 
					
						
							|  |  |  |                 "azure_deployment": "gpt-35-turbo", | 
					
						
							|  |  |  |                 "organization": None, | 
					
						
							|  |  |  |                 "streaming_callback": None, | 
					
						
							|  |  |  |                 "generation_kwargs": {}, | 
					
						
							|  |  |  |             }, | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-05 13:17:01 +01:00
										 |  |  |     def test_to_dict_with_parameters(self, monkeypatch): | 
					
						
							|  |  |  |         monkeypatch.setenv("ENV_VAR", "test-api-key") | 
					
						
							| 
									
										
										
										
											2024-01-05 15:48:28 +01:00
										 |  |  |         component = AzureOpenAIChatGenerator( | 
					
						
							| 
									
										
										
										
											2024-02-05 13:17:01 +01:00
										 |  |  |             api_key=Secret.from_env_var("ENV_VAR", strict=False), | 
					
						
							|  |  |  |             azure_ad_token=Secret.from_env_var("ENV_VAR1", strict=False), | 
					
						
							| 
									
										
										
										
											2024-01-05 15:48:28 +01:00
										 |  |  |             azure_endpoint="some-non-existing-endpoint", | 
					
						
							|  |  |  |             generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         data = component.to_dict() | 
					
						
							|  |  |  |         assert data == { | 
					
						
							|  |  |  |             "type": "haystack.components.generators.chat.azure.AzureOpenAIChatGenerator", | 
					
						
							|  |  |  |             "init_parameters": { | 
					
						
							| 
									
										
										
										
											2024-02-05 13:17:01 +01:00
										 |  |  |                 "api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, | 
					
						
							|  |  |  |                 "azure_ad_token": {"env_vars": ["ENV_VAR1"], "strict": False, "type": "env_var"}, | 
					
						
							| 
									
										
										
										
											2024-01-05 15:48:28 +01:00
										 |  |  |                 "api_version": "2023-05-15", | 
					
						
							|  |  |  |                 "azure_endpoint": "some-non-existing-endpoint", | 
					
						
							|  |  |  |                 "azure_deployment": "gpt-35-turbo", | 
					
						
							|  |  |  |                 "organization": None, | 
					
						
							|  |  |  |                 "streaming_callback": None, | 
					
						
							|  |  |  |                 "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, | 
					
						
							|  |  |  |             }, | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     @pytest.mark.skipif( | 
					
						
							|  |  |  |         not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None), | 
					
						
							|  |  |  |         reason=( | 
					
						
							|  |  |  |             "Please export env variables called AZURE_OPENAI_API_KEY containing " | 
					
						
							|  |  |  |             "the Azure OpenAI key, AZURE_OPENAI_ENDPOINT containing " | 
					
						
							|  |  |  |             "the Azure OpenAI endpoint URL to run this test." | 
					
						
							|  |  |  |         ), | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     def test_live_run(self): | 
					
						
							|  |  |  |         chat_messages = [ChatMessage.from_user("What's the capital of France")] | 
					
						
							|  |  |  |         component = AzureOpenAIChatGenerator(organization="HaystackCI") | 
					
						
							|  |  |  |         results = component.run(chat_messages) | 
					
						
							|  |  |  |         assert len(results["replies"]) == 1 | 
					
						
							|  |  |  |         message: ChatMessage = results["replies"][0] | 
					
						
							|  |  |  |         assert "Paris" in message.content | 
					
						
							|  |  |  |         assert "gpt-35-turbo" in message.meta["model"] | 
					
						
							|  |  |  |         assert message.meta["finish_reason"] == "stop" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # additional tests intentionally omitted as they are covered by test_openai.py |