mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-26 08:33:51 +00:00
feat: Add CLI prompt cache command (#5050)
* Add CLI prompt cache command * Rename prompt cache to prompt fetch
This commit is contained in:
parent
6249e65bc8
commit
3fd9e0fd89
@ -1,6 +1,7 @@
|
|||||||
import click
|
import click
|
||||||
|
|
||||||
from haystack import __version__
|
from haystack import __version__
|
||||||
|
from haystack.preview.cli.prompt import prompt
|
||||||
|
|
||||||
|
|
||||||
@click.group()
|
@click.group()
|
||||||
@ -9,5 +10,8 @@ def main_cli():
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
main_cli.add_command(prompt)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
main_cli()
|
main_cli()
|
||||||
|
11
haystack/preview/cli/prompt/__init__.py
Normal file
11
haystack/preview/cli/prompt/__init__.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
import click
|
||||||
|
|
||||||
|
from haystack.preview.cli.prompt import fetch
|
||||||
|
|
||||||
|
|
||||||
|
@click.group(short_help="Prompts related commands")
|
||||||
|
def prompt():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
prompt.add_command(fetch.fetch)
|
30
haystack/preview/cli/prompt/fetch.py
Normal file
30
haystack/preview/cli/prompt/fetch.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
import click
|
||||||
|
|
||||||
|
from haystack.nodes.prompt.prompt_template import PromptNotFoundError, fetch_from_prompthub, cache_prompt
|
||||||
|
|
||||||
|
|
||||||
|
@click.command(
|
||||||
|
short_help="Downloads and saves prompts from Haystack PromptHub",
|
||||||
|
help="""
|
||||||
|
Downloads a prompt from the official Haystack PromptHub and saves
|
||||||
|
it locally to ease use in enviroments with no network.
|
||||||
|
|
||||||
|
PROMPT_NAME can be specified multiple times.
|
||||||
|
|
||||||
|
PROMPTHUB_CACHE_PATH environment variable can be set to change the
|
||||||
|
default folder in which the prompts will be saved in.
|
||||||
|
|
||||||
|
If a custom PROMPTHUB_CACHE_PATH is used remember to also used it
|
||||||
|
for Haystack invocations.
|
||||||
|
|
||||||
|
The Haystack PromptHub is https://prompthub.deepset.ai/
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
@click.argument("prompt_name", nargs=-1)
|
||||||
|
def fetch(prompt_name):
|
||||||
|
for name in prompt_name:
|
||||||
|
try:
|
||||||
|
data = fetch_from_prompthub(name)
|
||||||
|
except PromptNotFoundError as err:
|
||||||
|
raise click.ClickException(str(err)) from err
|
||||||
|
cache_prompt(data)
|
7
test/preview/cli/conftest.py
Normal file
7
test/preview/cli/conftest.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
import pytest
|
||||||
|
from click.testing import CliRunner
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def cli_runner():
|
||||||
|
yield CliRunner()
|
60
test/preview/cli/test_prompt_fetch.py
Normal file
60
test/preview/cli/test_prompt_fetch.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from haystack.preview.cli.entry_point import main_cli
|
||||||
|
from haystack.nodes.prompt.prompt_template import PromptNotFoundError
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
@patch("haystack.preview.cli.prompt.fetch.fetch_from_prompthub")
|
||||||
|
@patch("haystack.preview.cli.prompt.fetch.cache_prompt")
|
||||||
|
def test_prompt_fetch_no_args(mock_cache, mock_fetch, cli_runner):
|
||||||
|
response = cli_runner.invoke(main_cli, ["prompt", "fetch"])
|
||||||
|
assert response.exit_code == 0
|
||||||
|
|
||||||
|
mock_fetch.assert_not_called()
|
||||||
|
mock_cache.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
@patch("haystack.preview.cli.prompt.fetch.fetch_from_prompthub")
|
||||||
|
@patch("haystack.preview.cli.prompt.fetch.cache_prompt")
|
||||||
|
def test_prompt_fetch(mock_cache, mock_fetch, cli_runner):
|
||||||
|
response = cli_runner.invoke(main_cli, ["prompt", "fetch", "deepset/question-generation"])
|
||||||
|
assert response.exit_code == 0
|
||||||
|
|
||||||
|
mock_fetch.assert_called_once_with("deepset/question-generation")
|
||||||
|
mock_cache.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
@patch("haystack.preview.cli.prompt.fetch.fetch_from_prompthub")
|
||||||
|
@patch("haystack.preview.cli.prompt.fetch.cache_prompt")
|
||||||
|
def test_prompt_fetch_with_multiple_prompts(mock_cache, mock_fetch, cli_runner):
|
||||||
|
response = cli_runner.invoke(
|
||||||
|
main_cli, ["prompt", "fetch", "deepset/question-generation", "deepset/conversational-agent"]
|
||||||
|
)
|
||||||
|
assert response.exit_code == 0
|
||||||
|
|
||||||
|
assert mock_fetch.call_count == 2
|
||||||
|
mock_fetch.assert_any_call("deepset/question-generation")
|
||||||
|
mock_fetch.assert_any_call("deepset/conversational-agent")
|
||||||
|
|
||||||
|
assert mock_cache.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
@patch("haystack.preview.cli.prompt.fetch.fetch_from_prompthub")
|
||||||
|
@patch("haystack.preview.cli.prompt.fetch.cache_prompt")
|
||||||
|
def test_prompt_fetch_with_unexisting_prompt(mock_cache, mock_fetch, cli_runner):
|
||||||
|
prompt_name = "deepset/martian-speak"
|
||||||
|
error_message = f"Prompt template named '{prompt_name}' not available in the Prompt Hub."
|
||||||
|
mock_fetch.side_effect = PromptNotFoundError(error_message)
|
||||||
|
|
||||||
|
response = cli_runner.invoke(main_cli, ["prompt", "fetch", prompt_name])
|
||||||
|
assert response.exit_code == 1
|
||||||
|
assert error_message in response.output
|
||||||
|
|
||||||
|
mock_fetch.assert_called_once_with(prompt_name)
|
||||||
|
mock_cache.assert_not_called()
|
Loading…
x
Reference in New Issue
Block a user