2022-12-20 11:21:26 +01:00
import os
2023-02-08 10:39:40 +01:00
import logging
2023-03-27 12:14:11 +02:00
from typing import Optional , Set , Type , Union , List , Dict , Any , Tuple
2022-12-20 11:21:26 +01:00
import pytest
import torch
2023-01-19 11:17:06 +01:00
from haystack import Document , Pipeline , BaseComponent , MultiLabel
2022-12-20 11:21:26 +01:00
from haystack . errors import OpenAIError
from haystack . nodes . prompt import PromptTemplate , PromptNode , PromptModel
2023-03-06 16:34:59 +01:00
from haystack . nodes . prompt import PromptModelInvocationLayer
2023-03-27 12:14:11 +02:00
from haystack . nodes . prompt . prompt_node import PromptTemplateValidationError
2023-03-15 18:24:47 +01:00
from haystack . nodes . prompt . providers import HFLocalInvocationLayer , TokenStreamingHandler
2023-03-27 12:14:11 +02:00
from haystack . schema import Answer
2022-12-20 11:21:26 +01:00
2023-03-02 09:55:09 +01:00
def skip_test_for_invalid_key ( prompt_model ) :
if prompt_model . api_key is not None and prompt_model . api_key == " KEY_NOT_FOUND " :
pytest . skip ( " No API key found, skipping test " )
2023-03-15 18:24:47 +01:00
class TestTokenStreamingHandler ( TokenStreamingHandler ) :
stream_handler_invoked = False
def __call__ ( self , token_received , * args , * * kwargs ) - > str :
"""
This callback method is called when a new token is received from the stream .
: param token_received : The token received from the stream .
: param kwargs : Additional keyword arguments passed to the underlying model .
: return : The token to be sent to the stream .
"""
self . stream_handler_invoked = True
return token_received
2023-03-06 16:34:59 +01:00
class CustomInvocationLayer ( PromptModelInvocationLayer ) :
def __init__ ( self , * args , * * kwargs ) :
super ( ) . __init__ ( * args , * * kwargs )
def invoke ( self , * args , * * kwargs ) :
return [ " fake_response " ]
def _ensure_token_limit ( self , prompt : str ) - > str :
return prompt
@classmethod
def supports ( cls , model_name_or_path : str , * * kwargs ) - > bool :
return model_name_or_path == " fake_model "
2023-03-02 09:55:09 +01:00
@pytest.fixture
def get_api_key ( request ) :
if request . param == " openai " :
return os . environ . get ( " OPENAI_API_KEY " , None )
elif request . param == " azure " :
return os . environ . get ( " AZURE_OPENAI_API_KEY " , None )
2022-12-20 11:21:26 +01:00
2023-02-27 15:00:19 +01:00
@pytest.mark.unit
2022-12-20 11:21:26 +01:00
def test_prompt_templates ( ) :
2023-03-27 12:14:11 +02:00
p = PromptTemplate ( " t1 " , " Here is some fake template with variable {foo} " )
assert set ( p . prompt_params ) == { " foo " }
2022-12-20 11:21:26 +01:00
2023-03-27 12:14:11 +02:00
p = PromptTemplate ( " t3 " , " Here is some fake template with variable {foo} and {bar} " )
assert set ( p . prompt_params ) == { " foo " , " bar " }
2022-12-20 11:21:26 +01:00
2023-03-27 12:14:11 +02:00
p = PromptTemplate ( " t4 " , " Here is some fake template with variable {foo1} and {bar2} " )
assert set ( p . prompt_params ) == { " foo1 " , " bar2 " }
2022-12-20 11:21:26 +01:00
2023-03-27 12:14:11 +02:00
p = PromptTemplate ( " t4 " , " Here is some fake template with variable {foo_1} and {bar_2} " )
assert set ( p . prompt_params ) == { " foo_1 " , " bar_2 " }
2022-12-20 11:21:26 +01:00
2023-03-27 12:14:11 +02:00
p = PromptTemplate ( " t4 " , " Here is some fake template with variable {Foo_1} and {Bar_2} " )
assert set ( p . prompt_params ) == { " Foo_1 " , " Bar_2 " }
2022-12-20 11:21:26 +01:00
2023-03-27 12:14:11 +02:00
p = PromptTemplate ( " t4 " , " ' Here is some fake template with variable {baz} ' " )
assert set ( p . prompt_params ) == { " baz " }
2022-12-20 11:21:26 +01:00
# strip single quotes, happens in YAML as we need to use single quotes for the template string
2023-03-27 12:14:11 +02:00
assert p . prompt_text == " Here is some fake template with variable {baz} "
2022-12-20 11:21:26 +01:00
2023-03-27 12:14:11 +02:00
p = PromptTemplate ( " t4 " , ' " Here is some fake template with variable {baz} " ' )
assert set ( p . prompt_params ) == { " baz " }
2022-12-20 11:21:26 +01:00
# strip double quotes, happens in YAML as we need to use single quotes for the template string
2023-03-27 12:14:11 +02:00
assert p . prompt_text == " Here is some fake template with variable {baz} "
2022-12-20 11:21:26 +01:00
2023-02-27 15:00:19 +01:00
@pytest.mark.unit
2023-02-07 08:14:32 +01:00
def test_prompt_template_repr ( ) :
2023-03-27 12:14:11 +02:00
p = PromptTemplate ( " t " , " Here is variable {baz} " )
desired_repr = " PromptTemplate(name=t, prompt_text=Here is variable {baz} , prompt_params=[ ' baz ' ]) "
2023-02-07 08:14:32 +01:00
assert repr ( p ) == desired_repr
assert str ( p ) == desired_repr
2023-03-06 16:34:59 +01:00
@pytest.mark.unit
2023-03-07 20:53:48 +01:00
def test_prompt_node_with_custom_invocation_layer ( ) :
model = PromptModel ( " fake_model " )
2023-03-06 16:34:59 +01:00
pn = PromptNode ( model_name_or_path = model )
output = pn ( " Some fake invocation " )
assert output == [ " fake_response " ]
2023-02-27 15:00:19 +01:00
@pytest.mark.integration
2022-12-20 11:21:26 +01:00
def test_create_prompt_model ( ) :
model = PromptModel ( " google/flan-t5-small " )
assert model . model_name_or_path == " google/flan-t5-small "
model = PromptModel ( )
assert model . model_name_or_path == " google/flan-t5-base "
with pytest . raises ( OpenAIError ) :
# davinci selected but no API key provided
model = PromptModel ( " text-davinci-003 " )
model = PromptModel ( " text-davinci-003 " , api_key = " no need to provide a real key " )
assert model . model_name_or_path == " text-davinci-003 "
with pytest . raises ( ValueError , match = " Model some-random-model is not supported " ) :
PromptModel ( " some-random-model " )
# we can also pass model kwargs to the PromptModel
model = PromptModel ( " google/flan-t5-small " , model_kwargs = { " model_kwargs " : { " torch_dtype " : torch . bfloat16 } } )
assert model . model_name_or_path == " google/flan-t5-small "
# we can also pass kwargs directly, see HF Pipeline constructor
model = PromptModel ( " google/flan-t5-small " , model_kwargs = { " torch_dtype " : torch . bfloat16 } )
assert model . model_name_or_path == " google/flan-t5-small "
# we can't use device_map auto without accelerate library installed
with pytest . raises ( ImportError , match = " requires Accelerate: `pip install accelerate` " ) :
model = PromptModel ( " google/flan-t5-small " , model_kwargs = { " device_map " : " auto " } )
assert model . model_name_or_path == " google/flan-t5-small "
2023-02-27 09:59:27 +01:00
def test_create_prompt_model_dtype ( ) :
model = PromptModel ( " google/flan-t5-small " , model_kwargs = { " torch_dtype " : " auto " } )
assert model . model_name_or_path == " google/flan-t5-small "
model = PromptModel ( " google/flan-t5-small " , model_kwargs = { " torch_dtype " : " torch.bfloat16 " } )
assert model . model_name_or_path == " google/flan-t5-small "
2023-02-27 15:00:19 +01:00
@pytest.mark.integration
2022-12-20 11:21:26 +01:00
def test_create_prompt_node ( ) :
prompt_node = PromptNode ( )
assert prompt_node is not None
assert prompt_node . prompt_model is not None
prompt_node = PromptNode ( " google/flan-t5-small " )
assert prompt_node is not None
assert prompt_node . model_name_or_path == " google/flan-t5-small "
assert prompt_node . prompt_model is not None
with pytest . raises ( OpenAIError ) :
# davinci selected but no API key provided
prompt_node = PromptNode ( " text-davinci-003 " )
prompt_node = PromptNode ( " text-davinci-003 " , api_key = " no need to provide a real key " )
assert prompt_node is not None
assert prompt_node . model_name_or_path == " text-davinci-003 "
assert prompt_node . prompt_model is not None
with pytest . raises ( ValueError , match = " Model some-random-model is not supported " ) :
PromptNode ( " some-random-model " )
2023-02-28 09:37:40 +01:00
@pytest.mark.integration
2022-12-20 11:21:26 +01:00
def test_add_and_remove_template ( prompt_node ) :
num_default_tasks = len ( prompt_node . get_prompt_template_names ( ) )
2023-03-27 12:14:11 +02:00
custom_task = PromptTemplate ( name = " custom-task " , prompt_text = " Custom task: {param1} , {param2} " )
2022-12-20 11:21:26 +01:00
prompt_node . add_prompt_template ( custom_task )
assert len ( prompt_node . get_prompt_template_names ( ) ) == num_default_tasks + 1
assert " custom-task " in prompt_node . get_prompt_template_names ( )
assert prompt_node . remove_prompt_template ( " custom-task " ) is not None
assert " custom-task " not in prompt_node . get_prompt_template_names ( )
2023-02-27 15:00:19 +01:00
@pytest.mark.integration
2022-12-20 11:21:26 +01:00
def test_add_template_and_invoke ( prompt_node ) :
tt = PromptTemplate (
name = " sentiment-analysis-new " ,
prompt_text = " Please give a sentiment for this context. Answer with positive, "
2023-03-27 12:14:11 +02:00
" negative or neutral. Context: {documents} ; Answer: " ,
2022-12-20 11:21:26 +01:00
)
prompt_node . add_prompt_template ( tt )
r = prompt_node . prompt ( " sentiment-analysis-new " , documents = [ " Berlin is an amazing city. " ] )
assert r [ 0 ] . casefold ( ) == " positive "
2023-02-27 15:00:19 +01:00
@pytest.mark.integration
2022-12-20 11:21:26 +01:00
def test_on_the_fly_prompt ( prompt_node ) :
2023-03-03 13:49:21 +01:00
prompt_template = PromptTemplate (
2022-12-20 11:21:26 +01:00
name = " sentiment-analysis-temp " ,
prompt_text = " Please give a sentiment for this context. Answer with positive, "
2023-03-27 12:14:11 +02:00
" negative or neutral. Context: {documents} ; Answer: " ,
2022-12-20 11:21:26 +01:00
)
2023-03-03 13:49:21 +01:00
r = prompt_node . prompt ( prompt_template , documents = [ " Berlin is an amazing city. " ] )
2022-12-20 11:21:26 +01:00
assert r [ 0 ] . casefold ( ) == " positive "
2023-02-27 15:00:19 +01:00
@pytest.mark.integration
2022-12-20 11:21:26 +01:00
def test_direct_prompting ( prompt_node ) :
r = prompt_node ( " What is the capital of Germany? " )
assert r [ 0 ] . casefold ( ) == " berlin "
r = prompt_node ( " What is the capital of Germany? " , " What is the secret of universe? " )
assert r [ 0 ] . casefold ( ) == " berlin "
assert len ( r [ 1 ] ) > 0
r = prompt_node ( " Capital of Germany is Berlin " , task = " question-generation " )
assert len ( r [ 0 ] ) > 10 and " Germany " in r [ 0 ]
r = prompt_node ( [ " Capital of Germany is Berlin " , " Capital of France is Paris " ] , task = " question-generation " )
assert len ( r ) == 2
2023-02-27 15:00:19 +01:00
@pytest.mark.integration
2022-12-20 11:21:26 +01:00
def test_question_generation ( prompt_node ) :
r = prompt_node . prompt ( " question-generation " , documents = [ " Berlin is the capital of Germany. " ] )
assert len ( r ) == 1 and len ( r [ 0 ] ) > 0
2023-02-27 15:00:19 +01:00
@pytest.mark.integration
2022-12-20 11:21:26 +01:00
def test_template_selection ( prompt_node ) :
2023-03-27 12:14:11 +02:00
qa = prompt_node . set_default_prompt_template ( " question-answering-per-document " )
2022-12-20 11:21:26 +01:00
r = qa (
[ " Berlin is the capital of Germany. " , " Paris is the capital of France. " ] ,
[ " What is the capital of Germany? " , " What is the capital of France " ] ,
)
2023-03-27 12:14:11 +02:00
assert r [ 0 ] . answer . casefold ( ) == " berlin " and r [ 1 ] . answer . casefold ( ) == " paris "
2022-12-20 11:21:26 +01:00
2023-02-27 15:00:19 +01:00
@pytest.mark.integration
2022-12-20 11:21:26 +01:00
def test_has_supported_template_names ( prompt_node ) :
assert len ( prompt_node . get_prompt_template_names ( ) ) > 0
2023-02-27 15:00:19 +01:00
@pytest.mark.integration
2022-12-20 11:21:26 +01:00
def test_invalid_template_params ( prompt_node ) :
2023-02-10 15:24:01 +01:00
with pytest . raises ( ValueError , match = " Expected prompt parameters " ) :
2023-03-27 12:14:11 +02:00
prompt_node . prompt ( " question-answering-per-document " , { " some_crazy_key " : " Berlin is the capital of Germany. " } )
2022-12-20 11:21:26 +01:00
2023-02-27 15:00:19 +01:00
@pytest.mark.integration
2022-12-20 11:21:26 +01:00
def test_wrong_template_params ( prompt_node ) :
2023-02-10 15:24:01 +01:00
with pytest . raises ( ValueError , match = " Expected prompt parameters " ) :
2022-12-20 11:21:26 +01:00
# with don't have options param, multiple choice QA has
2023-03-27 12:14:11 +02:00
prompt_node . prompt ( " question-answering-per-document " , options = [ " Berlin is the capital of Germany. " ] )
2022-12-20 11:21:26 +01:00
2023-02-27 15:00:19 +01:00
@pytest.mark.integration
2022-12-20 11:21:26 +01:00
def test_run_invalid_template ( prompt_node ) :
with pytest . raises ( ValueError , match = " invalid-task not supported " ) :
prompt_node . prompt ( " invalid-task " , { } )
2023-02-27 15:00:19 +01:00
@pytest.mark.integration
2022-12-20 11:21:26 +01:00
def test_invalid_prompting ( prompt_node ) :
with pytest . raises ( ValueError , match = " Hey there, what is the best city in the " ) :
prompt_node . prompt ( [ " Hey there, what is the best city in the world? " , " Hey, answer me! " ] )
2023-03-27 14:10:20 +02:00
@pytest.mark.integration
def test_prompt_at_query_time ( prompt_node : PromptNode ) :
results = prompt_node . prompt ( " Hey there, what is the best city in the world? " )
assert len ( results ) == 1
assert isinstance ( results [ 0 ] , str )
2023-02-27 15:00:19 +01:00
@pytest.mark.integration
2022-12-20 11:21:26 +01:00
def test_invalid_state_ops ( prompt_node ) :
with pytest . raises ( ValueError , match = " Prompt template no_such_task_exists " ) :
prompt_node . remove_prompt_template ( " no_such_task_exists " )
# remove default task
2023-03-27 12:14:11 +02:00
prompt_node . remove_prompt_template ( " question-answering-per-document " )
2022-12-20 11:21:26 +01:00
2023-01-25 19:02:11 +01:00
@pytest.mark.integration
2023-03-02 09:55:09 +01:00
@pytest.mark.parametrize ( " prompt_model " , [ " openai " , " azure " ] , indirect = True )
def test_open_ai_prompt_with_params ( prompt_model ) :
skip_test_for_invalid_key ( prompt_model )
pn = PromptNode ( prompt_model )
2022-12-20 11:21:26 +01:00
optional_davinci_params = { " temperature " : 0.5 , " max_tokens " : 10 , " top_p " : 1 , " frequency_penalty " : 0.5 }
r = pn . prompt ( " question-generation " , documents = [ " Berlin is the capital of Germany. " ] , * * optional_davinci_params )
assert len ( r ) == 1 and len ( r [ 0 ] ) > 0
2023-02-15 18:46:26 +01:00
@pytest.mark.integration
2023-03-02 09:55:09 +01:00
def test_open_ai_prompt_with_default_params ( azure_conf ) :
if not azure_conf :
pytest . skip ( " No Azure API key found, skipping test " )
model_kwargs = { " temperature " : 0.5 , " max_tokens " : 2 , " top_p " : 1 , " frequency_penalty " : 0.5 }
model_kwargs . update ( azure_conf )
pn = PromptNode ( model_name_or_path = " text-davinci-003 " , api_key = azure_conf [ " api_key " ] , model_kwargs = model_kwargs )
2023-02-15 18:46:26 +01:00
result = pn . prompt ( " question-generation " , documents = [ " Berlin is the capital of Germany. " ] )
assert len ( result ) == 1 and len ( result [ 0 ] ) > 0
2023-02-08 10:39:40 +01:00
@pytest.mark.integration
2023-03-02 09:55:09 +01:00
@pytest.mark.parametrize ( " prompt_model " , [ " openai " , " azure " ] , indirect = True )
def test_open_ai_warn_if_max_tokens_is_too_short ( prompt_model , caplog ) :
skip_test_for_invalid_key ( prompt_model )
pn = PromptNode ( prompt_model )
2023-02-08 10:39:40 +01:00
optional_davinci_params = { " temperature " : 0.5 , " max_tokens " : 2 , " top_p " : 1 , " frequency_penalty " : 0.5 }
with caplog . at_level ( logging . WARNING ) :
_ = pn . prompt ( " question-generation " , documents = [ " Berlin is the capital of Germany. " ] , * * optional_davinci_params )
2023-03-03 13:49:21 +01:00
assert " Increase the max_tokens parameter to allow for longer completions. " in caplog . text
2023-02-08 10:39:40 +01:00
2023-01-25 19:02:11 +01:00
@pytest.mark.integration
2023-03-02 09:55:09 +01:00
@pytest.mark.parametrize ( " prompt_model " , [ " hf " , " openai " , " azure " ] , indirect = True )
2023-01-19 12:26:15 +01:00
def test_stop_words ( prompt_model ) :
2023-03-02 09:55:09 +01:00
skip_test_for_invalid_key ( prompt_model )
2023-01-19 12:26:15 +01:00
2023-01-24 12:52:41 +01:00
# test stop words for both HF and OpenAI
# set stop words in PromptNode
2023-01-19 12:26:15 +01:00
node = PromptNode ( prompt_model , stop_words = [ " capital " , " Germany " ] )
2023-01-24 12:52:41 +01:00
# with default prompt template and stop words set in PN
2023-01-19 12:26:15 +01:00
r = node . prompt ( " question-generation " , documents = [ " Berlin is the capital of Germany. " ] )
2023-01-24 12:52:41 +01:00
assert r [ 0 ] == " What is the " or r [ 0 ] == " What city is the "
# with default prompt template and stop words set in kwargs (overrides PN stop words)
r = node . prompt ( " question-generation " , documents = [ " Berlin is the capital of Germany. " ] , stop_words = None )
assert " capital " in r [ 0 ] or " Germany " in r [ 0 ]
# simple prompting
r = node ( " Given the context please generate a question. Context: Berlin is the capital of Germany.; Question: " )
assert len ( r [ 0 ] ) > 0
assert " capital " not in r [ 0 ]
assert " Germany " not in r [ 0 ]
# simple prompting with stop words set in kwargs (overrides PN stop words)
r = node (
" Given the context please generate a question. Context: Berlin is the capital of Germany.; Question: " ,
stop_words = None ,
)
assert " capital " in r [ 0 ] or " Germany " in r [ 0 ]
tt = PromptTemplate (
name = " question-generation-copy " ,
2023-03-27 12:14:11 +02:00
prompt_text = " Given the context please generate a question. Context: {documents} ; Question: " ,
2023-01-24 12:52:41 +01:00
)
# with custom prompt template
r = node . prompt ( tt , documents = [ " Berlin is the capital of Germany. " ] )
assert r [ 0 ] == " What is the " or r [ 0 ] == " What city is the "
# with custom prompt template and stop words set in kwargs (overrides PN stop words)
r = node . prompt ( tt , documents = [ " Berlin is the capital of Germany. " ] , stop_words = None )
assert " capital " in r [ 0 ] or " Germany " in r [ 0 ]
2023-01-19 12:26:15 +01:00
2023-03-15 18:24:47 +01:00
@pytest.mark.integration
@pytest.mark.parametrize ( " prompt_model " , [ " openai " , " azure " ] , indirect = True )
def test_streaming_prompt_node_with_params ( prompt_model ) :
skip_test_for_invalid_key ( prompt_model )
# test streaming of calls to OpenAI by passing a stream handler to the prompt method
ttsh = TestTokenStreamingHandler ( )
node = PromptNode ( prompt_model )
response = node ( " What are some of the best cities in the world to live and why? " , stream = True , stream_handler = ttsh )
assert len ( response [ 0 ] ) > 0 , " Response should not be empty "
assert ttsh . stream_handler_invoked , " Stream handler should have been invoked "
@pytest.mark.integration
@pytest.mark.skipif (
not os . environ . get ( " OPENAI_API_KEY " , None ) ,
reason = " No OpenAI API key provided. Please export an env var called OPENAI_API_KEY containing the OpenAI API key. " ,
)
def test_streaming_prompt_node ( ) :
ttsh = TestTokenStreamingHandler ( )
# test streaming of all calls to OpenAI by registering a stream handler as a model kwarg
node = PromptNode (
" text-davinci-003 " , api_key = os . environ . get ( " OPENAI_API_KEY " ) , model_kwargs = { " stream_handler " : ttsh }
)
response = node ( " What are some of the best cities in the world to live? " )
assert len ( response [ 0 ] ) > 0 , " Response should not be empty "
assert ttsh . stream_handler_invoked , " Stream handler should have been invoked "
2023-03-22 07:20:36 +01:00
def test_prompt_node_with_text_generation_model ( ) :
# test simple prompting with text generation model
# by default, we force the model not return prompt text
# Thus text-generation models can be used with PromptNode
# just like text2text-generation models
node = PromptNode ( " bigscience/bigscience-small-testing " )
r = node ( " Hello big science! " )
assert len ( r [ 0 ] ) > 0
# test prompting with parameter to return prompt text as well
# users can use this param to get the prompt text and the generated text
r = node ( " Hello big science! " , return_full_text = True )
assert len ( r [ 0 ] ) > 0 and r [ 0 ] . startswith ( " Hello big science! " )
2023-01-25 19:02:11 +01:00
@pytest.mark.integration
2023-03-02 09:55:09 +01:00
@pytest.mark.parametrize ( " prompt_model " , [ " hf " , " openai " , " azure " ] , indirect = True )
2022-12-20 11:21:26 +01:00
def test_simple_pipeline ( prompt_model ) :
2023-03-02 09:55:09 +01:00
skip_test_for_invalid_key ( prompt_model )
2022-12-20 11:21:26 +01:00
2023-02-20 11:58:13 +01:00
node = PromptNode ( prompt_model , default_prompt_template = " sentiment-analysis " , output_variable = " out " )
2022-12-20 11:21:26 +01:00
pipe = Pipeline ( )
pipe . add_node ( component = node , name = " prompt_node " , inputs = [ " Query " ] )
result = pipe . run ( query = " not relevant " , documents = [ Document ( " Berlin is an amazing city. " ) ] )
2023-02-22 12:37:09 +01:00
assert " positive " in result [ " out " ] [ 0 ] . casefold ( )
2022-12-20 11:21:26 +01:00
2023-01-25 19:02:11 +01:00
@pytest.mark.integration
2023-03-02 09:55:09 +01:00
@pytest.mark.parametrize ( " prompt_model " , [ " hf " , " openai " , " azure " ] , indirect = True )
2022-12-20 11:21:26 +01:00
def test_complex_pipeline ( prompt_model ) :
2023-03-02 09:55:09 +01:00
skip_test_for_invalid_key ( prompt_model )
2022-12-20 11:21:26 +01:00
2023-03-27 12:14:11 +02:00
node = PromptNode ( prompt_model , default_prompt_template = " question-generation " , output_variable = " query " )
node2 = PromptNode ( prompt_model , default_prompt_template = " question-answering-per-document " )
2022-12-20 11:21:26 +01:00
pipe = Pipeline ( )
pipe . add_node ( component = node , name = " prompt_node " , inputs = [ " Query " ] )
pipe . add_node ( component = node2 , name = " prompt_node_2 " , inputs = [ " prompt_node " ] )
result = pipe . run ( query = " not relevant " , documents = [ Document ( " Berlin is the capital of Germany " ) ] )
2023-03-27 12:14:11 +02:00
assert " berlin " in result [ " answers " ] [ 0 ] . answer . casefold ( )
2022-12-20 11:21:26 +01:00
2023-02-20 14:51:45 +01:00
@pytest.mark.integration
2023-03-02 09:55:09 +01:00
@pytest.mark.parametrize ( " prompt_model " , [ " hf " , " openai " , " azure " ] , indirect = True )
2023-02-20 14:51:45 +01:00
def test_simple_pipeline_with_topk ( prompt_model ) :
2023-03-02 09:55:09 +01:00
skip_test_for_invalid_key ( prompt_model )
2023-02-20 14:51:45 +01:00
2023-03-27 12:14:11 +02:00
node = PromptNode ( prompt_model , default_prompt_template = " question-generation " , output_variable = " query " , top_k = 2 )
2023-02-20 14:51:45 +01:00
pipe = Pipeline ( )
pipe . add_node ( component = node , name = " prompt_node " , inputs = [ " Query " ] )
result = pipe . run ( query = " not relevant " , documents = [ Document ( " Berlin is the capital of Germany " ) ] )
2023-03-27 12:14:11 +02:00
assert len ( result [ " query " ] ) == 2
@pytest.mark.integration
@pytest.mark.parametrize ( " prompt_model " , [ " hf " , " openai " , " azure " ] , indirect = True )
def test_pipeline_with_standard_qa ( prompt_model ) :
skip_test_for_invalid_key ( prompt_model )
node = PromptNode ( prompt_model , default_prompt_template = " question-answering " , top_k = 1 )
pipe = Pipeline ( )
pipe . add_node ( component = node , name = " prompt_node " , inputs = [ " Query " ] )
result = pipe . run (
query = " Who lives in Berlin? " , # this being a string instead of a list what is being tested
documents = [
Document ( " My name is Carla and I live in Berlin " , id = " 1 " ) ,
Document ( " My name is Christelle and I live in Paris " , id = " 2 " ) ,
] ,
)
assert len ( result [ " answers " ] ) == 1
assert " carla " in result [ " answers " ] [ 0 ] . answer . casefold ( )
assert result [ " answers " ] [ 0 ] . document_ids == [ " 1 " , " 2 " ]
assert (
result [ " answers " ] [ 0 ] . meta [ " prompt " ]
== " Given the context please answer the question. Context: My name is Carla and I live in Berlin My name is Christelle and I live in Paris; "
" Question: Who lives in Berlin?; Answer: "
)
@pytest.mark.integration
@pytest.mark.parametrize ( " prompt_model " , [ " openai " , " azure " ] , indirect = True )
def test_pipeline_with_qa_with_references ( prompt_model ) :
skip_test_for_invalid_key ( prompt_model )
node = PromptNode ( prompt_model , default_prompt_template = " question-answering-with-references " , top_k = 1 )
pipe = Pipeline ( )
pipe . add_node ( component = node , name = " prompt_node " , inputs = [ " Query " ] )
result = pipe . run (
query = " Who lives in Berlin? " , # this being a string instead of a list what is being tested
documents = [
Document ( " My name is Carla and I live in Berlin " , id = " 1 " ) ,
Document ( " My name is Christelle and I live in Paris " , id = " 2 " ) ,
] ,
)
assert len ( result [ " answers " ] ) == 1
assert " carla, as stated in document[1] " in result [ " answers " ] [ 0 ] . answer . casefold ( )
assert result [ " answers " ] [ 0 ] . document_ids == [ " 1 " ]
assert (
result [ " answers " ] [ 0 ] . meta [ " prompt " ]
== " Create a concise and informative answer (no more than 50 words) for a given question based solely on the given documents. "
" You must only use information from the given documents. Use an unbiased and journalistic tone. Do not repeat text. Cite the documents using Document[number] notation. "
" If multiple documents contain the answer, cite those documents like ‘ as stated in Document[number], Document[number], etc.’ . If the documents do not contain the answer to the question, "
" say that ‘ answering is not possible given the available information.’ \n \n Document[1]: My name is Carla and I live in Berlin \n \n Document[2]: My name is Christelle and I live in Paris \n "
" Question: Who lives in Berlin?; Answer: "
)
2023-02-20 14:51:45 +01:00
2023-03-27 14:10:20 +02:00
@pytest.mark.integration
@pytest.mark.parametrize ( " prompt_model " , [ " openai " , " azure " ] , indirect = True )
def test_pipeline_with_prompt_text_at_query_time ( prompt_model ) :
skip_test_for_invalid_key ( prompt_model )
node = PromptNode ( prompt_model , default_prompt_template = " question-answering-with-references " , top_k = 1 )
pipe = Pipeline ( )
pipe . add_node ( component = node , name = " prompt_node " , inputs = [ " Query " ] )
result = pipe . run (
query = " Who lives in Berlin? " , # this being a string instead of a list what is being tested
documents = [
Document ( " My name is Carla and I live in Berlin " , id = " 1 " ) ,
Document ( " My name is Christelle and I live in Paris " , id = " 2 " ) ,
] ,
params = {
" prompt_template " : " Create a concise and informative answer (no more than 50 words) for a given question based solely on the given documents. Cite the documents using Document[number] notation. \n \n { join(documents, delimiter=new_line+new_line, pattern= ' Document[$idx]: $content ' )} \n \n Question: {query} \n \n Answer: "
} ,
)
assert len ( result [ " answers " ] ) == 1
assert " carla " in result [ " answers " ] [ 0 ] . answer . casefold ( )
assert result [ " answers " ] [ 0 ] . document_ids == [ " 1 " ]
assert (
result [ " answers " ] [ 0 ] . meta [ " prompt " ]
== " Create a concise and informative answer (no more than 50 words) for a given question based solely on the given documents. Cite the documents using Document[number] notation. \n \n "
" Document[1]: My name is Carla and I live in Berlin \n \n Document[2]: My name is Christelle and I live in Paris \n \n "
" Question: Who lives in Berlin? \n \n Answer: "
)
2023-04-03 09:49:30 +02:00
@pytest.mark.skip
2023-03-27 14:10:20 +02:00
@pytest.mark.integration
@pytest.mark.parametrize ( " prompt_model " , [ " openai " , " azure " ] , indirect = True )
def test_pipeline_with_prompt_template_at_query_time ( prompt_model ) :
skip_test_for_invalid_key ( prompt_model )
node = PromptNode ( prompt_model , default_prompt_template = " question-answering-with-references " , top_k = 1 )
prompt_template_yaml = """
name : " question-answering-with-references-custom "
prompt_text : ' Create a concise and informative answer (no more than 50 words) for
a given question based solely on the given documents . Cite the documents using Doc [ number ] notation .
{ join ( documents , delimiter = new_line + new_line , pattern = ' ' Doc [ $ idx ] : $ content ' ' ) }
Question : { query }
Answer : '
output_parser :
type : AnswerParser
params :
reference_pattern : Doc \\[ ( [ ^ \\] ] + ) \\]
"""
pipe = Pipeline ( )
pipe . add_node ( component = node , name = " prompt_node " , inputs = [ " Query " ] )
result = pipe . run (
query = " Who lives in Berlin? " , # this being a string instead of a list what is being tested
documents = [
Document ( " My name is Carla and I live in Berlin " , id = " doc-1 " ) ,
Document ( " My name is Christelle and I live in Paris " , id = " doc-2 " ) ,
] ,
params = { " prompt_template " : prompt_template_yaml } ,
)
assert len ( result [ " answers " ] ) == 1
assert " carla " in result [ " answers " ] [ 0 ] . answer . casefold ( )
assert result [ " answers " ] [ 0 ] . document_ids == [ " doc-1 " ]
assert (
result [ " answers " ] [ 0 ] . meta [ " prompt " ]
== " Create a concise and informative answer (no more than 50 words) for a given question based solely on the given documents. Cite the documents using Doc[number] notation. \n \n "
" Doc[1]: My name is Carla and I live in Berlin \n \n Doc[2]: My name is Christelle and I live in Paris \n \n "
" Question: Who lives in Berlin? \n \n Answer: "
)
@pytest.mark.integration
def test_pipeline_with_prompt_template_and_nested_shaper_yaml ( tmp_path ) :
with open ( tmp_path / " tmp_config_with_prompt_template.yml " , " w " ) as tmp_file :
tmp_file . write (
f """
version : ignore
components :
- name : template_with_nested_shaper
type : PromptTemplate
params :
name : custom - template - with - nested - shaper
prompt_text : " Given the context please answer the question. Context: {{ documents}}; Question: {{ query}}; Answer: "
output_parser :
type : AnswerParser
- name : p1
params :
model_name_or_path : google / flan - t5 - small
default_prompt_template : template_with_nested_shaper
type : PromptNode
pipelines :
- name : query
nodes :
- name : p1
inputs :
- Query
"""
)
pipeline = Pipeline . load_from_yaml ( path = tmp_path / " tmp_config_with_prompt_template.yml " )
result = pipeline . run ( query = " What is an amazing city? " , documents = [ Document ( " Berlin is an amazing city. " ) ] )
answer = result [ " answers " ] [ 0 ] . answer
assert any ( word for word in [ " berlin " , " germany " , " population " , " city " , " amazing " ] if word in answer . casefold ( ) )
assert (
result [ " answers " ] [ 0 ] . meta [ " prompt " ]
== " Given the context please answer the question. Context: Berlin is an amazing city.; Question: What is an amazing city?; Answer: "
)
2023-03-24 14:07:52 +00:00
@pytest.mark.integration
@pytest.mark.parametrize ( " prompt_model " , [ " hf " ] , indirect = True )
def test_prompt_node_no_debug ( prompt_model ) :
""" Pipeline with PromptNode should not generate debug info if debug is false. """
node = PromptNode ( prompt_model , default_prompt_template = " question-generation " , top_k = 2 )
pipe = Pipeline ( )
pipe . add_node ( component = node , name = " prompt_node " , inputs = [ " Query " ] )
# debug explicitely False
result = pipe . run ( query = " not relevant " , documents = [ Document ( " Berlin is the capital of Germany " ) ] , debug = False )
assert result . get ( " _debug " , " No debug info " ) == " No debug info "
# debug None
result = pipe . run ( query = " not relevant " , documents = [ Document ( " Berlin is the capital of Germany " ) ] , debug = None )
assert result . get ( " _debug " , " No debug info " ) == " No debug info "
# debug True
result = pipe . run ( query = " not relevant " , documents = [ Document ( " Berlin is the capital of Germany " ) ] , debug = True )
assert (
result [ " _debug " ] [ " prompt_node " ] [ " runtime " ] [ " prompts_used " ] [ 0 ]
== " Given the context please generate a question. Context: Berlin is the capital of Germany; Question: "
)
2023-01-25 19:02:11 +01:00
@pytest.mark.integration
2023-03-02 09:55:09 +01:00
@pytest.mark.parametrize ( " prompt_model " , [ " hf " , " openai " , " azure " ] , indirect = True )
2023-01-10 02:08:17 -05:00
def test_complex_pipeline_with_qa ( prompt_model ) :
""" Test the PromptNode where the `query` is a string instead of a list what the PromptNode would expects,
because in a question - answering pipeline the retrievers need ` query ` as a string , so the PromptNode
need to be able to handle the ` query ` being a string instead of a list . """
2023-03-02 09:55:09 +01:00
skip_test_for_invalid_key ( prompt_model )
2023-01-10 02:08:17 -05:00
prompt_template = PromptTemplate (
name = " question-answering-new " ,
2023-03-27 12:14:11 +02:00
prompt_text = " Given the context please answer the question. Context: {documents} ; Question: {query} ; Answer: " ,
2023-01-10 02:08:17 -05:00
)
node = PromptNode ( prompt_model , default_prompt_template = prompt_template )
pipe = Pipeline ( )
pipe . add_node ( component = node , name = " prompt_node " , inputs = [ " Query " ] )
result = pipe . run (
query = " Who lives in Berlin? " , # this being a string instead of a list what is being tested
documents = [
Document ( " My name is Carla and I live in Berlin " ) ,
Document ( " My name is Christelle and I live in Paris " ) ,
] ,
2023-01-31 03:33:47 -05:00
debug = True , # so we can verify that the constructed prompt is returned in debug
2023-01-10 02:08:17 -05:00
)
2023-03-27 12:14:11 +02:00
assert len ( result [ " results " ] ) == 2
2023-01-10 02:08:17 -05:00
assert " carla " in result [ " results " ] [ 0 ] . casefold ( )
2023-01-31 03:33:47 -05:00
# also verify that the PromptNode has included its constructed prompt LLM model input in the returned debug
assert (
result [ " _debug " ] [ " prompt_node " ] [ " runtime " ] [ " prompts_used " ] [ 0 ]
== " Given the context please answer the question. Context: My name is Carla and I live in Berlin; "
" Question: Who lives in Berlin?; Answer: "
)
2023-01-10 02:08:17 -05:00
2023-02-27 15:00:19 +01:00
@pytest.mark.integration
2022-12-20 11:21:26 +01:00
def test_complex_pipeline_with_shared_model ( ) :
model = PromptModel ( )
2023-03-27 12:14:11 +02:00
node = PromptNode ( model_name_or_path = model , default_prompt_template = " question-generation " , output_variable = " query " )
node2 = PromptNode ( model_name_or_path = model , default_prompt_template = " question-answering-per-document " )
2022-12-20 11:21:26 +01:00
pipe = Pipeline ( )
pipe . add_node ( component = node , name = " prompt_node " , inputs = [ " Query " ] )
pipe . add_node ( component = node2 , name = " prompt_node_2 " , inputs = [ " prompt_node " ] )
result = pipe . run ( query = " not relevant " , documents = [ Document ( " Berlin is the capital of Germany " ) ] )
2023-03-27 12:14:11 +02:00
assert result [ " answers " ] [ 0 ] . answer == " Berlin "
2022-12-20 11:21:26 +01:00
2023-02-27 15:00:19 +01:00
@pytest.mark.integration
2022-12-20 11:21:26 +01:00
def test_simple_pipeline_yaml ( tmp_path ) :
with open ( tmp_path / " tmp_config.yml " , " w " ) as tmp_file :
tmp_file . write (
f """
version : ignore
components :
- name : p1
params :
default_prompt_template : sentiment - analysis
type : PromptNode
pipelines :
- name : query
nodes :
- name : p1
inputs :
- Query
"""
)
pipeline = Pipeline . load_from_yaml ( path = tmp_path / " tmp_config.yml " )
result = pipeline . run ( query = " not relevant " , documents = [ Document ( " Berlin is an amazing city. " ) ] )
assert result [ " results " ] [ 0 ] == " positive "
2023-02-27 15:00:19 +01:00
@pytest.mark.integration
2023-02-15 18:46:26 +01:00
def test_simple_pipeline_yaml_with_default_params ( tmp_path ) :
with open ( tmp_path / " tmp_config.yml " , " w " ) as tmp_file :
tmp_file . write (
f """
version : ignore
components :
- name : p1
type : PromptNode
params :
default_prompt_template : sentiment - analysis
model_kwargs :
torch_dtype : torch . bfloat16
pipelines :
- name : query
nodes :
- name : p1
inputs :
- Query
"""
)
pipeline = Pipeline . load_from_yaml ( path = tmp_path / " tmp_config.yml " )
assert pipeline . graph . nodes [ " p1 " ] [ " component " ] . prompt_model . model_kwargs == { " torch_dtype " : " torch.bfloat16 " }
result = pipeline . run ( query = None , documents = [ Document ( " Berlin is an amazing city. " ) ] )
assert result [ " results " ] [ 0 ] == " positive "
2023-02-27 15:00:19 +01:00
@pytest.mark.integration
2022-12-20 11:21:26 +01:00
def test_complex_pipeline_yaml ( tmp_path ) :
with open ( tmp_path / " tmp_config.yml " , " w " ) as tmp_file :
tmp_file . write (
f """
version : ignore
components :
- name : p1
params :
default_prompt_template : question - generation
2023-03-27 12:14:11 +02:00
output_variable : query
2022-12-20 11:21:26 +01:00
type : PromptNode
- name : p2
params :
2023-03-27 12:14:11 +02:00
default_prompt_template : question - answering - per - document
2022-12-20 11:21:26 +01:00
type : PromptNode
pipelines :
- name : query
nodes :
- name : p1
inputs :
- Query
- name : p2
inputs :
- p1
"""
)
pipeline = Pipeline . load_from_yaml ( path = tmp_path / " tmp_config.yml " )
result = pipeline . run ( query = " not relevant " , documents = [ Document ( " Berlin is an amazing city. " ) ] )
2023-03-27 12:14:11 +02:00
response = result [ " answers " ] [ 0 ] . answer
2023-01-02 16:28:56 +01:00
assert any ( word for word in [ " berlin " , " germany " , " population " , " city " , " amazing " ] if word in response . casefold ( ) )
2023-01-19 11:17:06 +01:00
assert len ( result [ " invocation_context " ] ) > 0
2023-03-27 12:14:11 +02:00
assert len ( result [ " query " ] ) > 0
assert " query " in result [ " invocation_context " ] and len ( result [ " invocation_context " ] [ " query " ] ) > 0
2022-12-20 11:21:26 +01:00
2023-02-27 15:00:19 +01:00
@pytest.mark.integration
2022-12-20 11:21:26 +01:00
def test_complex_pipeline_with_shared_prompt_model_yaml ( tmp_path ) :
with open ( tmp_path / " tmp_config.yml " , " w " ) as tmp_file :
tmp_file . write (
f """
version : ignore
components :
- name : pmodel
type : PromptModel
- name : p1
params :
model_name_or_path : pmodel
default_prompt_template : question - generation
2023-03-27 12:14:11 +02:00
output_variable : query
2022-12-20 11:21:26 +01:00
type : PromptNode
- name : p2
params :
model_name_or_path : pmodel
2023-03-27 12:14:11 +02:00
default_prompt_template : question - answering - per - document
2022-12-20 11:21:26 +01:00
type : PromptNode
pipelines :
- name : query
nodes :
- name : p1
inputs :
- Query
- name : p2
inputs :
- p1
"""
)
pipeline = Pipeline . load_from_yaml ( path = tmp_path / " tmp_config.yml " )
result = pipeline . run ( query = " not relevant " , documents = [ Document ( " Berlin is an amazing city. " ) ] )
2023-03-27 12:14:11 +02:00
response = result [ " answers " ] [ 0 ] . answer
2023-01-02 16:28:56 +01:00
assert any ( word for word in [ " berlin " , " germany " , " population " , " city " , " amazing " ] if word in response . casefold ( ) )
2023-01-19 11:17:06 +01:00
assert len ( result [ " invocation_context " ] ) > 0
2023-03-27 12:14:11 +02:00
assert len ( result [ " query " ] ) > 0
assert " query " in result [ " invocation_context " ] and len ( result [ " invocation_context " ] [ " query " ] ) > 0
2022-12-20 11:21:26 +01:00
2023-02-27 15:00:19 +01:00
@pytest.mark.integration
2022-12-20 11:21:26 +01:00
def test_complex_pipeline_with_shared_prompt_model_and_prompt_template_yaml ( tmp_path ) :
with open ( tmp_path / " tmp_config_with_prompt_template.yml " , " w " ) as tmp_file :
tmp_file . write (
f """
version : ignore
components :
- name : pmodel
type : PromptModel
params :
model_name_or_path : google / flan - t5 - small
model_kwargs :
2023-02-27 09:59:27 +01:00
torch_dtype : auto
2022-12-20 11:21:26 +01:00
- name : question_generation_template
type : PromptTemplate
params :
name : question - generation - new
2023-03-27 12:14:11 +02:00
prompt_text : " Given the context please generate a question. Context: {{ documents}}; Question: "
2022-12-20 11:21:26 +01:00
- name : p1
params :
model_name_or_path : pmodel
default_prompt_template : question_generation_template
2023-03-27 12:14:11 +02:00
output_variable : query
2022-12-20 11:21:26 +01:00
type : PromptNode
- name : p2
params :
model_name_or_path : pmodel
2023-03-27 12:14:11 +02:00
default_prompt_template : question - answering - per - document
2022-12-20 11:21:26 +01:00
type : PromptNode
pipelines :
- name : query
nodes :
- name : p1
inputs :
- Query
- name : p2
inputs :
- p1
"""
)
pipeline = Pipeline . load_from_yaml ( path = tmp_path / " tmp_config_with_prompt_template.yml " )
result = pipeline . run ( query = " not relevant " , documents = [ Document ( " Berlin is an amazing city. " ) ] )
2023-03-27 12:14:11 +02:00
response = result [ " answers " ] [ 0 ] . answer
2023-01-02 16:28:56 +01:00
assert any ( word for word in [ " berlin " , " germany " , " population " , " city " , " amazing " ] if word in response . casefold ( ) )
2023-01-19 11:17:06 +01:00
assert len ( result [ " invocation_context " ] ) > 0
2023-03-27 12:14:11 +02:00
assert len ( result [ " query " ] ) > 0
assert " query " in result [ " invocation_context " ] and len ( result [ " invocation_context " ] [ " query " ] ) > 0
2023-01-19 11:17:06 +01:00
2023-02-27 15:00:19 +01:00
@pytest.mark.integration
2023-01-19 11:17:06 +01:00
def test_complex_pipeline_with_with_dummy_node_between_prompt_nodes_yaml ( tmp_path ) :
# test that we can stick some random node in between prompt nodes and that everything still works
# most specifically, we want to ensure that invocation_context is still populated correctly and propagated
class InBetweenNode ( BaseComponent ) :
outgoing_edges = 1
def run (
self ,
query : Optional [ str ] = None ,
file_paths : Optional [ List [ str ] ] = None ,
labels : Optional [ MultiLabel ] = None ,
documents : Optional [ List [ Document ] ] = None ,
meta : Optional [ dict ] = None ,
) - > Tuple [ Dict , str ] :
return { } , " output_1 "
def run_batch (
self ,
queries : Optional [ Union [ str , List [ str ] ] ] = None ,
file_paths : Optional [ List [ str ] ] = None ,
labels : Optional [ Union [ MultiLabel , List [ MultiLabel ] ] ] = None ,
documents : Optional [ Union [ List [ Document ] , List [ List [ Document ] ] ] ] = None ,
meta : Optional [ Union [ Dict [ str , Any ] , List [ Dict [ str , Any ] ] ] ] = None ,
params : Optional [ dict ] = None ,
debug : Optional [ bool ] = None ,
) :
return { } , " output_1 "
with open ( tmp_path / " tmp_config_with_prompt_template.yml " , " w " ) as tmp_file :
tmp_file . write (
f """
version : ignore
components :
- name : in_between
type : InBetweenNode
- name : pmodel
type : PromptModel
params :
model_name_or_path : google / flan - t5 - small
model_kwargs :
torch_dtype : torch . bfloat16
- name : question_generation_template
type : PromptTemplate
params :
name : question - generation - new
2023-03-27 12:14:11 +02:00
prompt_text : " Given the context please generate a question. Context: {{ documents}}; Question: "
2023-01-19 11:17:06 +01:00
- name : p1
params :
model_name_or_path : pmodel
default_prompt_template : question_generation_template
2023-03-27 12:14:11 +02:00
output_variable : query
2023-01-19 11:17:06 +01:00
type : PromptNode
- name : p2
params :
model_name_or_path : pmodel
2023-03-27 12:14:11 +02:00
default_prompt_template : question - answering - per - document
2023-01-19 11:17:06 +01:00
type : PromptNode
pipelines :
- name : query
nodes :
- name : p1
inputs :
- Query
- name : in_between
inputs :
- p1
- name : p2
inputs :
- in_between
"""
)
pipeline = Pipeline . load_from_yaml ( path = tmp_path / " tmp_config_with_prompt_template.yml " )
result = pipeline . run ( query = " not relevant " , documents = [ Document ( " Berlin is an amazing city. " ) ] )
2023-03-27 12:14:11 +02:00
response = result [ " answers " ] [ 0 ] . answer
2023-01-19 11:17:06 +01:00
assert any ( word for word in [ " berlin " , " germany " , " population " , " city " , " amazing " ] if word in response . casefold ( ) )
assert len ( result [ " invocation_context " ] ) > 0
2023-03-27 12:14:11 +02:00
assert len ( result [ " query " ] ) > 0
assert " query " in result [ " invocation_context " ] and len ( result [ " invocation_context " ] [ " query " ] ) > 0
2022-12-20 11:21:26 +01:00
2023-03-02 09:55:09 +01:00
@pytest.mark.parametrize ( " haystack_openai_config " , [ " openai " , " azure " ] , indirect = True )
def test_complex_pipeline_with_all_features ( tmp_path , haystack_openai_config ) :
if not haystack_openai_config :
pytest . skip ( " No API key found, skipping test " )
if " azure_base_url " in haystack_openai_config :
# don't change this indentation, it's important for the yaml to be valid
azure_conf_yaml_snippet = f """
azure_base_url : { haystack_openai_config [ ' azure_base_url ' ] }
azure_deployment_name : { haystack_openai_config [ ' azure_deployment_name ' ] }
"""
else :
azure_conf_yaml_snippet = " "
2022-12-20 11:21:26 +01:00
with open ( tmp_path / " tmp_config_with_prompt_template.yml " , " w " ) as tmp_file :
tmp_file . write (
f """
version : ignore
components :
- name : pmodel
type : PromptModel
params :
model_name_or_path : google / flan - t5 - small
model_kwargs :
torch_dtype : torch . bfloat16
- name : pmodel_openai
type : PromptModel
params :
model_name_or_path : text - davinci - 003
model_kwargs :
temperature : 0.9
max_tokens : 64
2023-03-02 09:55:09 +01:00
{ azure_conf_yaml_snippet }
api_key : { haystack_openai_config [ " api_key " ] }
2022-12-20 11:21:26 +01:00
- name : question_generation_template
type : PromptTemplate
params :
name : question - generation - new
2023-03-27 12:14:11 +02:00
prompt_text : " Given the context please generate a question. Context: {{ documents}}; Question: "
2022-12-20 11:21:26 +01:00
- name : p1
params :
model_name_or_path : pmodel_openai
default_prompt_template : question_generation_template
2023-03-27 12:14:11 +02:00
output_variable : query
2022-12-20 11:21:26 +01:00
type : PromptNode
- name : p2
params :
model_name_or_path : pmodel
2023-03-27 12:14:11 +02:00
default_prompt_template : question - answering - per - document
2022-12-20 11:21:26 +01:00
type : PromptNode
pipelines :
- name : query
nodes :
- name : p1
inputs :
- Query
- name : p2
inputs :
- p1
"""
)
pipeline = Pipeline . load_from_yaml ( path = tmp_path / " tmp_config_with_prompt_template.yml " )
result = pipeline . run ( query = " not relevant " , documents = [ Document ( " Berlin is a city in Germany. " ) ] )
2023-03-27 12:14:11 +02:00
response = result [ " answers " ] [ 0 ] . answer
2023-01-02 16:28:56 +01:00
assert any ( word for word in [ " berlin " , " germany " , " population " , " city " , " amazing " ] if word in response . casefold ( ) )
2023-01-19 11:17:06 +01:00
assert len ( result [ " invocation_context " ] ) > 0
2023-03-27 12:14:11 +02:00
assert len ( result [ " query " ] ) > 0
assert " query " in result [ " invocation_context " ] and len ( result [ " invocation_context " ] [ " query " ] ) > 0
2023-01-26 13:38:35 +01:00
2023-02-27 15:00:19 +01:00
@pytest.mark.integration
2023-01-26 13:38:35 +01:00
def test_complex_pipeline_with_multiple_same_prompt_node_components_yaml ( tmp_path ) :
# p2 and p3 are essentially the same PromptNode component, make sure we can use them both as is in the pipeline
with open ( tmp_path / " tmp_config.yml " , " w " ) as tmp_file :
tmp_file . write (
f """
version : ignore
components :
- name : p1
params :
default_prompt_template : question - generation
type : PromptNode
- name : p2
params :
2023-03-27 12:14:11 +02:00
default_prompt_template : question - answering - per - document
2023-01-26 13:38:35 +01:00
type : PromptNode
- name : p3
params :
2023-03-27 12:14:11 +02:00
default_prompt_template : question - answering - per - document
2023-01-26 13:38:35 +01:00
type : PromptNode
pipelines :
- name : query
nodes :
- name : p1
inputs :
- Query
- name : p2
inputs :
- p1
- name : p3
inputs :
- p2
"""
)
pipeline = Pipeline . load_from_yaml ( path = tmp_path / " tmp_config.yml " )
assert pipeline is not None
2023-02-08 12:47:52 +01:00
2023-03-03 13:49:21 +01:00
class TestTokenLimit :
@pytest.mark.integration
def test_hf_token_limit_warning ( self , prompt_node , caplog ) :
prompt_template = PromptTemplate (
2023-03-27 12:14:11 +02:00
name = " too-long-temp " , prompt_text = " Repeating text " * 200 + " Docs: {documents} ; Answer: "
2023-03-03 13:49:21 +01:00
)
with caplog . at_level ( logging . WARNING ) :
_ = prompt_node . prompt ( prompt_template , documents = [ " Berlin is an amazing city. " ] )
assert " The prompt has been truncated from 812 tokens to 412 tokens " in caplog . text
2023-03-30 15:27:44 +03:00
assert " and answer length (100 tokens) fit within the max token limit (512 tokens). " in caplog . text
2023-03-03 13:49:21 +01:00
@pytest.mark.integration
@pytest.mark.skipif (
not os . environ . get ( " OPENAI_API_KEY " , None ) ,
reason = " No OpenAI API key provided. Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test. " ,
)
def test_openai_token_limit_warning ( self , caplog ) :
2023-03-27 12:14:11 +02:00
tt = PromptTemplate ( name = " too-long-temp " , prompt_text = " Repeating text " * 200 + " Docs: {documents} ; Answer: " )
2023-03-03 13:49:21 +01:00
prompt_node = PromptNode ( " text-ada-001 " , max_length = 2000 , api_key = os . environ . get ( " OPENAI_API_KEY " , " " ) )
with caplog . at_level ( logging . WARNING ) :
_ = prompt_node . prompt ( tt , documents = [ " Berlin is an amazing city. " ] )
assert " The prompt has been truncated from " in caplog . text
2023-03-30 15:27:44 +03:00
assert " and answer length (2000 tokens) fit within the max token limit (2049 tokens). " in caplog . text
2023-03-03 13:49:21 +01:00
2023-02-20 11:58:13 +01:00
class TestRunBatch :
@pytest.mark.integration
2023-03-02 09:55:09 +01:00
@pytest.mark.parametrize ( " prompt_model " , [ " hf " , " openai " , " azure " ] , indirect = True )
2023-02-20 11:58:13 +01:00
def test_simple_pipeline_batch_no_query_single_doc_list ( self , prompt_model ) :
2023-03-02 09:55:09 +01:00
skip_test_for_invalid_key ( prompt_model )
2023-02-20 11:58:13 +01:00
node = PromptNode ( prompt_model , default_prompt_template = " sentiment-analysis " )
pipe = Pipeline ( )
pipe . add_node ( component = node , name = " prompt_node " , inputs = [ " Query " ] )
result = pipe . run_batch (
queries = None , documents = [ Document ( " Berlin is an amazing city. " ) , Document ( " I am not feeling well. " ) ]
)
assert isinstance ( result [ " results " ] , list )
assert isinstance ( result [ " results " ] [ 0 ] , list )
assert isinstance ( result [ " results " ] [ 0 ] [ 0 ] , str )
assert " positive " in result [ " results " ] [ 0 ] [ 0 ] . casefold ( )
assert " negative " in result [ " results " ] [ 1 ] [ 0 ] . casefold ( )
@pytest.mark.integration
2023-03-02 09:55:09 +01:00
@pytest.mark.parametrize ( " prompt_model " , [ " hf " , " openai " , " azure " ] , indirect = True )
2023-02-20 11:58:13 +01:00
def test_simple_pipeline_batch_no_query_multiple_doc_list ( self , prompt_model ) :
2023-03-02 09:55:09 +01:00
skip_test_for_invalid_key ( prompt_model )
2023-02-20 11:58:13 +01:00
node = PromptNode ( prompt_model , default_prompt_template = " sentiment-analysis " , output_variable = " out " )
pipe = Pipeline ( )
pipe . add_node ( component = node , name = " prompt_node " , inputs = [ " Query " ] )
result = pipe . run_batch (
queries = None ,
documents = [
[ Document ( " Berlin is an amazing city. " ) , Document ( " Paris is an amazing city. " ) ] ,
[ Document ( " I am not feeling well. " ) ] ,
] ,
)
assert isinstance ( result [ " out " ] , list )
assert isinstance ( result [ " out " ] [ 0 ] , list )
assert isinstance ( result [ " out " ] [ 0 ] [ 0 ] , str )
assert all ( " positive " in x . casefold ( ) for x in result [ " out " ] [ 0 ] )
assert " negative " in result [ " out " ] [ 1 ] [ 0 ] . casefold ( )
@pytest.mark.integration
2023-03-02 09:55:09 +01:00
@pytest.mark.parametrize ( " prompt_model " , [ " hf " , " openai " , " azure " ] , indirect = True )
2023-02-20 11:58:13 +01:00
def test_simple_pipeline_batch_query_multiple_doc_list ( self , prompt_model ) :
2023-03-02 09:55:09 +01:00
skip_test_for_invalid_key ( prompt_model )
2023-02-20 11:58:13 +01:00
prompt_template = PromptTemplate (
name = " question-answering-new " ,
2023-03-27 12:14:11 +02:00
prompt_text = " Given the context please answer the question. Context: {documents} ; Question: {query} ; Answer: " ,
2023-02-20 11:58:13 +01:00
)
node = PromptNode ( prompt_model , default_prompt_template = prompt_template )
pipe = Pipeline ( )
pipe . add_node ( component = node , name = " prompt_node " , inputs = [ " Query " ] )
result = pipe . run_batch (
queries = [ " Who lives in Berlin? " ] ,
documents = [
[ Document ( " My name is Carla and I live in Berlin " ) , Document ( " My name is James and I live in London " ) ] ,
[ Document ( " My name is Christelle and I live in Paris " ) ] ,
] ,
debug = True ,
)
assert isinstance ( result [ " results " ] , list )
assert isinstance ( result [ " results " ] [ 0 ] , list )
assert isinstance ( result [ " results " ] [ 0 ] [ 0 ] , str )
2023-03-06 16:34:59 +01:00
@pytest.mark.unit
2023-02-08 12:47:52 +01:00
def test_HFLocalInvocationLayer_supports ( ) :
assert HFLocalInvocationLayer . supports ( " philschmid/flan-t5-base-samsum " )
assert HFLocalInvocationLayer . supports ( " bigscience/T0_3B " )
2023-03-17 14:16:41 +01:00
2023-03-27 12:14:11 +02:00
class TestPromptTemplateSyntax :
@pytest.mark.unit
@pytest.mark.parametrize (
" prompt_text, expected_prompt_params, expected_used_functions " ,
[
( " {documents} " , { " documents " } , set ( ) ) ,
( " Please answer the question: {documents} Question: how? " , { " documents " } , set ( ) ) ,
( " Please answer the question: {documents} Question: {query} " , { " documents " , " query " } , set ( ) ) ,
( " Please answer the question: {documents} {{ Question}}: {query} " , { " documents " , " query " } , set ( ) ) ,
(
" Please answer the question: { join(documents)} Question: { query.replace( ' A ' , ' a ' )} " ,
{ " documents " , " query " } ,
{ " join " , " replace " } ,
) ,
(
" Please answer the question: { join(documents, ' delim ' , { ' { ' : ' ( ' })} Question: { query.replace( ' A ' , ' a ' )} " ,
{ " documents " , " query " } ,
{ " join " , " replace " } ,
) ,
(
' Please answer the question: { join(documents, " delim " , { " { " : " ( " })} Question: { query.replace( " A " , " a " )} ' ,
{ " documents " , " query " } ,
{ " join " , " replace " } ,
) ,
(
" Please answer the question: { join(documents, ' delim ' , { ' a ' : { ' b ' : ' c ' }})} Question: { query.replace( ' A ' , ' a ' )} " ,
{ " documents " , " query " } ,
{ " join " , " replace " } ,
) ,
(
" Please answer the question: { join(document=documents, delimiter= ' delim ' , str_replace= { ' { ' : ' ( ' })} Question: { query.replace( ' A ' , ' a ' )} " ,
{ " documents " , " query " } ,
{ " join " , " replace " } ,
) ,
] ,
)
def test_prompt_template_syntax_parser (
self , prompt_text : str , expected_prompt_params : Set [ str ] , expected_used_functions : Set [ str ]
) :
prompt_template = PromptTemplate ( name = " test " , prompt_text = prompt_text )
assert set ( prompt_template . prompt_params ) == expected_prompt_params
assert set ( prompt_template . _used_functions ) == expected_used_functions
@pytest.mark.unit
@pytest.mark.parametrize (
" prompt_text, documents, query, expected_prompts " ,
[
( " {documents} " , [ Document ( " doc1 " ) , Document ( " doc2 " ) ] , None , [ " doc1 " , " doc2 " ] ) ,
(
" context: {documents} question: how? " ,
[ Document ( " doc1 " ) , Document ( " doc2 " ) ] ,
None ,
[ " context: doc1 question: how? " , " context: doc2 question: how? " ] ,
) ,
(
" context: { ' ' .join([d.content for d in documents])} question: how? " ,
[ Document ( " doc1 " ) , Document ( " doc2 " ) ] ,
None ,
[ " context: doc1 doc2 question: how? " ] ,
) ,
(
" context: {documents} question: {query} " ,
[ Document ( " doc1 " ) , Document ( " doc2 " ) ] ,
" how? " ,
[ " context: doc1 question: how? " , " context: doc2 question: how? " ] ,
) ,
(
" context: {documents} {{ question}}: {query} " ,
[ Document ( " doc1 " ) ] ,
" how? " ,
[ " context: doc1 {question} : how? " ] ,
) ,
(
" context: { join(documents)} question: {query} " ,
[ Document ( " doc1 " ) , Document ( " doc2 " ) ] ,
" how? " ,
[ " context: doc1 doc2 question: how? " ] ,
) ,
(
" Please answer the question: { join(documents, ' delim ' , ' [$idx] $content ' , { ' { ' : ' ( ' })} question: {query} " ,
[ Document ( " doc1 " ) , Document ( " doc2 " ) ] ,
" how? " ,
[ " Please answer the question: [1] doc1 delim [2] doc2 question: how? " ] ,
) ,
(
" Please answer the question: { join(documents=documents, delimiter= ' delim ' , pattern= ' [$idx] $content ' , str_replace= { ' { ' : ' ( ' })} question: {query} " ,
[ Document ( " doc1 " ) , Document ( " doc2 " ) ] ,
" how? " ,
[ " Please answer the question: [1] doc1 delim [2] doc2 question: how? " ] ,
) ,
(
" Please answer the question: { ' delim ' .join([ ' [ ' +str(idx+1)+ ' ] ' +d.content.replace( ' { ' , ' ( ' ) for idx, d in enumerate(documents)])} question: {query} " ,
[ Document ( " doc1 " ) , Document ( " doc2 " ) ] ,
" how? " ,
[ " Please answer the question: [1] doc1 delim [2] doc2 question: how? " ] ,
) ,
(
' Please answer the question: { join(documents, " delim " , " [$idx] $content " , { " { " : " ( " })} question: {query} ' ,
[ Document ( " doc1 " ) , Document ( " doc2 " ) ] ,
" how? " ,
[ " Please answer the question: [1] doc1 delim [2] doc2 question: how? " ] ,
) ,
(
" context: { join(documents)} question: { query.replace( ' how ' , ' what ' )} " ,
[ Document ( " doc1 " ) , Document ( " doc2 " ) ] ,
" how? " ,
[ " context: doc1 doc2 question: what? " ] ,
) ,
(
" context: { join(documents)[:6]} question: { query.replace( ' how ' , ' what ' ).replace( ' ? ' , ' ! ' )} " ,
[ Document ( " doc1 " ) , Document ( " doc2 " ) ] ,
" how? " ,
[ " context: doc1 d question: what! " ] ,
) ,
( " context " , None , None , [ " context " ] ) ,
] ,
)
def test_prompt_template_syntax_fill (
self , prompt_text : str , documents : List [ Document ] , query : str , expected_prompts : List [ str ]
) :
prompt_template = PromptTemplate ( name = " test " , prompt_text = prompt_text )
prompts = [ prompt for prompt in prompt_template . fill ( documents = documents , query = query ) ]
assert prompts == expected_prompts
@pytest.mark.unit
@pytest.mark.parametrize (
" prompt_text, documents, expected_prompts " ,
[
( " { join(documents)} " , [ Document ( " doc1 " ) , Document ( " doc2 " ) ] , [ " doc1 doc2 " ] ) ,
(
" { join(documents, ' delim ' , ' [$idx] $content ' , { ' c ' : ' C ' })} " ,
[ Document ( " doc1 " ) , Document ( " doc2 " ) ] ,
[ " [1] doC1 delim [2] doC2 " ] ,
) ,
(
" { join(documents, ' delim ' , ' [$id] $content ' , { ' c ' : ' C ' })} " ,
[ Document ( " doc1 " , id = " 123 " ) , Document ( " doc2 " , id = " 456 " ) ] ,
[ " [123] doC1 delim [456] doC2 " ] ,
) ,
(
" { join(documents, ' delim ' , ' [$file_id] $content ' , { ' c ' : ' C ' })} " ,
[ Document ( " doc1 " , meta = { " file_id " : " 123.txt " } ) , Document ( " doc2 " , meta = { " file_id " : " 456.txt " } ) ] ,
[ " [123.txt] doC1 delim [456.txt] doC2 " ] ,
) ,
] ,
)
def test_join ( self , prompt_text : str , documents : List [ Document ] , expected_prompts : List [ str ] ) :
prompt_template = PromptTemplate ( name = " test " , prompt_text = prompt_text )
prompts = [ prompt for prompt in prompt_template . fill ( documents = documents ) ]
assert prompts == expected_prompts
@pytest.mark.unit
@pytest.mark.parametrize (
" prompt_text, documents, expected_prompts " ,
[
( " { to_strings(documents)} " , [ Document ( " doc1 " ) , Document ( " doc2 " ) ] , [ " doc1 " , " doc2 " ] ) ,
(
" { to_strings(documents, ' [$idx] $content ' , { ' c ' : ' C ' })} " ,
[ Document ( " doc1 " ) , Document ( " doc2 " ) ] ,
[ " [1] doC1 " , " [2] doC2 " ] ,
) ,
(
" { to_strings(documents, ' [$id] $content ' , { ' c ' : ' C ' })} " ,
[ Document ( " doc1 " , id = " 123 " ) , Document ( " doc2 " , id = " 456 " ) ] ,
[ " [123] doC1 " , " [456] doC2 " ] ,
) ,
(
" { to_strings(documents, ' [$file_id] $content ' , { ' c ' : ' C ' })} " ,
[ Document ( " doc1 " , meta = { " file_id " : " 123.txt " } ) , Document ( " doc2 " , meta = { " file_id " : " 456.txt " } ) ] ,
[ " [123.txt] doC1 " , " [456.txt] doC2 " ] ,
) ,
( " { to_strings(documents, ' [$file_id] $content ' , { ' c ' : ' C ' })} " , [ " doc1 " , " doc2 " ] , [ " doC1 " , " doC2 " ] ) ,
(
" { to_strings(documents, ' [$idx] $answer ' , { ' c ' : ' C ' })} " ,
[ Answer ( " doc1 " ) , Answer ( " doc2 " ) ] ,
[ " [1] doC1 " , " [2] doC2 " ] ,
) ,
] ,
)
def test_to_strings ( self , prompt_text : str , documents : List [ Document ] , expected_prompts : List [ str ] ) :
prompt_template = PromptTemplate ( name = " test " , prompt_text = prompt_text )
prompts = [ prompt for prompt in prompt_template . fill ( documents = documents ) ]
assert prompts == expected_prompts
@pytest.mark.unit
@pytest.mark.parametrize (
" prompt_text, exc_type, expected_exc_match " ,
[
( " { __import__( ' os ' ).listdir( ' . ' )} " , PromptTemplateValidationError , " Invalid function in prompt text " ) ,
( " { __import__( ' os ' )} " , PromptTemplateValidationError , " Invalid function in prompt text " ) ,
(
" { requests.get( ' https://haystack.deepset.ai/ ' )} " ,
PromptTemplateValidationError ,
" Invalid function in prompt text " ,
) ,
( " { join(__import__( ' os ' ).listdir( ' . ' ))} " , PromptTemplateValidationError , " Invalid function in prompt text " ) ,
( " {for} " , SyntaxError , " invalid syntax " ) ,
( " This is an invalid { variable . " , SyntaxError , " f-string: expecting ' } ' " ) ,
] ,
)
def test_prompt_template_syntax_init_raises (
self , prompt_text : str , exc_type : Type [ BaseException ] , expected_exc_match : str
) :
with pytest . raises ( exc_type , match = expected_exc_match ) :
PromptTemplate ( name = " test " , prompt_text = prompt_text )
@pytest.mark.unit
@pytest.mark.parametrize (
" prompt_text, documents, query, exc_type, expected_exc_match " ,
[ ( " {join} " , None , None , ValueError , " Expected prompt parameters " ) ] ,
)
def test_prompt_template_syntax_fill_raises (
self ,
prompt_text : str ,
documents : List [ Document ] ,
query : str ,
exc_type : Type [ BaseException ] ,
expected_exc_match : str ,
) :
with pytest . raises ( exc_type , match = expected_exc_match ) :
prompt_template = PromptTemplate ( name = " test " , prompt_text = prompt_text )
next ( prompt_template . fill ( documents = documents , query = query ) )
@pytest.mark.unit
@pytest.mark.parametrize (
" prompt_text, documents, query, expected_prompts " ,
[
( " __import__( ' os ' ).listdir( ' . ' ) " , None , None , [ " __import__( ' os ' ).listdir( ' . ' ) " ] ) ,
(
" requests.get( ' https://haystack.deepset.ai/ ' ) " ,
None ,
None ,
[ " requests.get( ' https://haystack.deepset.ai/ ' ) " ] ,
) ,
( " {query} " , None , print , [ " <built-in function print> " ] ) ,
( " \b \b __import__( ' os ' ).listdir( ' . ' ) " , None , None , [ " \x08 \x08 __import__( ' os ' ).listdir( ' . ' ) " ] ) ,
] ,
)
def test_prompt_template_syntax_fill_ignores_dangerous_input (
self , prompt_text : str , documents : List [ Document ] , query : str , expected_prompts : List [ str ]
) :
prompt_template = PromptTemplate ( name = " test " , prompt_text = prompt_text )
prompts = [ prompt for prompt in prompt_template . fill ( documents = documents , query = query ) ]
assert prompts == expected_prompts
2023-03-17 14:16:41 +01:00
@pytest.mark.integration
def test_chatgpt_direct_prompting ( chatgpt_prompt_model ) :
skip_test_for_invalid_key ( chatgpt_prompt_model )
pn = PromptNode ( chatgpt_prompt_model )
result = pn ( " Hey, I need some Python help. When should I use list comprehension? " )
assert len ( result ) == 1 and all ( w in result [ 0 ] for w in [ " comprehension " , " list " ] )
@pytest.mark.integration
def test_chatgpt_direct_prompting_w_messages ( chatgpt_prompt_model ) :
skip_test_for_invalid_key ( chatgpt_prompt_model )
pn = PromptNode ( chatgpt_prompt_model )
messages = [
{ " role " : " system " , " content " : " You are a helpful assistant. " } ,
{ " role " : " user " , " content " : " Who won the world series in 2020? " } ,
{ " role " : " assistant " , " content " : " The Los Angeles Dodgers won the World Series in 2020. " } ,
{ " role " : " user " , " content " : " Where was it played? " } ,
]
result = pn ( messages )
assert len ( result ) == 1 and all ( w in result [ 0 ] . casefold ( ) for w in [ " arlington " , " texas " ] )
@pytest.mark.integration
@pytest.mark.skipif (
not os . environ . get ( " OPENAI_API_KEY " , None ) ,
reason = " No OpenAI API key provided. Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test. " ,
)
def test_chatgpt_promptnode ( ) :
pn = PromptNode ( model_name_or_path = " gpt-3.5-turbo " , api_key = os . environ . get ( " OPENAI_API_KEY " , None ) )
result = pn ( " Hey, I need some Python help. When should I use list comprehension? " )
assert len ( result ) == 1 and all ( w in result [ 0 ] for w in [ " comprehension " , " list " ] )
messages = [
{ " role " : " system " , " content " : " You are a helpful assistant. " } ,
{ " role " : " user " , " content " : " Who won the world series in 2020? " } ,
{ " role " : " assistant " , " content " : " The Los Angeles Dodgers won the World Series in 2020. " } ,
{ " role " : " user " , " content " : " Where was it played? " } ,
]
result = pn ( messages )
assert len ( result ) == 1 and all ( w in result [ 0 ] . casefold ( ) for w in [ " arlington " , " texas " ] )