2023-04-03 11:49:49 +02:00
from typing import Set , Type , List
2023-05-23 15:22:58 +02:00
import textwrap
from unittest . mock import patch , MagicMock
2023-04-03 11:49:49 +02:00
import pytest
2023-05-23 15:22:58 +02:00
import prompthub
2023-04-03 11:49:49 +02:00
from haystack . nodes . prompt import PromptTemplate
2023-04-26 13:56:51 +02:00
from haystack . nodes . prompt . prompt_node import PromptNode
2023-04-03 11:49:49 +02:00
from haystack . nodes . prompt . prompt_template import PromptTemplateValidationError
2023-04-26 13:56:51 +02:00
from haystack . nodes . prompt . shapers import AnswerParser
from haystack . pipelines . base import Pipeline
2023-04-03 11:49:49 +02:00
from haystack . schema import Answer , Document
2023-05-23 15:22:58 +02:00
def mock_prompthub ( ) :
with patch ( " haystack.nodes.prompt.prompt_template.PromptTemplate._fetch_from_prompthub " ) as mock_prompthub :
mock_prompthub . side_effect = [
( " deepset/test-prompt " , " This is a test prompt. Use your knowledge to answer this question: {question} " )
]
yield mock_prompthub
@pytest.mark.unit
def test_prompt_templates_from_hub ( ) :
with patch ( " haystack.nodes.prompt.prompt_template.prompthub " ) as mock_prompthub :
PromptTemplate ( " deepset/question-answering " )
mock_prompthub . fetch . assert_called_with ( " deepset/question-answering " , timeout = 30 )
@pytest.mark.unit
def test_prompt_templates_from_file ( tmp_path ) :
path = tmp_path / " test-prompt.yml "
with open ( path , " a " ) as yamlfile :
yamlfile . write (
textwrap . dedent (
"""
name : deepset / question - answering
prompt_text : |
Given the context please answer the question . Context : { join ( documents ) } ;
Question : { query } ;
Answer :
description : A simple prompt to answer a question given a set of documents
tags :
- question - answering
meta :
authors :
- vblagoje
version : v0 .1 .1
"""
)
)
p = PromptTemplate ( str ( path . absolute ( ) ) )
assert p . name == " deepset/question-answering "
assert " Given the context please answer the question " in p . prompt_text
@pytest.mark.unit
def test_prompt_templates_on_the_fly ( ) :
with patch ( " haystack.nodes.prompt.prompt_template.yaml " ) as mocked_yaml :
with patch ( " haystack.nodes.prompt.prompt_template.prompthub " ) as mocked_ph :
p = PromptTemplate ( " This is a test prompt. Use your knowledge to answer this question: {question} " )
assert p . name == " custom-at-query-time "
mocked_ph . fetch . assert_not_called ( )
mocked_yaml . safe_load . assert_not_called ( )
2023-04-03 11:49:49 +02:00
@pytest.mark.unit
2023-05-23 15:22:58 +02:00
def test_custom_prompt_templates ( ) :
p = PromptTemplate ( " Here is some fake template with variable {foo} " )
2023-04-03 11:49:49 +02:00
assert set ( p . prompt_params ) == { " foo " }
2023-05-23 15:22:58 +02:00
p = PromptTemplate ( " Here is some fake template with variable {foo} and {bar} " )
2023-04-03 11:49:49 +02:00
assert set ( p . prompt_params ) == { " foo " , " bar " }
2023-05-23 15:22:58 +02:00
p = PromptTemplate ( " Here is some fake template with variable {foo1} and {bar2} " )
2023-04-03 11:49:49 +02:00
assert set ( p . prompt_params ) == { " foo1 " , " bar2 " }
2023-05-23 15:22:58 +02:00
p = PromptTemplate ( " Here is some fake template with variable {foo_1} and {bar_2} " )
2023-04-03 11:49:49 +02:00
assert set ( p . prompt_params ) == { " foo_1 " , " bar_2 " }
2023-05-23 15:22:58 +02:00
p = PromptTemplate ( " Here is some fake template with variable {Foo_1} and {Bar_2} " )
2023-04-03 11:49:49 +02:00
assert set ( p . prompt_params ) == { " Foo_1 " , " Bar_2 " }
2023-05-23 15:22:58 +02:00
p = PromptTemplate ( " ' Here is some fake template with variable {baz} ' " )
2023-04-03 11:49:49 +02:00
assert set ( p . prompt_params ) == { " baz " }
# strip single quotes, happens in YAML as we need to use single quotes for the template string
assert p . prompt_text == " Here is some fake template with variable {baz} "
2023-05-23 15:22:58 +02:00
p = PromptTemplate ( ' " Here is some fake template with variable {baz} " ' )
2023-04-03 11:49:49 +02:00
assert set ( p . prompt_params ) == { " baz " }
# strip double quotes, happens in YAML as we need to use single quotes for the template string
assert p . prompt_text == " Here is some fake template with variable {baz} "
2023-04-26 18:09:20 +02:00
@pytest.mark.unit
def test_missing_prompt_template_params ( ) :
2023-05-23 15:22:58 +02:00
template = PromptTemplate ( " Here is some fake template with variable {foo} and {bar} " )
2023-04-26 18:09:20 +02:00
# both params provided - ok
template . prepare ( foo = " foo " , bar = " bar " )
# missing one param
with pytest . raises ( ValueError , match = r " .*parameters \ [ ' bar ' , ' foo ' \ ] to be provided but got only \ [ ' foo ' \ ].* " ) :
template . prepare ( foo = " foo " )
# missing both params
with pytest . raises (
ValueError , match = r " .*parameters \ [ ' bar ' , ' foo ' \ ] to be provided but got none of these parameters.* "
) :
template . prepare ( lets = " go " )
# more than both params provided - also ok
template . prepare ( foo = " foo " , bar = " bar " , lets = " go " )
2023-04-03 11:49:49 +02:00
@pytest.mark.unit
def test_prompt_template_repr ( ) :
2023-05-23 15:22:58 +02:00
p = PromptTemplate ( " Here is variable {baz} " )
desired_repr = (
" PromptTemplate(name=custom-at-query-time, prompt_text=Here is variable {baz} , prompt_params=[ ' baz ' ]) "
)
2023-04-03 11:49:49 +02:00
assert repr ( p ) == desired_repr
assert str ( p ) == desired_repr
2023-04-26 13:56:51 +02:00
@pytest.mark.unit
@patch ( " haystack.nodes.prompt.prompt_node.PromptModel " )
def test_prompt_template_deserialization ( mock_prompt_model ) :
custom_prompt_template = PromptTemplate (
2023-05-23 15:22:58 +02:00
" Given the context please answer the question. Context: {context} ; Question: {query} ; Answer: " ,
2023-04-26 13:56:51 +02:00
output_parser = AnswerParser ( ) ,
)
prompt_node = PromptNode ( default_prompt_template = custom_prompt_template )
pipe = Pipeline ( )
pipe . add_node ( component = prompt_node , name = " Generator " , inputs = [ " Query " ] )
config = pipe . get_config ( )
loaded_pipe = Pipeline . load_from_config ( config )
loaded_generator = loaded_pipe . get_node ( " Generator " )
assert isinstance ( loaded_generator , PromptNode )
assert isinstance ( loaded_generator . default_prompt_template , PromptTemplate )
assert (
loaded_generator . default_prompt_template . prompt_text
== " Given the context please answer the question. Context: {context} ; Question: {query} ; Answer: "
)
assert isinstance ( loaded_generator . default_prompt_template . output_parser , AnswerParser )
2023-04-03 11:49:49 +02:00
class TestPromptTemplateSyntax :
@pytest.mark.unit
@pytest.mark.parametrize (
" prompt_text, expected_prompt_params, expected_used_functions " ,
[
( " {documents} " , { " documents " } , set ( ) ) ,
( " Please answer the question: {documents} Question: how? " , { " documents " } , set ( ) ) ,
( " Please answer the question: {documents} Question: {query} " , { " documents " , " query " } , set ( ) ) ,
( " Please answer the question: {documents} {{ Question}}: {query} " , { " documents " , " query " } , set ( ) ) ,
(
" Please answer the question: { join(documents)} Question: { query.replace( ' A ' , ' a ' )} " ,
{ " documents " , " query " } ,
{ " join " , " replace " } ,
) ,
(
" Please answer the question: { join(documents, ' delim ' , { ' { ' : ' ( ' })} Question: { query.replace( ' A ' , ' a ' )} " ,
{ " documents " , " query " } ,
{ " join " , " replace " } ,
) ,
(
' Please answer the question: { join(documents, " delim " , { " { " : " ( " })} Question: { query.replace( " A " , " a " )} ' ,
{ " documents " , " query " } ,
{ " join " , " replace " } ,
) ,
(
" Please answer the question: { join(documents, ' delim ' , { ' a ' : { ' b ' : ' c ' }})} Question: { query.replace( ' A ' , ' a ' )} " ,
{ " documents " , " query " } ,
{ " join " , " replace " } ,
) ,
(
" Please answer the question: { join(document=documents, delimiter= ' delim ' , str_replace= { ' { ' : ' ( ' })} Question: { query.replace( ' A ' , ' a ' )} " ,
{ " documents " , " query " } ,
{ " join " , " replace " } ,
) ,
] ,
)
def test_prompt_template_syntax_parser (
self , prompt_text : str , expected_prompt_params : Set [ str ] , expected_used_functions : Set [ str ]
) :
2023-05-23 15:22:58 +02:00
prompt_template = PromptTemplate ( prompt_text )
2023-04-03 11:49:49 +02:00
assert set ( prompt_template . prompt_params ) == expected_prompt_params
assert set ( prompt_template . _used_functions ) == expected_used_functions
@pytest.mark.unit
@pytest.mark.parametrize (
" prompt_text, documents, query, expected_prompts " ,
[
( " {documents} " , [ Document ( " doc1 " ) , Document ( " doc2 " ) ] , None , [ " doc1 " , " doc2 " ] ) ,
(
" context: {documents} question: how? " ,
[ Document ( " doc1 " ) , Document ( " doc2 " ) ] ,
None ,
[ " context: doc1 question: how? " , " context: doc2 question: how? " ] ,
) ,
(
" context: { ' ' .join([d.content for d in documents])} question: how? " ,
[ Document ( " doc1 " ) , Document ( " doc2 " ) ] ,
None ,
[ " context: doc1 doc2 question: how? " ] ,
) ,
(
" context: {documents} question: {query} " ,
[ Document ( " doc1 " ) , Document ( " doc2 " ) ] ,
" how? " ,
[ " context: doc1 question: how? " , " context: doc2 question: how? " ] ,
) ,
(
" context: {documents} {{ question}}: {query} " ,
[ Document ( " doc1 " ) ] ,
" how? " ,
[ " context: doc1 {question} : how? " ] ,
) ,
(
" context: { join(documents)} question: {query} " ,
[ Document ( " doc1 " ) , Document ( " doc2 " ) ] ,
" how? " ,
[ " context: doc1 doc2 question: how? " ] ,
) ,
(
" Please answer the question: { join(documents, ' delim ' , ' [$idx] $content ' , { ' { ' : ' ( ' })} question: {query} " ,
[ Document ( " doc1 " ) , Document ( " doc2 " ) ] ,
" how? " ,
[ " Please answer the question: [1] doc1 delim [2] doc2 question: how? " ] ,
) ,
(
" Please answer the question: { join(documents=documents, delimiter= ' delim ' , pattern= ' [$idx] $content ' , str_replace= { ' { ' : ' ( ' })} question: {query} " ,
[ Document ( " doc1 " ) , Document ( " doc2 " ) ] ,
" how? " ,
[ " Please answer the question: [1] doc1 delim [2] doc2 question: how? " ] ,
) ,
(
" Please answer the question: { ' delim ' .join([ ' [ ' +str(idx+1)+ ' ] ' +d.content.replace( ' { ' , ' ( ' ) for idx, d in enumerate(documents)])} question: {query} " ,
[ Document ( " doc1 " ) , Document ( " doc2 " ) ] ,
" how? " ,
[ " Please answer the question: [1] doc1 delim [2] doc2 question: how? " ] ,
) ,
(
' Please answer the question: { join(documents, " delim " , " [$idx] $content " , { " { " : " ( " })} question: {query} ' ,
[ Document ( " doc1 " ) , Document ( " doc2 " ) ] ,
" how? " ,
[ " Please answer the question: [1] doc1 delim [2] doc2 question: how? " ] ,
) ,
(
" context: { join(documents)} question: { query.replace( ' how ' , ' what ' )} " ,
[ Document ( " doc1 " ) , Document ( " doc2 " ) ] ,
" how? " ,
[ " context: doc1 doc2 question: what? " ] ,
) ,
(
" context: { join(documents)[:6]} question: { query.replace( ' how ' , ' what ' ).replace( ' ? ' , ' ! ' )} " ,
[ Document ( " doc1 " ) , Document ( " doc2 " ) ] ,
" how? " ,
[ " context: doc1 d question: what! " ] ,
) ,
2023-05-23 15:22:58 +02:00
( " context: " , None , None , [ " context: " ] ) ,
2023-04-03 11:49:49 +02:00
] ,
)
def test_prompt_template_syntax_fill (
self , prompt_text : str , documents : List [ Document ] , query : str , expected_prompts : List [ str ]
) :
2023-05-23 15:22:58 +02:00
prompt_template = PromptTemplate ( prompt_text )
2023-04-03 11:49:49 +02:00
prompts = [ prompt for prompt in prompt_template . fill ( documents = documents , query = query ) ]
assert prompts == expected_prompts
@pytest.mark.unit
@pytest.mark.parametrize (
" prompt_text, documents, expected_prompts " ,
[
( " { join(documents)} " , [ Document ( " doc1 " ) , Document ( " doc2 " ) ] , [ " doc1 doc2 " ] ) ,
(
" { join(documents, ' delim ' , ' [$idx] $content ' , { ' c ' : ' C ' })} " ,
[ Document ( " doc1 " ) , Document ( " doc2 " ) ] ,
[ " [1] doC1 delim [2] doC2 " ] ,
) ,
(
" { join(documents, ' delim ' , ' [$id] $content ' , { ' c ' : ' C ' })} " ,
[ Document ( " doc1 " , id = " 123 " ) , Document ( " doc2 " , id = " 456 " ) ] ,
[ " [123] doC1 delim [456] doC2 " ] ,
) ,
(
" { join(documents, ' delim ' , ' [$file_id] $content ' , { ' c ' : ' C ' })} " ,
[ Document ( " doc1 " , meta = { " file_id " : " 123.txt " } ) , Document ( " doc2 " , meta = { " file_id " : " 456.txt " } ) ] ,
[ " [123.txt] doC1 delim [456.txt] doC2 " ] ,
) ,
] ,
)
def test_join ( self , prompt_text : str , documents : List [ Document ] , expected_prompts : List [ str ] ) :
2023-05-23 15:22:58 +02:00
prompt_template = PromptTemplate ( prompt_text )
2023-04-03 11:49:49 +02:00
prompts = [ prompt for prompt in prompt_template . fill ( documents = documents ) ]
assert prompts == expected_prompts
@pytest.mark.unit
@pytest.mark.parametrize (
" prompt_text, documents, expected_prompts " ,
[
( " { to_strings(documents)} " , [ Document ( " doc1 " ) , Document ( " doc2 " ) ] , [ " doc1 " , " doc2 " ] ) ,
(
" { to_strings(documents, ' [$idx] $content ' , { ' c ' : ' C ' })} " ,
[ Document ( " doc1 " ) , Document ( " doc2 " ) ] ,
[ " [1] doC1 " , " [2] doC2 " ] ,
) ,
(
" { to_strings(documents, ' [$id] $content ' , { ' c ' : ' C ' })} " ,
[ Document ( " doc1 " , id = " 123 " ) , Document ( " doc2 " , id = " 456 " ) ] ,
[ " [123] doC1 " , " [456] doC2 " ] ,
) ,
(
" { to_strings(documents, ' [$file_id] $content ' , { ' c ' : ' C ' })} " ,
[ Document ( " doc1 " , meta = { " file_id " : " 123.txt " } ) , Document ( " doc2 " , meta = { " file_id " : " 456.txt " } ) ] ,
[ " [123.txt] doC1 " , " [456.txt] doC2 " ] ,
) ,
( " { to_strings(documents, ' [$file_id] $content ' , { ' c ' : ' C ' })} " , [ " doc1 " , " doc2 " ] , [ " doC1 " , " doC2 " ] ) ,
(
" { to_strings(documents, ' [$idx] $answer ' , { ' c ' : ' C ' })} " ,
[ Answer ( " doc1 " ) , Answer ( " doc2 " ) ] ,
[ " [1] doC1 " , " [2] doC2 " ] ,
) ,
] ,
)
def test_to_strings ( self , prompt_text : str , documents : List [ Document ] , expected_prompts : List [ str ] ) :
2023-05-23 15:22:58 +02:00
prompt_template = PromptTemplate ( prompt_text )
2023-04-03 11:49:49 +02:00
prompts = [ prompt for prompt in prompt_template . fill ( documents = documents ) ]
assert prompts == expected_prompts
@pytest.mark.unit
@pytest.mark.parametrize (
" prompt_text, exc_type, expected_exc_match " ,
[
( " { __import__( ' os ' ).listdir( ' . ' )} " , PromptTemplateValidationError , " Invalid function in prompt text " ) ,
( " { __import__( ' os ' )} " , PromptTemplateValidationError , " Invalid function in prompt text " ) ,
(
" { requests.get( ' https://haystack.deepset.ai/ ' )} " ,
PromptTemplateValidationError ,
" Invalid function in prompt text " ,
) ,
( " { join(__import__( ' os ' ).listdir( ' . ' ))} " , PromptTemplateValidationError , " Invalid function in prompt text " ) ,
( " {for} " , SyntaxError , " invalid syntax " ) ,
( " This is an invalid { variable . " , SyntaxError , " f-string: expecting ' } ' " ) ,
] ,
)
def test_prompt_template_syntax_init_raises (
self , prompt_text : str , exc_type : Type [ BaseException ] , expected_exc_match : str
) :
with pytest . raises ( exc_type , match = expected_exc_match ) :
2023-05-23 15:22:58 +02:00
PromptTemplate ( prompt_text )
2023-04-03 11:49:49 +02:00
@pytest.mark.unit
@pytest.mark.parametrize (
" prompt_text, documents, query, exc_type, expected_exc_match " ,
[ ( " {join} " , None , None , ValueError , " Expected prompt parameters " ) ] ,
)
def test_prompt_template_syntax_fill_raises (
self ,
prompt_text : str ,
documents : List [ Document ] ,
query : str ,
exc_type : Type [ BaseException ] ,
expected_exc_match : str ,
) :
with pytest . raises ( exc_type , match = expected_exc_match ) :
2023-05-23 15:22:58 +02:00
prompt_template = PromptTemplate ( prompt_text )
2023-04-03 11:49:49 +02:00
next ( prompt_template . fill ( documents = documents , query = query ) )
@pytest.mark.unit
@pytest.mark.parametrize (
" prompt_text, documents, query, expected_prompts " ,
[
( " __import__( ' os ' ).listdir( ' . ' ) " , None , None , [ " __import__( ' os ' ).listdir( ' . ' ) " ] ) ,
(
" requests.get( ' https://haystack.deepset.ai/ ' ) " ,
None ,
None ,
[ " requests.get( ' https://haystack.deepset.ai/ ' ) " ] ,
) ,
( " {query} " , None , print , [ " <built-in function print> " ] ) ,
( " \b \b __import__( ' os ' ).listdir( ' . ' ) " , None , None , [ " \x08 \x08 __import__( ' os ' ).listdir( ' . ' ) " ] ) ,
] ,
)
def test_prompt_template_syntax_fill_ignores_dangerous_input (
self , prompt_text : str , documents : List [ Document ] , query : str , expected_prompts : List [ str ]
) :
2023-05-23 15:22:58 +02:00
prompt_template = PromptTemplate ( prompt_text )
2023-04-03 11:49:49 +02:00
prompts = [ prompt for prompt in prompt_template . fill ( documents = documents , query = query ) ]
assert prompts == expected_prompts