feat: prompts caching from PromptHub (#5048)

* split up prompttemplate init

* caching

* docstring

* add platformdirs

* use user_data_dir

* fix tests

* add tests

* pylint

* mypy
This commit is contained in:
ZanSara 2023-05-30 16:55:48 +02:00 committed by GitHub
parent 76a6eefe5e
commit 6249e65bc8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 155 additions and 62 deletions

View File

@ -93,7 +93,6 @@ class PromptNode(BaseComponent):
},
)
super().__init__()
self._prompt_templates_cache: Dict[str, PromptTemplate] = {}
# If we don't set _default_template here Pylint fails with error W0201 because it can't see that
# it's set in default_prompt_template, so we set it explicitly to None to avoid it failing
@ -214,16 +213,10 @@ class PromptNode(BaseComponent):
if isinstance(prompt_template, PromptTemplate):
return prompt_template
# If it's the name of a template that was used already
if prompt_template in self._prompt_templates_cache:
return self._prompt_templates_cache[prompt_template]
output_parser = None
if self.default_prompt_template:
output_parser = self.default_prompt_template.output_parser
template = PromptTemplate(prompt_template, output_parser=output_parser)
self._prompt_templates_cache[prompt_template] = template
return template
return PromptTemplate(prompt_template, output_parser=output_parser)
def prompt_template_params(self, prompt_template: str) -> List[str]:
"""

View File

@ -10,6 +10,7 @@ from abc import ABC
from uuid import uuid4
import yaml
from platformdirs import user_data_dir
import tenacity
import prompthub
from requests import HTTPError, RequestException, JSONDecodeError
@ -46,6 +47,10 @@ PROMPTHUB_TIMEOUT = float(os.environ.get(HAYSTACK_REMOTE_API_TIMEOUT_SEC, 30.0))
PROMPTHUB_BACKOFF = float(os.environ.get(HAYSTACK_REMOTE_API_BACKOFF_SEC, 10.0))
PROMPTHUB_MAX_RETRIES = int(os.environ.get(HAYSTACK_REMOTE_API_MAX_RETRIES, 5))
PROMPTHUB_CACHE_PATH = os.environ.get(
"PROMPTHUB_CACHE_PATH", Path(user_data_dir("haystack", "deepset")) / "prompthub_cache"
)
#############################################################################
# This templates were hardcoded in the prompt_template module. When adding
@ -65,7 +70,7 @@ PROMPTHUB_MAX_RETRIES = int(os.environ.get(HAYSTACK_REMOTE_API_MAX_RETRIES, 5))
#############################################################################
LEGACY_DEFAULT_TEMPLATES: Dict[str, Dict[str, Any]] = {
LEGACY_DEFAULT_TEMPLATES: Dict[str, Dict] = {
# DO NOT ADD ANY NEW TEMPLATE IN HERE!
"question-answering": {
"prompt": "Given the context please answer the question. Context: {join(documents)}; Question: "
@ -285,6 +290,39 @@ class _FstringParamsTransformer(ast.NodeTransformer):
)
@tenacity.retry(
reraise=True,
retry=tenacity.retry_if_exception_type((HTTPError, RequestException, JSONDecodeError)),
wait=tenacity.wait_exponential(multiplier=PROMPTHUB_BACKOFF),
stop=tenacity.stop_after_attempt(PROMPTHUB_MAX_RETRIES),
)
def fetch_from_prompthub(name: str) -> prompthub.Prompt:
"""
Looks for the given prompt in the PromptHub.
:param name: the name of the prompt on the Hub.
:returns: the Prompt object.
"""
try:
prompt_data: prompthub.Prompt = prompthub.fetch(name, timeout=PROMPTHUB_TIMEOUT)
except HTTPError as http_error:
if http_error.response.status_code != 404:
raise http_error
raise PromptNotFoundError(f"Prompt template named '{name}' not available in the Prompt Hub.")
return prompt_data
def cache_prompt(data: prompthub.Prompt):
"""
Saves the prompt to the cache. Helps avoiding naming mismatches in the cache folder.
:param data: the prompthub.Prompt object from PromptHub.
"""
path = Path(PROMPTHUB_CACHE_PATH) / f"{data.name}.yml"
path.parent.mkdir(parents=True, exist_ok=True)
data.to_yaml(path)
class PromptTemplate(BasePromptTemplate, ABC):
"""
PromptTemplate is a template for the prompt you feed to the model to instruct it what to do. For example, if you want the model to perform sentiment analysis, you simply tell it to do that in a prompt. Here's what a prompt template may look like:
@ -325,33 +363,17 @@ class PromptTemplate(BasePromptTemplate, ABC):
name, prompt_text = "", ""
if prompt in LEGACY_DEFAULT_TEMPLATES:
warnings.warn(
f"You're using a legacy prompt template '{prompt}', "
"we strongly suggest you use prompts from the official Haystack PromptHub: "
"https://prompthub.deepset.ai/"
)
name = prompt
prompt_text = LEGACY_DEFAULT_TEMPLATES[prompt]["prompt"]
output_parser = LEGACY_DEFAULT_TEMPLATES[prompt].get("output_parser")
prompt_text, output_parser = self._load_from_legacy_template(prompt)
# if it looks like a prompt template name
elif re.fullmatch(r"[-a-zA-Z0-9_/]+", prompt):
name = prompt
try:
prompt_text = self._fetch_from_prompthub(prompt)
except HTTPError as http_error:
if http_error.response.status_code != 404:
raise http_error
raise PromptNotFoundError(f"Prompt template named '{name}' not available in the Prompt Hub.")
prompt_text = self._load_from_prompthub(prompt)
# if it's a path to a YAML file
elif len(prompt) < 255 and Path(prompt).exists():
with open(prompt, "r", encoding="utf-8") as yaml_file:
prompt_template_parsed = yaml.safe_load(yaml_file.read())
if not isinstance(prompt_template_parsed, dict):
raise ValueError("The prompt loaded is not a prompt YAML file.")
name = prompt_template_parsed["name"]
prompt_text = prompt_template_parsed["prompt_text"]
name, prompt_text = self._load_from_file(prompt)
# Otherwise it's a on-the-fly prompt text
else:
@ -395,27 +417,43 @@ class PromptTemplate(BasePromptTemplate, ABC):
output_parser_params = output_parser.get("params", {})
self.output_parser = BaseComponent._create_instance(output_parser_type, output_parser_params)
@property
def output_variable(self) -> Optional[str]:
return self.output_parser.output_variable if self.output_parser else None
def _load_from_legacy_template(self, name: str) -> Tuple[str, Any]:
warnings.warn(
f"You're using a legacy prompt template '{name}', "
"we strongly suggest you use prompts from the official Haystack PromptHub: "
"https://prompthub.deepset.ai/"
)
prompt_text = LEGACY_DEFAULT_TEMPLATES[name]["prompt"]
output_parser = LEGACY_DEFAULT_TEMPLATES[name].get("output_parser")
return prompt_text, output_parser
@tenacity.retry(
reraise=True,
retry=tenacity.retry_if_exception_type((HTTPError, RequestException, JSONDecodeError)),
wait=tenacity.wait_exponential(multiplier=PROMPTHUB_BACKOFF),
stop=tenacity.stop_after_attempt(PROMPTHUB_MAX_RETRIES),
)
def _fetch_from_prompthub(self, name) -> str:
"""
Looks for the given prompt in the PromptHub if the prompt is not in the local cache.
"""
def _load_from_prompthub(self, name: str) -> str:
prompt_path = Path(PROMPTHUB_CACHE_PATH) / f"{name}.yml"
if Path(prompt_path).exists():
return self._load_from_file(prompt_path)[1]
try:
prompt_data: prompthub.Prompt = prompthub.fetch(name, timeout=PROMPTHUB_TIMEOUT)
data = fetch_from_prompthub(name)
if os.environ.get("PROMPTHUB_CACHE_ENABLED", "true").lower() not in ("0", "false", "f"):
cache_prompt(data)
except HTTPError as http_error:
if http_error.response.status_code != 404:
raise http_error
raise PromptNotFoundError(f"Prompt template named '{name}' not available in the Prompt Hub.")
return prompt_data.text
return data.text
def _load_from_file(self, path: Union[Path, str]) -> Tuple[str, str]:
with open(path, "r", encoding="utf-8") as yaml_file:
prompt_template_parsed = yaml.safe_load(yaml_file.read())
if not isinstance(prompt_template_parsed, dict):
raise ValueError("The prompt loaded is not a prompt YAML file.")
name = prompt_template_parsed["name"]
prompt_text = prompt_template_parsed["text"]
return name, prompt_text
@property
def output_variable(self) -> Optional[str]:
return self.output_parser.output_variable if self.output_parser else None
def prepare(self, *args, **kwargs) -> Dict[str, Any]:
"""

View File

@ -54,6 +54,7 @@ dependencies = [
"scikit-learn>=1.0.0", # TF-IDF, SklearnQueryClassifier and metrics
"generalimport", # Optional imports
"prompthub-py==3.0.1",
"platformdirs",
# Utils
"tqdm", # progress bars in model download and training scripts

View File

@ -8,7 +8,7 @@ from haystack.nodes import PromptNode
@pytest.mark.unit
def test_init():
with patch("haystack.nodes.prompt.prompt_template.PromptTemplate._fetch_from_prompthub") as mock_prompthub:
with patch("haystack.nodes.prompt.prompt_template.fetch_from_prompthub") as mock_prompthub:
mock_prompthub.side_effect = [("This is a test prompt. Use your knowledge to answer this question: {question}")]
prompt_node = PromptNode()
agent = ConversationalAgent(prompt_node)
@ -25,7 +25,7 @@ def test_init():
@pytest.mark.unit
def test_init_with_summary_memory():
with patch("haystack.nodes.prompt.prompt_template.PromptTemplate._fetch_from_prompthub") as mock_prompthub:
with patch("haystack.nodes.prompt.prompt_template.fetch_from_prompthub") as mock_prompthub:
mock_prompthub.side_effect = [("This is a test prompt. Use your knowledge to answer this question: {question}")]
prompt_node = PromptNode(default_prompt_template="this is a test")
# Test with summary memory
@ -35,7 +35,7 @@ def test_init_with_summary_memory():
@pytest.mark.unit
def test_init_with_no_memory():
with patch("haystack.nodes.prompt.prompt_template.PromptTemplate._fetch_from_prompthub") as mock_prompthub:
with patch("haystack.nodes.prompt.prompt_template.fetch_from_prompthub") as mock_prompthub:
mock_prompthub.side_effect = [("This is a test prompt. Use your knowledge to answer this question: {question}")]
prompt_node = PromptNode()
# Test with no memory
@ -45,7 +45,7 @@ def test_init_with_no_memory():
@pytest.mark.unit
def test_run():
with patch("haystack.nodes.prompt.prompt_template.PromptTemplate._fetch_from_prompthub") as mock_prompthub:
with patch("haystack.nodes.prompt.prompt_template.fetch_from_prompthub") as mock_prompthub:
mock_prompthub.side_effect = [("This is a test prompt. Use your knowledge to answer this question: {question}")]
prompt_node = PromptNode()
agent = ConversationalAgent(prompt_node)

View File

@ -78,6 +78,9 @@ META_FIELDS = [
# Disable telemetry reports when running tests
posthog.disabled = True
# Disable caching from prompthub to avoid polluting the local environment.
os.environ["PROMPTHUB_CACHE_ENABLED"] = "false"
# Cache requests (e.g. huggingface model) to circumvent load protection
# See https://requests-cache.readthedocs.io/en/stable/user_guide/filtering.html
requests_cache.install_cache(urls_expire_after={"huggingface.co": timedelta(hours=1), "*": requests_cache.DO_NOT_CACHE})

View File

@ -4,6 +4,7 @@ from typing import Optional, Union, List, Dict, Any, Tuple
from unittest.mock import patch, Mock, MagicMock
import pytest
from prompthub import Prompt
from transformers import GenerationConfig, TextStreamer
from haystack import Document, Pipeline, BaseComponent, MultiLabel
@ -14,9 +15,14 @@ from haystack.nodes.prompt.invocation_layer import HFLocalInvocationLayer, Defau
@pytest.fixture
def mock_prompthub():
with patch("haystack.nodes.prompt.prompt_template.PromptTemplate._fetch_from_prompthub") as mock_prompthub:
mock_prompthub.side_effect = (
lambda name: "This is a test prompt. Use your knowledge to answer this question: {question}"
with patch("haystack.nodes.prompt.prompt_template.fetch_from_prompthub") as mock_prompthub:
mock_prompthub.return_value = Prompt(
name="deepset/test",
tags=["test"],
meta={"author": "test"},
version="v0.0.0",
text="This is a test prompt. Use your knowledge to answer this question: {question}",
description="test prompt",
)
yield mock_prompthub
@ -132,10 +138,10 @@ def test_get_prompt_template_local_file(mock_model, tmp_path, mock_prompthub):
ptf.write(
"""
name: my_prompts/question-answering
prompt_text: |
Given the context please answer the question. Context: {join(documents)};
Question: {query};
Answer:
text: |
Given the context please answer the question. Context: {join(documents)};
Question: {query};
Answer:
description: A simple prompt to answer a question given a set of documents
tags:
- question-answering

View File

@ -1,5 +1,6 @@
from typing import Set, Type, List
import textwrap
import os
from unittest.mock import patch, MagicMock
import pytest
@ -8,17 +9,35 @@ import prompthub
from haystack.nodes.prompt import PromptTemplate
from haystack.nodes.prompt.prompt_node import PromptNode
from haystack.nodes.prompt.prompt_template import PromptTemplateValidationError, LEGACY_DEFAULT_TEMPLATES
from haystack.nodes.prompt import prompt_template
from haystack.nodes.prompt.shapers import AnswerParser
from haystack.pipelines.base import Pipeline
from haystack.schema import Answer, Document
@pytest.fixture
def enable_prompthub_cache(monkeypatch):
monkeypatch.setenv("PROMPTHUB_CACHE_ENABLED", True)
@pytest.fixture
def prompthub_cache_path(monkeypatch, tmp_path):
cache_path = tmp_path / "cache"
monkeypatch.setattr(prompt_template, "PROMPTHUB_CACHE_PATH", cache_path)
yield cache_path
@pytest.fixture
def mock_prompthub():
with patch("haystack.nodes.prompt.prompt_template.PromptTemplate._fetch_from_prompthub") as mock_prompthub:
mock_prompthub.side_effect = [
("deepset/test-prompt", "This is a test prompt. Use your knowledge to answer this question: {question}")
]
with patch("haystack.nodes.prompt.prompt_template.fetch_from_prompthub") as mock_prompthub:
mock_prompthub.return_value = prompthub.Prompt(
name="deepset/test-prompt",
tags=["test"],
meta={"author": "test"},
version="v0.0.0",
text="This is a test prompt. Use your knowledge to answer this question: {question}",
description="test prompt",
)
yield mock_prompthub
@ -29,6 +48,39 @@ def test_prompt_templates_from_hub():
mock_prompthub.fetch.assert_called_with("deepset/question-answering", timeout=30)
@pytest.mark.unit
def test_prompt_templates_from_hub_prompts_are_cached(prompthub_cache_path, enable_prompthub_cache, mock_prompthub):
PromptTemplate("deepset/test-prompt")
assert (prompthub_cache_path / "deepset" / "test-prompt.yml").exists()
@pytest.mark.unit
def test_prompt_templates_from_hub_prompts_are_not_cached_if_disabled(prompthub_cache_path, mock_prompthub):
PromptTemplate("deepset/test-prompt")
assert not (prompthub_cache_path / "deepset" / "test-prompt.yml").exists()
@pytest.mark.unit
def test_prompt_templates_from_hub_cached_prompts_are_used(
prompthub_cache_path, enable_prompthub_cache, mock_prompthub
):
test_path = prompthub_cache_path / "deepset" / "another-test.yml"
test_path.parent.mkdir(parents=True, exist_ok=True)
data = prompthub.Prompt(
name="deepset/another-test",
text="this is the prompt text",
description="test prompt description",
tags=["another-test"],
meta={"authors": ["vblagoje"]},
version="v0.1.1",
)
data.to_yaml(test_path)
template = PromptTemplate("deepset/another-test")
mock_prompthub.fetch.assert_not_called()
assert template.prompt_text == "this is the prompt text"
@pytest.mark.unit
def test_prompt_templates_from_legacy_set(mock_prompthub):
p = PromptTemplate("question-answering")
@ -45,10 +97,10 @@ def test_prompt_templates_from_file(tmp_path):
textwrap.dedent(
"""
name: deepset/question-answering
prompt_text: |
Given the context please answer the question. Context: {join(documents)};
Question: {query};
Answer:
text: |
Given the context please answer the question. Context: {join(documents)};
Question: {query};
Answer:
description: A simple prompt to answer a question given a set of documents
tags:
- question-answering