From 3fd9e0fd89e8ff89348bbf638b3610f63e9940d0 Mon Sep 17 00:00:00 2001 From: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> Date: Tue, 30 May 2023 18:04:52 +0200 Subject: [PATCH] feat: Add CLI prompt cache command (#5050) * Add CLI prompt cache command * Rename prompt cache to prompt fetch --- haystack/preview/cli/entry_point.py | 4 ++ haystack/preview/cli/prompt/__init__.py | 11 +++++ haystack/preview/cli/prompt/fetch.py | 30 +++++++++++++ test/preview/cli/conftest.py | 7 +++ test/preview/cli/test_prompt_fetch.py | 60 +++++++++++++++++++++++++ 5 files changed, 112 insertions(+) create mode 100644 haystack/preview/cli/prompt/__init__.py create mode 100644 haystack/preview/cli/prompt/fetch.py create mode 100644 test/preview/cli/conftest.py create mode 100644 test/preview/cli/test_prompt_fetch.py diff --git a/haystack/preview/cli/entry_point.py b/haystack/preview/cli/entry_point.py index d6f93f13c..7c12c2445 100644 --- a/haystack/preview/cli/entry_point.py +++ b/haystack/preview/cli/entry_point.py @@ -1,6 +1,7 @@ import click from haystack import __version__ +from haystack.preview.cli.prompt import prompt @click.group() @@ -9,5 +10,8 @@ def main_cli(): pass +main_cli.add_command(prompt) + + def main(): main_cli() diff --git a/haystack/preview/cli/prompt/__init__.py b/haystack/preview/cli/prompt/__init__.py new file mode 100644 index 000000000..f282f3c4c --- /dev/null +++ b/haystack/preview/cli/prompt/__init__.py @@ -0,0 +1,11 @@ +import click + +from haystack.preview.cli.prompt import fetch + + +@click.group(short_help="Prompts related commands") +def prompt(): + pass + + +prompt.add_command(fetch.fetch) diff --git a/haystack/preview/cli/prompt/fetch.py b/haystack/preview/cli/prompt/fetch.py new file mode 100644 index 000000000..dba9f8ebc --- /dev/null +++ b/haystack/preview/cli/prompt/fetch.py @@ -0,0 +1,30 @@ +import click + +from haystack.nodes.prompt.prompt_template import PromptNotFoundError, fetch_from_prompthub, cache_prompt + + +@click.command( + short_help="Downloads and saves prompts from Haystack PromptHub", + help=""" + Downloads a prompt from the official Haystack PromptHub and saves + it locally to ease use in enviroments with no network. + + PROMPT_NAME can be specified multiple times. + + PROMPTHUB_CACHE_PATH environment variable can be set to change the + default folder in which the prompts will be saved in. + + If a custom PROMPTHUB_CACHE_PATH is used remember to also used it + for Haystack invocations. + + The Haystack PromptHub is https://prompthub.deepset.ai/ + """, +) +@click.argument("prompt_name", nargs=-1) +def fetch(prompt_name): + for name in prompt_name: + try: + data = fetch_from_prompthub(name) + except PromptNotFoundError as err: + raise click.ClickException(str(err)) from err + cache_prompt(data) diff --git a/test/preview/cli/conftest.py b/test/preview/cli/conftest.py new file mode 100644 index 000000000..33d1416fb --- /dev/null +++ b/test/preview/cli/conftest.py @@ -0,0 +1,7 @@ +import pytest +from click.testing import CliRunner + + +@pytest.fixture() +def cli_runner(): + yield CliRunner() diff --git a/test/preview/cli/test_prompt_fetch.py b/test/preview/cli/test_prompt_fetch.py new file mode 100644 index 000000000..2eba20d64 --- /dev/null +++ b/test/preview/cli/test_prompt_fetch.py @@ -0,0 +1,60 @@ +from unittest.mock import patch + +import pytest + +from haystack.preview.cli.entry_point import main_cli +from haystack.nodes.prompt.prompt_template import PromptNotFoundError + + +@pytest.mark.unit +@patch("haystack.preview.cli.prompt.fetch.fetch_from_prompthub") +@patch("haystack.preview.cli.prompt.fetch.cache_prompt") +def test_prompt_fetch_no_args(mock_cache, mock_fetch, cli_runner): + response = cli_runner.invoke(main_cli, ["prompt", "fetch"]) + assert response.exit_code == 0 + + mock_fetch.assert_not_called() + mock_cache.assert_not_called() + + +@pytest.mark.unit +@patch("haystack.preview.cli.prompt.fetch.fetch_from_prompthub") +@patch("haystack.preview.cli.prompt.fetch.cache_prompt") +def test_prompt_fetch(mock_cache, mock_fetch, cli_runner): + response = cli_runner.invoke(main_cli, ["prompt", "fetch", "deepset/question-generation"]) + assert response.exit_code == 0 + + mock_fetch.assert_called_once_with("deepset/question-generation") + mock_cache.assert_called_once() + + +@pytest.mark.unit +@patch("haystack.preview.cli.prompt.fetch.fetch_from_prompthub") +@patch("haystack.preview.cli.prompt.fetch.cache_prompt") +def test_prompt_fetch_with_multiple_prompts(mock_cache, mock_fetch, cli_runner): + response = cli_runner.invoke( + main_cli, ["prompt", "fetch", "deepset/question-generation", "deepset/conversational-agent"] + ) + assert response.exit_code == 0 + + assert mock_fetch.call_count == 2 + mock_fetch.assert_any_call("deepset/question-generation") + mock_fetch.assert_any_call("deepset/conversational-agent") + + assert mock_cache.call_count == 2 + + +@pytest.mark.unit +@patch("haystack.preview.cli.prompt.fetch.fetch_from_prompthub") +@patch("haystack.preview.cli.prompt.fetch.cache_prompt") +def test_prompt_fetch_with_unexisting_prompt(mock_cache, mock_fetch, cli_runner): + prompt_name = "deepset/martian-speak" + error_message = f"Prompt template named '{prompt_name}' not available in the Prompt Hub." + mock_fetch.side_effect = PromptNotFoundError(error_message) + + response = cli_runner.invoke(main_cli, ["prompt", "fetch", prompt_name]) + assert response.exit_code == 1 + assert error_message in response.output + + mock_fetch.assert_called_once_with(prompt_name) + mock_cache.assert_not_called()