| 
									
										
										
										
											2023-04-06 14:47:44 +02:00
										 |  |  | from unittest.mock import patch, Mock | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-03 11:49:49 +02:00
										 |  |  | import pytest | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from haystack.nodes.prompt.prompt_model import PromptModel | 
					
						
							| 
									
										
										
										
											2023-04-26 12:10:02 +01:00
										 |  |  | from haystack.nodes.prompt.invocation_layer import PromptModelInvocationLayer, HFLocalInvocationLayer | 
					
						
							| 
									
										
										
										
											2023-04-03 11:49:49 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-06 14:47:44 +02:00
										 |  |  | from .conftest import create_mock_layer_that_supports | 
					
						
							| 
									
										
										
										
											2023-04-03 11:49:49 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-06 14:47:44 +02:00
										 |  |  | @pytest.mark.unit | 
					
						
							|  |  |  | def test_constructor_with_default_model(): | 
					
						
							|  |  |  |     mock_layer = create_mock_layer_that_supports("google/flan-t5-base") | 
					
						
							|  |  |  |     another_layer = create_mock_layer_that_supports("another-model") | 
					
						
							| 
									
										
										
										
											2023-04-03 11:49:49 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-06 14:47:44 +02:00
										 |  |  |     with patch.object(PromptModelInvocationLayer, "invocation_layer_providers", new=[mock_layer, another_layer]): | 
					
						
							|  |  |  |         model = PromptModel() | 
					
						
							|  |  |  |         mock_layer.assert_called_once() | 
					
						
							|  |  |  |         another_layer.assert_not_called() | 
					
						
							|  |  |  |         model.model_invocation_layer.model_name_or_path = "google/flan-t5-base" | 
					
						
							| 
									
										
										
										
											2023-04-03 11:49:49 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-06 14:47:44 +02:00
										 |  |  | @pytest.mark.unit | 
					
						
							|  |  |  | def test_construtor_with_custom_model(): | 
					
						
							|  |  |  |     mock_layer = create_mock_layer_that_supports("some-model") | 
					
						
							|  |  |  |     another_layer = create_mock_layer_that_supports("another-model") | 
					
						
							| 
									
										
										
										
											2023-04-03 11:49:49 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-06 14:47:44 +02:00
										 |  |  |     with patch.object(PromptModelInvocationLayer, "invocation_layer_providers", new=[mock_layer, another_layer]): | 
					
						
							|  |  |  |         model = PromptModel("another-model") | 
					
						
							|  |  |  |         mock_layer.assert_not_called() | 
					
						
							|  |  |  |         another_layer.assert_called_once() | 
					
						
							|  |  |  |         model.model_invocation_layer.model_name_or_path = "another-model" | 
					
						
							| 
									
										
										
										
											2023-04-03 11:49:49 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-06 14:47:44 +02:00
										 |  |  | @pytest.mark.unit | 
					
						
							|  |  |  | def test_constructor_with_no_supported_model(): | 
					
						
							|  |  |  |     with pytest.raises(ValueError, match="Model some-random-model is not supported"): | 
					
						
							|  |  |  |         PromptModel("some-random-model") |