import os import pytest import torch from haystack import Document, Pipeline from haystack.errors import OpenAIError from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel def is_openai_api_key_set(api_key: str): return len(api_key) > 0 and api_key != "KEY_NOT_FOUND" def test_prompt_templates(): p = PromptTemplate("t1", "Here is some fake template with variable $foo", ["foo"]) with pytest.raises(ValueError, match="Number of parameters in"): PromptTemplate("t2", "Here is some fake template with variable $foo and $bar", ["foo"]) with pytest.raises(ValueError, match="Invalid parameter"): PromptTemplate("t2", "Here is some fake template with variable $footur", ["foo"]) with pytest.raises(ValueError, match="Number of parameters in"): PromptTemplate("t2", "Here is some fake template with variable $foo and $bar", ["foo", "bar", "baz"]) p = PromptTemplate("t3", "Here is some fake template with variable $for and $bar", ["for", "bar"]) # last parameter: "prompt_params" can be omitted p = PromptTemplate("t4", "Here is some fake template with variable $foo and $bar") assert p.prompt_params == ["foo", "bar"] p = PromptTemplate("t4", "Here is some fake template with variable $foo1 and $bar2") assert p.prompt_params == ["foo1", "bar2"] p = PromptTemplate("t4", "Here is some fake template with variable $foo_1 and $bar_2") assert p.prompt_params == ["foo_1", "bar_2"] p = PromptTemplate("t4", "Here is some fake template with variable $Foo_1 and $Bar_2") assert p.prompt_params == ["Foo_1", "Bar_2"] p = PromptTemplate("t4", "'Here is some fake template with variable $baz'") assert p.prompt_params == ["baz"] # strip single quotes, happens in YAML as we need to use single quotes for the template string assert p.prompt_text == "Here is some fake template with variable $baz" p = PromptTemplate("t4", '"Here is some fake template with variable $baz"') assert p.prompt_params == ["baz"] # strip double quotes, happens in YAML as we need to use single quotes for the template string assert p.prompt_text == "Here is some fake template with variable $baz" def test_create_prompt_model(): model = PromptModel("google/flan-t5-small") assert model.model_name_or_path == "google/flan-t5-small" model = PromptModel() assert model.model_name_or_path == "google/flan-t5-base" with pytest.raises(OpenAIError): # davinci selected but no API key provided model = PromptModel("text-davinci-003") model = PromptModel("text-davinci-003", api_key="no need to provide a real key") assert model.model_name_or_path == "text-davinci-003" with pytest.raises(ValueError, match="Model some-random-model is not supported"): PromptModel("some-random-model") # we can also pass model kwargs to the PromptModel model = PromptModel("google/flan-t5-small", model_kwargs={"model_kwargs": {"torch_dtype": torch.bfloat16}}) assert model.model_name_or_path == "google/flan-t5-small" # we can also pass kwargs directly, see HF Pipeline constructor model = PromptModel("google/flan-t5-small", model_kwargs={"torch_dtype": torch.bfloat16}) assert model.model_name_or_path == "google/flan-t5-small" # we can't use device_map auto without accelerate library installed with pytest.raises(ImportError, match="requires Accelerate: `pip install accelerate`"): model = PromptModel("google/flan-t5-small", model_kwargs={"device_map": "auto"}) assert model.model_name_or_path == "google/flan-t5-small" def test_create_prompt_node(): prompt_node = PromptNode() assert prompt_node is not None assert prompt_node.prompt_model is not None prompt_node = PromptNode("google/flan-t5-small") assert prompt_node is not None assert prompt_node.model_name_or_path == "google/flan-t5-small" assert prompt_node.prompt_model is not None with pytest.raises(OpenAIError): # davinci selected but no API key provided prompt_node = PromptNode("text-davinci-003") prompt_node = PromptNode("text-davinci-003", api_key="no need to provide a real key") assert prompt_node is not None assert prompt_node.model_name_or_path == "text-davinci-003" assert prompt_node.prompt_model is not None with pytest.raises(ValueError, match="Model vblagoje/bart_lfqa is not supported"): # yes vblagoje/bart_lfqa is AutoModelForSeq2SeqLM, can be downloaded, however it is useless for prompting # currently support only T5-Flan models prompt_node = PromptNode("vblagoje/bart_lfqa") with pytest.raises(ValueError, match="Model valhalla/t5-base-e2e-qg is not supported"): # yes valhalla/t5-base-e2e-qg is AutoModelForSeq2SeqLM, can be downloaded, however it is useless for prompting # currently support only T5-Flan models prompt_node = PromptNode("valhalla/t5-base-e2e-qg") with pytest.raises(ValueError, match="Model some-random-model is not supported"): PromptNode("some-random-model") def test_add_and_remove_template(prompt_node): num_default_tasks = len(prompt_node.get_prompt_template_names()) custom_task = PromptTemplate( name="custom-task", prompt_text="Custom task: $param1, $param2", prompt_params=["param1", "param2"] ) prompt_node.add_prompt_template(custom_task) assert len(prompt_node.get_prompt_template_names()) == num_default_tasks + 1 assert "custom-task" in prompt_node.get_prompt_template_names() assert prompt_node.remove_prompt_template("custom-task") is not None assert "custom-task" not in prompt_node.get_prompt_template_names() def test_invalid_template(prompt_node): with pytest.raises(ValueError, match="Invalid parameter"): PromptTemplate( name="custom-task", prompt_text="Custom task: $pram1 $param2", prompt_params=["param1", "param2"] ) with pytest.raises(ValueError, match="Number of parameters"): PromptTemplate(name="custom-task", prompt_text="Custom task: $param1", prompt_params=["param1", "param2"]) def test_add_template_and_invoke(prompt_node): tt = PromptTemplate( name="sentiment-analysis-new", prompt_text="Please give a sentiment for this context. Answer with positive, " "negative or neutral. Context: $documents; Answer:", prompt_params=["documents"], ) prompt_node.add_prompt_template(tt) r = prompt_node.prompt("sentiment-analysis-new", documents=["Berlin is an amazing city."]) assert r[0].casefold() == "positive" def test_on_the_fly_prompt(prompt_node): tt = PromptTemplate( name="sentiment-analysis-temp", prompt_text="Please give a sentiment for this context. Answer with positive, " "negative or neutral. Context: $documents; Answer:", prompt_params=["documents"], ) r = prompt_node.prompt(tt, documents=["Berlin is an amazing city."]) assert r[0].casefold() == "positive" def test_direct_prompting(prompt_node): r = prompt_node("What is the capital of Germany?") assert r[0].casefold() == "berlin" r = prompt_node("What is the capital of Germany?", "What is the secret of universe?") assert r[0].casefold() == "berlin" assert len(r[1]) > 0 r = prompt_node("Capital of Germany is Berlin", task="question-generation") assert len(r[0]) > 10 and "Germany" in r[0] r = prompt_node(["Capital of Germany is Berlin", "Capital of France is Paris"], task="question-generation") assert len(r) == 2 def test_question_generation(prompt_node): r = prompt_node.prompt("question-generation", documents=["Berlin is the capital of Germany."]) assert len(r) == 1 and len(r[0]) > 0 def test_template_selection(prompt_node): qa = prompt_node.set_default_prompt_template("question-answering") r = qa( ["Berlin is the capital of Germany.", "Paris is the capital of France."], ["What is the capital of Germany?", "What is the capital of France"], ) assert r[0].casefold() == "berlin" and r[1].casefold() == "paris" def test_has_supported_template_names(prompt_node): assert len(prompt_node.get_prompt_template_names()) > 0 def test_invalid_template_params(prompt_node): with pytest.raises(ValueError, match="Expected prompt params"): prompt_node.prompt("question-answering", {"some_crazy_key": "Berlin is the capital of Germany."}) def test_wrong_template_params(prompt_node): with pytest.raises(ValueError, match="Expected prompt params"): # with don't have options param, multiple choice QA has prompt_node.prompt("question-answering", options=["Berlin is the capital of Germany."]) def test_run_invalid_template(prompt_node): with pytest.raises(ValueError, match="invalid-task not supported"): prompt_node.prompt("invalid-task", {}) def test_invalid_prompting(prompt_node): with pytest.raises(ValueError, match="Hey there, what is the best city in the worl"): prompt_node.prompt( "Hey there, what is the best city in the world?" "Hey there, what is the best city in the world?" ) with pytest.raises(ValueError, match="Hey there, what is the best city in the"): prompt_node.prompt(["Hey there, what is the best city in the world?", "Hey, answer me!"]) def test_invalid_state_ops(prompt_node): with pytest.raises(ValueError, match="Prompt template no_such_task_exists"): prompt_node.remove_prompt_template("no_such_task_exists") # remove default task prompt_node.remove_prompt_template("question-answering") @pytest.mark.skipif( not os.environ.get("OPENAI_API_KEY", None), reason="Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", ) def test_open_ai_prompt_with_params(): pm = PromptModel("text-davinci-003", api_key=os.environ["OPENAI_API_KEY"]) pn = PromptNode(pm) optional_davinci_params = {"temperature": 0.5, "max_tokens": 10, "top_p": 1, "frequency_penalty": 0.5} r = pn.prompt("question-generation", documents=["Berlin is the capital of Germany."], **optional_davinci_params) assert len(r) == 1 and len(r[0]) > 0 @pytest.mark.parametrize("prompt_model", ["hf", "openai"], indirect=True) def test_simple_pipeline(prompt_model): if prompt_model.api_key is not None and not is_openai_api_key_set(prompt_model.api_key): pytest.skip("No API key found for OpenAI, skipping test") node = PromptNode(prompt_model, default_prompt_template="sentiment-analysis") pipe = Pipeline() pipe.add_node(component=node, name="prompt_node", inputs=["Query"]) result = pipe.run(query="not relevant", documents=[Document("Berlin is an amazing city.")]) assert result["results"][0].casefold() == "positive" @pytest.mark.parametrize("prompt_model", ["hf", "openai"], indirect=True) def test_complex_pipeline(prompt_model): if prompt_model.api_key is not None and not is_openai_api_key_set(prompt_model.api_key): pytest.skip("No API key found for OpenAI, skipping test") node = PromptNode(prompt_model, default_prompt_template="question-generation", output_variable="questions") node2 = PromptNode(prompt_model, default_prompt_template="question-answering") pipe = Pipeline() pipe.add_node(component=node, name="prompt_node", inputs=["Query"]) pipe.add_node(component=node2, name="prompt_node_2", inputs=["prompt_node"]) result = pipe.run(query="not relevant", documents=[Document("Berlin is the capital of Germany")]) assert "berlin" in result["results"][0].casefold() def test_complex_pipeline_with_shared_model(): model = PromptModel() node = PromptNode( model_name_or_path=model, default_prompt_template="question-generation", output_variable="questions" ) node2 = PromptNode(model_name_or_path=model, default_prompt_template="question-answering") pipe = Pipeline() pipe.add_node(component=node, name="prompt_node", inputs=["Query"]) pipe.add_node(component=node2, name="prompt_node_2", inputs=["prompt_node"]) result = pipe.run(query="not relevant", documents=[Document("Berlin is the capital of Germany")]) assert result["results"][0] == "Berlin" def test_simple_pipeline_yaml(tmp_path): with open(tmp_path / "tmp_config.yml", "w") as tmp_file: tmp_file.write( f""" version: ignore components: - name: p1 params: default_prompt_template: sentiment-analysis type: PromptNode pipelines: - name: query nodes: - name: p1 inputs: - Query """ ) pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml") result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")]) assert result["results"][0] == "positive" def test_complex_pipeline_yaml(tmp_path): with open(tmp_path / "tmp_config.yml", "w") as tmp_file: tmp_file.write( f""" version: ignore components: - name: p1 params: default_prompt_template: question-generation output_variable: questions type: PromptNode - name: p2 params: default_prompt_template: question-answering type: PromptNode pipelines: - name: query nodes: - name: p1 inputs: - Query - name: p2 inputs: - p1 """ ) pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml") result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")]) assert result["results"][0] == "Berlin" assert len(result["meta"]["invocation_context"]) > 0 def test_complex_pipeline_with_shared_prompt_model_yaml(tmp_path): with open(tmp_path / "tmp_config.yml", "w") as tmp_file: tmp_file.write( f""" version: ignore components: - name: pmodel type: PromptModel - name: p1 params: model_name_or_path: pmodel default_prompt_template: question-generation output_variable: questions type: PromptNode - name: p2 params: model_name_or_path: pmodel default_prompt_template: question-answering type: PromptNode pipelines: - name: query nodes: - name: p1 inputs: - Query - name: p2 inputs: - p1 """ ) pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml") result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")]) assert "Berlin" in result["results"][0] assert len(result["meta"]["invocation_context"]) > 0 def test_complex_pipeline_with_shared_prompt_model_and_prompt_template_yaml(tmp_path): with open(tmp_path / "tmp_config_with_prompt_template.yml", "w") as tmp_file: tmp_file.write( f""" version: ignore components: - name: pmodel type: PromptModel params: model_name_or_path: google/flan-t5-small model_kwargs: torch_dtype: torch.bfloat16 - name: question_generation_template type: PromptTemplate params: name: question-generation-new prompt_text: "Given the context please generate a question. Context: $documents; Question:" - name: p1 params: model_name_or_path: pmodel default_prompt_template: question_generation_template output_variable: questions type: PromptNode - name: p2 params: model_name_or_path: pmodel default_prompt_template: question-answering type: PromptNode pipelines: - name: query nodes: - name: p1 inputs: - Query - name: p2 inputs: - p1 """ ) pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config_with_prompt_template.yml") result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")]) assert "Berlin" in result["results"][0] assert len(result["meta"]["invocation_context"]) > 0 @pytest.mark.skipif( not os.environ.get("OPENAI_API_KEY", None), reason="Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", ) def test_complex_pipeline_with_all_features(tmp_path): api_key = os.environ.get("OPENAI_API_KEY", None) with open(tmp_path / "tmp_config_with_prompt_template.yml", "w") as tmp_file: tmp_file.write( f""" version: ignore components: - name: pmodel type: PromptModel params: model_name_or_path: google/flan-t5-small model_kwargs: torch_dtype: torch.bfloat16 - name: pmodel_openai type: PromptModel params: model_name_or_path: text-davinci-003 model_kwargs: temperature: 0.9 max_tokens: 64 api_key: {api_key} - name: question_generation_template type: PromptTemplate params: name: question-generation-new prompt_text: "Given the context please generate a question. Context: $documents; Question:" - name: p1 params: model_name_or_path: pmodel_openai default_prompt_template: question_generation_template output_variable: questions type: PromptNode - name: p2 params: model_name_or_path: pmodel default_prompt_template: question-answering type: PromptNode pipelines: - name: query nodes: - name: p1 inputs: - Query - name: p2 inputs: - p1 """ ) pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config_with_prompt_template.yml") result = pipeline.run(query="not relevant", documents=[Document("Berlin is a city in Germany.")]) assert "Berlin" in result["results"][0] or "Germany" in result["results"][0] assert len(result["meta"]["invocation_context"]) > 0