| 
									
										
										
										
											2023-05-30 18:04:52 +02:00
										 |  |  | from unittest.mock import patch | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import pytest | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-31 18:34:14 +02:00
										 |  |  | from haystack.cli.entry_point import main_cli | 
					
						
							| 
									
										
										
										
											2023-05-30 18:04:52 +02:00
										 |  |  | from haystack.nodes.prompt.prompt_template import PromptNotFoundError | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @pytest.mark.unit | 
					
						
							| 
									
										
										
										
											2023-05-31 18:34:14 +02:00
										 |  |  | @patch("haystack.cli.prompt.fetch.fetch_from_prompthub") | 
					
						
							|  |  |  | @patch("haystack.cli.prompt.fetch.cache_prompt") | 
					
						
							| 
									
										
										
										
											2023-05-30 18:04:52 +02:00
										 |  |  | def test_prompt_fetch_no_args(mock_cache, mock_fetch, cli_runner): | 
					
						
							|  |  |  |     response = cli_runner.invoke(main_cli, ["prompt", "fetch"]) | 
					
						
							|  |  |  |     assert response.exit_code == 0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     mock_fetch.assert_not_called() | 
					
						
							|  |  |  |     mock_cache.assert_not_called() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @pytest.mark.unit | 
					
						
							| 
									
										
										
										
											2023-05-31 18:34:14 +02:00
										 |  |  | @patch("haystack.cli.prompt.fetch.fetch_from_prompthub") | 
					
						
							|  |  |  | @patch("haystack.cli.prompt.fetch.cache_prompt") | 
					
						
							| 
									
										
										
										
											2023-05-30 18:04:52 +02:00
										 |  |  | def test_prompt_fetch(mock_cache, mock_fetch, cli_runner): | 
					
						
							|  |  |  |     response = cli_runner.invoke(main_cli, ["prompt", "fetch", "deepset/question-generation"]) | 
					
						
							|  |  |  |     assert response.exit_code == 0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     mock_fetch.assert_called_once_with("deepset/question-generation") | 
					
						
							|  |  |  |     mock_cache.assert_called_once() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @pytest.mark.unit | 
					
						
							| 
									
										
										
										
											2023-05-31 18:34:14 +02:00
										 |  |  | @patch("haystack.cli.prompt.fetch.fetch_from_prompthub") | 
					
						
							|  |  |  | @patch("haystack.cli.prompt.fetch.cache_prompt") | 
					
						
							| 
									
										
										
										
											2023-05-30 18:04:52 +02:00
										 |  |  | def test_prompt_fetch_with_multiple_prompts(mock_cache, mock_fetch, cli_runner): | 
					
						
							|  |  |  |     response = cli_runner.invoke( | 
					
						
							|  |  |  |         main_cli, ["prompt", "fetch", "deepset/question-generation", "deepset/conversational-agent"] | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     assert response.exit_code == 0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     assert mock_fetch.call_count == 2 | 
					
						
							|  |  |  |     mock_fetch.assert_any_call("deepset/question-generation") | 
					
						
							|  |  |  |     mock_fetch.assert_any_call("deepset/conversational-agent") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     assert mock_cache.call_count == 2 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @pytest.mark.unit | 
					
						
							| 
									
										
										
										
											2023-05-31 18:34:14 +02:00
										 |  |  | @patch("haystack.cli.prompt.fetch.fetch_from_prompthub") | 
					
						
							|  |  |  | @patch("haystack.cli.prompt.fetch.cache_prompt") | 
					
						
							| 
									
										
										
										
											2023-05-30 18:04:52 +02:00
										 |  |  | def test_prompt_fetch_with_unexisting_prompt(mock_cache, mock_fetch, cli_runner): | 
					
						
							|  |  |  |     prompt_name = "deepset/martian-speak" | 
					
						
							|  |  |  |     error_message = f"Prompt template named '{prompt_name}' not available in the Prompt Hub." | 
					
						
							|  |  |  |     mock_fetch.side_effect = PromptNotFoundError(error_message) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     response = cli_runner.invoke(main_cli, ["prompt", "fetch", prompt_name]) | 
					
						
							|  |  |  |     assert response.exit_code == 1 | 
					
						
							|  |  |  |     assert error_message in response.output | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     mock_fetch.assert_called_once_with(prompt_name) | 
					
						
							|  |  |  |     mock_cache.assert_not_called() |