feat: Add CLI prompt cache command (#5050)

* Add CLI prompt cache command

* Rename prompt cache to prompt fetch
This commit is contained in:
Silvano Cerza 2023-05-30 18:04:52 +02:00 committed by GitHub
parent 6249e65bc8
commit 3fd9e0fd89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 112 additions and 0 deletions

View File

@ -1,6 +1,7 @@
import click
from haystack import __version__
from haystack.preview.cli.prompt import prompt
@click.group()
@ -9,5 +10,8 @@ def main_cli():
pass
main_cli.add_command(prompt)
def main():
main_cli()

View File

@ -0,0 +1,11 @@
import click
from haystack.preview.cli.prompt import fetch
@click.group(short_help="Prompts related commands")
def prompt():
pass
prompt.add_command(fetch.fetch)

View File

@ -0,0 +1,30 @@
import click
from haystack.nodes.prompt.prompt_template import PromptNotFoundError, fetch_from_prompthub, cache_prompt
@click.command(
short_help="Downloads and saves prompts from Haystack PromptHub",
help="""
Downloads a prompt from the official Haystack PromptHub and saves
it locally to ease use in enviroments with no network.
PROMPT_NAME can be specified multiple times.
PROMPTHUB_CACHE_PATH environment variable can be set to change the
default folder in which the prompts will be saved in.
If a custom PROMPTHUB_CACHE_PATH is used remember to also used it
for Haystack invocations.
The Haystack PromptHub is https://prompthub.deepset.ai/
""",
)
@click.argument("prompt_name", nargs=-1)
def fetch(prompt_name):
for name in prompt_name:
try:
data = fetch_from_prompthub(name)
except PromptNotFoundError as err:
raise click.ClickException(str(err)) from err
cache_prompt(data)

View File

@ -0,0 +1,7 @@
import pytest
from click.testing import CliRunner
@pytest.fixture()
def cli_runner():
yield CliRunner()

View File

@ -0,0 +1,60 @@
from unittest.mock import patch
import pytest
from haystack.preview.cli.entry_point import main_cli
from haystack.nodes.prompt.prompt_template import PromptNotFoundError
@pytest.mark.unit
@patch("haystack.preview.cli.prompt.fetch.fetch_from_prompthub")
@patch("haystack.preview.cli.prompt.fetch.cache_prompt")
def test_prompt_fetch_no_args(mock_cache, mock_fetch, cli_runner):
response = cli_runner.invoke(main_cli, ["prompt", "fetch"])
assert response.exit_code == 0
mock_fetch.assert_not_called()
mock_cache.assert_not_called()
@pytest.mark.unit
@patch("haystack.preview.cli.prompt.fetch.fetch_from_prompthub")
@patch("haystack.preview.cli.prompt.fetch.cache_prompt")
def test_prompt_fetch(mock_cache, mock_fetch, cli_runner):
response = cli_runner.invoke(main_cli, ["prompt", "fetch", "deepset/question-generation"])
assert response.exit_code == 0
mock_fetch.assert_called_once_with("deepset/question-generation")
mock_cache.assert_called_once()
@pytest.mark.unit
@patch("haystack.preview.cli.prompt.fetch.fetch_from_prompthub")
@patch("haystack.preview.cli.prompt.fetch.cache_prompt")
def test_prompt_fetch_with_multiple_prompts(mock_cache, mock_fetch, cli_runner):
response = cli_runner.invoke(
main_cli, ["prompt", "fetch", "deepset/question-generation", "deepset/conversational-agent"]
)
assert response.exit_code == 0
assert mock_fetch.call_count == 2
mock_fetch.assert_any_call("deepset/question-generation")
mock_fetch.assert_any_call("deepset/conversational-agent")
assert mock_cache.call_count == 2
@pytest.mark.unit
@patch("haystack.preview.cli.prompt.fetch.fetch_from_prompthub")
@patch("haystack.preview.cli.prompt.fetch.cache_prompt")
def test_prompt_fetch_with_unexisting_prompt(mock_cache, mock_fetch, cli_runner):
prompt_name = "deepset/martian-speak"
error_message = f"Prompt template named '{prompt_name}' not available in the Prompt Hub."
mock_fetch.side_effect = PromptNotFoundError(error_message)
response = cli_runner.invoke(main_cli, ["prompt", "fetch", prompt_name])
assert response.exit_code == 1
assert error_message in response.output
mock_fetch.assert_called_once_with(prompt_name)
mock_cache.assert_not_called()