mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-08 13:06:29 +00:00
feat: prompts caching from PromptHub (#5048)
* split up prompttemplate init * caching * docstring * add platformdirs * use user_data_dir * fix tests * add tests * pylint * mypy
This commit is contained in:
parent
76a6eefe5e
commit
6249e65bc8
@ -93,7 +93,6 @@ class PromptNode(BaseComponent):
|
||||
},
|
||||
)
|
||||
super().__init__()
|
||||
self._prompt_templates_cache: Dict[str, PromptTemplate] = {}
|
||||
|
||||
# If we don't set _default_template here Pylint fails with error W0201 because it can't see that
|
||||
# it's set in default_prompt_template, so we set it explicitly to None to avoid it failing
|
||||
@ -214,16 +213,10 @@ class PromptNode(BaseComponent):
|
||||
if isinstance(prompt_template, PromptTemplate):
|
||||
return prompt_template
|
||||
|
||||
# If it's the name of a template that was used already
|
||||
if prompt_template in self._prompt_templates_cache:
|
||||
return self._prompt_templates_cache[prompt_template]
|
||||
|
||||
output_parser = None
|
||||
if self.default_prompt_template:
|
||||
output_parser = self.default_prompt_template.output_parser
|
||||
template = PromptTemplate(prompt_template, output_parser=output_parser)
|
||||
self._prompt_templates_cache[prompt_template] = template
|
||||
return template
|
||||
return PromptTemplate(prompt_template, output_parser=output_parser)
|
||||
|
||||
def prompt_template_params(self, prompt_template: str) -> List[str]:
|
||||
"""
|
||||
|
||||
@ -10,6 +10,7 @@ from abc import ABC
|
||||
from uuid import uuid4
|
||||
|
||||
import yaml
|
||||
from platformdirs import user_data_dir
|
||||
import tenacity
|
||||
import prompthub
|
||||
from requests import HTTPError, RequestException, JSONDecodeError
|
||||
@ -46,6 +47,10 @@ PROMPTHUB_TIMEOUT = float(os.environ.get(HAYSTACK_REMOTE_API_TIMEOUT_SEC, 30.0))
|
||||
PROMPTHUB_BACKOFF = float(os.environ.get(HAYSTACK_REMOTE_API_BACKOFF_SEC, 10.0))
|
||||
PROMPTHUB_MAX_RETRIES = int(os.environ.get(HAYSTACK_REMOTE_API_MAX_RETRIES, 5))
|
||||
|
||||
PROMPTHUB_CACHE_PATH = os.environ.get(
|
||||
"PROMPTHUB_CACHE_PATH", Path(user_data_dir("haystack", "deepset")) / "prompthub_cache"
|
||||
)
|
||||
|
||||
|
||||
#############################################################################
|
||||
# This templates were hardcoded in the prompt_template module. When adding
|
||||
@ -65,7 +70,7 @@ PROMPTHUB_MAX_RETRIES = int(os.environ.get(HAYSTACK_REMOTE_API_MAX_RETRIES, 5))
|
||||
#############################################################################
|
||||
|
||||
|
||||
LEGACY_DEFAULT_TEMPLATES: Dict[str, Dict[str, Any]] = {
|
||||
LEGACY_DEFAULT_TEMPLATES: Dict[str, Dict] = {
|
||||
# DO NOT ADD ANY NEW TEMPLATE IN HERE!
|
||||
"question-answering": {
|
||||
"prompt": "Given the context please answer the question. Context: {join(documents)}; Question: "
|
||||
@ -285,6 +290,39 @@ class _FstringParamsTransformer(ast.NodeTransformer):
|
||||
)
|
||||
|
||||
|
||||
@tenacity.retry(
|
||||
reraise=True,
|
||||
retry=tenacity.retry_if_exception_type((HTTPError, RequestException, JSONDecodeError)),
|
||||
wait=tenacity.wait_exponential(multiplier=PROMPTHUB_BACKOFF),
|
||||
stop=tenacity.stop_after_attempt(PROMPTHUB_MAX_RETRIES),
|
||||
)
|
||||
def fetch_from_prompthub(name: str) -> prompthub.Prompt:
|
||||
"""
|
||||
Looks for the given prompt in the PromptHub.
|
||||
|
||||
:param name: the name of the prompt on the Hub.
|
||||
:returns: the Prompt object.
|
||||
"""
|
||||
try:
|
||||
prompt_data: prompthub.Prompt = prompthub.fetch(name, timeout=PROMPTHUB_TIMEOUT)
|
||||
except HTTPError as http_error:
|
||||
if http_error.response.status_code != 404:
|
||||
raise http_error
|
||||
raise PromptNotFoundError(f"Prompt template named '{name}' not available in the Prompt Hub.")
|
||||
return prompt_data
|
||||
|
||||
|
||||
def cache_prompt(data: prompthub.Prompt):
|
||||
"""
|
||||
Saves the prompt to the cache. Helps avoiding naming mismatches in the cache folder.
|
||||
|
||||
:param data: the prompthub.Prompt object from PromptHub.
|
||||
"""
|
||||
path = Path(PROMPTHUB_CACHE_PATH) / f"{data.name}.yml"
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
data.to_yaml(path)
|
||||
|
||||
|
||||
class PromptTemplate(BasePromptTemplate, ABC):
|
||||
"""
|
||||
PromptTemplate is a template for the prompt you feed to the model to instruct it what to do. For example, if you want the model to perform sentiment analysis, you simply tell it to do that in a prompt. Here's what a prompt template may look like:
|
||||
@ -325,33 +363,17 @@ class PromptTemplate(BasePromptTemplate, ABC):
|
||||
name, prompt_text = "", ""
|
||||
|
||||
if prompt in LEGACY_DEFAULT_TEMPLATES:
|
||||
warnings.warn(
|
||||
f"You're using a legacy prompt template '{prompt}', "
|
||||
"we strongly suggest you use prompts from the official Haystack PromptHub: "
|
||||
"https://prompthub.deepset.ai/"
|
||||
)
|
||||
name = prompt
|
||||
prompt_text = LEGACY_DEFAULT_TEMPLATES[prompt]["prompt"]
|
||||
output_parser = LEGACY_DEFAULT_TEMPLATES[prompt].get("output_parser")
|
||||
prompt_text, output_parser = self._load_from_legacy_template(prompt)
|
||||
|
||||
# if it looks like a prompt template name
|
||||
elif re.fullmatch(r"[-a-zA-Z0-9_/]+", prompt):
|
||||
name = prompt
|
||||
try:
|
||||
prompt_text = self._fetch_from_prompthub(prompt)
|
||||
except HTTPError as http_error:
|
||||
if http_error.response.status_code != 404:
|
||||
raise http_error
|
||||
raise PromptNotFoundError(f"Prompt template named '{name}' not available in the Prompt Hub.")
|
||||
prompt_text = self._load_from_prompthub(prompt)
|
||||
|
||||
# if it's a path to a YAML file
|
||||
elif len(prompt) < 255 and Path(prompt).exists():
|
||||
with open(prompt, "r", encoding="utf-8") as yaml_file:
|
||||
prompt_template_parsed = yaml.safe_load(yaml_file.read())
|
||||
if not isinstance(prompt_template_parsed, dict):
|
||||
raise ValueError("The prompt loaded is not a prompt YAML file.")
|
||||
name = prompt_template_parsed["name"]
|
||||
prompt_text = prompt_template_parsed["prompt_text"]
|
||||
name, prompt_text = self._load_from_file(prompt)
|
||||
|
||||
# Otherwise it's a on-the-fly prompt text
|
||||
else:
|
||||
@ -395,27 +417,43 @@ class PromptTemplate(BasePromptTemplate, ABC):
|
||||
output_parser_params = output_parser.get("params", {})
|
||||
self.output_parser = BaseComponent._create_instance(output_parser_type, output_parser_params)
|
||||
|
||||
@property
|
||||
def output_variable(self) -> Optional[str]:
|
||||
return self.output_parser.output_variable if self.output_parser else None
|
||||
def _load_from_legacy_template(self, name: str) -> Tuple[str, Any]:
|
||||
warnings.warn(
|
||||
f"You're using a legacy prompt template '{name}', "
|
||||
"we strongly suggest you use prompts from the official Haystack PromptHub: "
|
||||
"https://prompthub.deepset.ai/"
|
||||
)
|
||||
prompt_text = LEGACY_DEFAULT_TEMPLATES[name]["prompt"]
|
||||
output_parser = LEGACY_DEFAULT_TEMPLATES[name].get("output_parser")
|
||||
return prompt_text, output_parser
|
||||
|
||||
@tenacity.retry(
|
||||
reraise=True,
|
||||
retry=tenacity.retry_if_exception_type((HTTPError, RequestException, JSONDecodeError)),
|
||||
wait=tenacity.wait_exponential(multiplier=PROMPTHUB_BACKOFF),
|
||||
stop=tenacity.stop_after_attempt(PROMPTHUB_MAX_RETRIES),
|
||||
)
|
||||
def _fetch_from_prompthub(self, name) -> str:
|
||||
"""
|
||||
Looks for the given prompt in the PromptHub if the prompt is not in the local cache.
|
||||
"""
|
||||
def _load_from_prompthub(self, name: str) -> str:
|
||||
prompt_path = Path(PROMPTHUB_CACHE_PATH) / f"{name}.yml"
|
||||
if Path(prompt_path).exists():
|
||||
return self._load_from_file(prompt_path)[1]
|
||||
try:
|
||||
prompt_data: prompthub.Prompt = prompthub.fetch(name, timeout=PROMPTHUB_TIMEOUT)
|
||||
data = fetch_from_prompthub(name)
|
||||
if os.environ.get("PROMPTHUB_CACHE_ENABLED", "true").lower() not in ("0", "false", "f"):
|
||||
cache_prompt(data)
|
||||
|
||||
except HTTPError as http_error:
|
||||
if http_error.response.status_code != 404:
|
||||
raise http_error
|
||||
raise PromptNotFoundError(f"Prompt template named '{name}' not available in the Prompt Hub.")
|
||||
return prompt_data.text
|
||||
return data.text
|
||||
|
||||
def _load_from_file(self, path: Union[Path, str]) -> Tuple[str, str]:
|
||||
with open(path, "r", encoding="utf-8") as yaml_file:
|
||||
prompt_template_parsed = yaml.safe_load(yaml_file.read())
|
||||
if not isinstance(prompt_template_parsed, dict):
|
||||
raise ValueError("The prompt loaded is not a prompt YAML file.")
|
||||
name = prompt_template_parsed["name"]
|
||||
prompt_text = prompt_template_parsed["text"]
|
||||
return name, prompt_text
|
||||
|
||||
@property
|
||||
def output_variable(self) -> Optional[str]:
|
||||
return self.output_parser.output_variable if self.output_parser else None
|
||||
|
||||
def prepare(self, *args, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
|
||||
@ -54,6 +54,7 @@ dependencies = [
|
||||
"scikit-learn>=1.0.0", # TF-IDF, SklearnQueryClassifier and metrics
|
||||
"generalimport", # Optional imports
|
||||
"prompthub-py==3.0.1",
|
||||
"platformdirs",
|
||||
|
||||
# Utils
|
||||
"tqdm", # progress bars in model download and training scripts
|
||||
|
||||
@ -8,7 +8,7 @@ from haystack.nodes import PromptNode
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init():
|
||||
with patch("haystack.nodes.prompt.prompt_template.PromptTemplate._fetch_from_prompthub") as mock_prompthub:
|
||||
with patch("haystack.nodes.prompt.prompt_template.fetch_from_prompthub") as mock_prompthub:
|
||||
mock_prompthub.side_effect = [("This is a test prompt. Use your knowledge to answer this question: {question}")]
|
||||
prompt_node = PromptNode()
|
||||
agent = ConversationalAgent(prompt_node)
|
||||
@ -25,7 +25,7 @@ def test_init():
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_with_summary_memory():
|
||||
with patch("haystack.nodes.prompt.prompt_template.PromptTemplate._fetch_from_prompthub") as mock_prompthub:
|
||||
with patch("haystack.nodes.prompt.prompt_template.fetch_from_prompthub") as mock_prompthub:
|
||||
mock_prompthub.side_effect = [("This is a test prompt. Use your knowledge to answer this question: {question}")]
|
||||
prompt_node = PromptNode(default_prompt_template="this is a test")
|
||||
# Test with summary memory
|
||||
@ -35,7 +35,7 @@ def test_init_with_summary_memory():
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_with_no_memory():
|
||||
with patch("haystack.nodes.prompt.prompt_template.PromptTemplate._fetch_from_prompthub") as mock_prompthub:
|
||||
with patch("haystack.nodes.prompt.prompt_template.fetch_from_prompthub") as mock_prompthub:
|
||||
mock_prompthub.side_effect = [("This is a test prompt. Use your knowledge to answer this question: {question}")]
|
||||
prompt_node = PromptNode()
|
||||
# Test with no memory
|
||||
@ -45,7 +45,7 @@ def test_init_with_no_memory():
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run():
|
||||
with patch("haystack.nodes.prompt.prompt_template.PromptTemplate._fetch_from_prompthub") as mock_prompthub:
|
||||
with patch("haystack.nodes.prompt.prompt_template.fetch_from_prompthub") as mock_prompthub:
|
||||
mock_prompthub.side_effect = [("This is a test prompt. Use your knowledge to answer this question: {question}")]
|
||||
prompt_node = PromptNode()
|
||||
agent = ConversationalAgent(prompt_node)
|
||||
|
||||
@ -78,6 +78,9 @@ META_FIELDS = [
|
||||
# Disable telemetry reports when running tests
|
||||
posthog.disabled = True
|
||||
|
||||
# Disable caching from prompthub to avoid polluting the local environment.
|
||||
os.environ["PROMPTHUB_CACHE_ENABLED"] = "false"
|
||||
|
||||
# Cache requests (e.g. huggingface model) to circumvent load protection
|
||||
# See https://requests-cache.readthedocs.io/en/stable/user_guide/filtering.html
|
||||
requests_cache.install_cache(urls_expire_after={"huggingface.co": timedelta(hours=1), "*": requests_cache.DO_NOT_CACHE})
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import Optional, Union, List, Dict, Any, Tuple
|
||||
from unittest.mock import patch, Mock, MagicMock
|
||||
|
||||
import pytest
|
||||
from prompthub import Prompt
|
||||
from transformers import GenerationConfig, TextStreamer
|
||||
|
||||
from haystack import Document, Pipeline, BaseComponent, MultiLabel
|
||||
@ -14,9 +15,14 @@ from haystack.nodes.prompt.invocation_layer import HFLocalInvocationLayer, Defau
|
||||
|
||||
@pytest.fixture
|
||||
def mock_prompthub():
|
||||
with patch("haystack.nodes.prompt.prompt_template.PromptTemplate._fetch_from_prompthub") as mock_prompthub:
|
||||
mock_prompthub.side_effect = (
|
||||
lambda name: "This is a test prompt. Use your knowledge to answer this question: {question}"
|
||||
with patch("haystack.nodes.prompt.prompt_template.fetch_from_prompthub") as mock_prompthub:
|
||||
mock_prompthub.return_value = Prompt(
|
||||
name="deepset/test",
|
||||
tags=["test"],
|
||||
meta={"author": "test"},
|
||||
version="v0.0.0",
|
||||
text="This is a test prompt. Use your knowledge to answer this question: {question}",
|
||||
description="test prompt",
|
||||
)
|
||||
yield mock_prompthub
|
||||
|
||||
@ -132,10 +138,10 @@ def test_get_prompt_template_local_file(mock_model, tmp_path, mock_prompthub):
|
||||
ptf.write(
|
||||
"""
|
||||
name: my_prompts/question-answering
|
||||
prompt_text: |
|
||||
Given the context please answer the question. Context: {join(documents)};
|
||||
Question: {query};
|
||||
Answer:
|
||||
text: |
|
||||
Given the context please answer the question. Context: {join(documents)};
|
||||
Question: {query};
|
||||
Answer:
|
||||
description: A simple prompt to answer a question given a set of documents
|
||||
tags:
|
||||
- question-answering
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from typing import Set, Type, List
|
||||
import textwrap
|
||||
import os
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
@ -8,17 +9,35 @@ import prompthub
|
||||
from haystack.nodes.prompt import PromptTemplate
|
||||
from haystack.nodes.prompt.prompt_node import PromptNode
|
||||
from haystack.nodes.prompt.prompt_template import PromptTemplateValidationError, LEGACY_DEFAULT_TEMPLATES
|
||||
from haystack.nodes.prompt import prompt_template
|
||||
from haystack.nodes.prompt.shapers import AnswerParser
|
||||
from haystack.pipelines.base import Pipeline
|
||||
from haystack.schema import Answer, Document
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def enable_prompthub_cache(monkeypatch):
|
||||
monkeypatch.setenv("PROMPTHUB_CACHE_ENABLED", True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prompthub_cache_path(monkeypatch, tmp_path):
|
||||
cache_path = tmp_path / "cache"
|
||||
monkeypatch.setattr(prompt_template, "PROMPTHUB_CACHE_PATH", cache_path)
|
||||
yield cache_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_prompthub():
|
||||
with patch("haystack.nodes.prompt.prompt_template.PromptTemplate._fetch_from_prompthub") as mock_prompthub:
|
||||
mock_prompthub.side_effect = [
|
||||
("deepset/test-prompt", "This is a test prompt. Use your knowledge to answer this question: {question}")
|
||||
]
|
||||
with patch("haystack.nodes.prompt.prompt_template.fetch_from_prompthub") as mock_prompthub:
|
||||
mock_prompthub.return_value = prompthub.Prompt(
|
||||
name="deepset/test-prompt",
|
||||
tags=["test"],
|
||||
meta={"author": "test"},
|
||||
version="v0.0.0",
|
||||
text="This is a test prompt. Use your knowledge to answer this question: {question}",
|
||||
description="test prompt",
|
||||
)
|
||||
yield mock_prompthub
|
||||
|
||||
|
||||
@ -29,6 +48,39 @@ def test_prompt_templates_from_hub():
|
||||
mock_prompthub.fetch.assert_called_with("deepset/question-answering", timeout=30)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_prompt_templates_from_hub_prompts_are_cached(prompthub_cache_path, enable_prompthub_cache, mock_prompthub):
|
||||
PromptTemplate("deepset/test-prompt")
|
||||
assert (prompthub_cache_path / "deepset" / "test-prompt.yml").exists()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_prompt_templates_from_hub_prompts_are_not_cached_if_disabled(prompthub_cache_path, mock_prompthub):
|
||||
PromptTemplate("deepset/test-prompt")
|
||||
assert not (prompthub_cache_path / "deepset" / "test-prompt.yml").exists()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_prompt_templates_from_hub_cached_prompts_are_used(
|
||||
prompthub_cache_path, enable_prompthub_cache, mock_prompthub
|
||||
):
|
||||
test_path = prompthub_cache_path / "deepset" / "another-test.yml"
|
||||
test_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
data = prompthub.Prompt(
|
||||
name="deepset/another-test",
|
||||
text="this is the prompt text",
|
||||
description="test prompt description",
|
||||
tags=["another-test"],
|
||||
meta={"authors": ["vblagoje"]},
|
||||
version="v0.1.1",
|
||||
)
|
||||
data.to_yaml(test_path)
|
||||
|
||||
template = PromptTemplate("deepset/another-test")
|
||||
mock_prompthub.fetch.assert_not_called()
|
||||
assert template.prompt_text == "this is the prompt text"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_prompt_templates_from_legacy_set(mock_prompthub):
|
||||
p = PromptTemplate("question-answering")
|
||||
@ -45,10 +97,10 @@ def test_prompt_templates_from_file(tmp_path):
|
||||
textwrap.dedent(
|
||||
"""
|
||||
name: deepset/question-answering
|
||||
prompt_text: |
|
||||
Given the context please answer the question. Context: {join(documents)};
|
||||
Question: {query};
|
||||
Answer:
|
||||
text: |
|
||||
Given the context please answer the question. Context: {join(documents)};
|
||||
Question: {query};
|
||||
Answer:
|
||||
description: A simple prompt to answer a question given a set of documents
|
||||
tags:
|
||||
- question-answering
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user