diff --git a/haystack/nodes/answer_generator/openai.py b/haystack/nodes/answer_generator/openai.py index 6cad70994..276cf526c 100644 --- a/haystack/nodes/answer_generator/openai.py +++ b/haystack/nodes/answer_generator/openai.py @@ -138,7 +138,7 @@ class OpenAIAnswerGenerator(BaseGenerator): else: # Check for required prompts required_params = ["context", "query"] - if all(p in prompt_template.prompt_params for p in required_params): + if not all(p in prompt_template.prompt_params for p in required_params): raise ValueError( "The OpenAIAnswerGenerator requires a PromptTemplate that has `context` and " "`query` in its `prompt_params`. Supply a different `prompt_template` or " diff --git a/test/nodes/test_generator.py b/test/nodes/test_generator.py index 8af0bc7e4..22f4605a6 100644 --- a/test/nodes/test_generator.py +++ b/test/nodes/test_generator.py @@ -7,6 +7,7 @@ import pytest from haystack.schema import Document from haystack.nodes.answer_generator import Seq2SeqGenerator, OpenAIAnswerGenerator from haystack.pipelines import TranslationWrapperPipeline, GenerativeQAPipeline +from haystack.nodes import PromptTemplate import logging @@ -142,6 +143,26 @@ def test_openai_answer_generator(openai_generator, docs): assert "Carla" in prediction["answers"][0].answer +@pytest.mark.integration +@pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY", None), + reason="No OpenAI API key provided. Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", +) +def test_openai_answer_generator_custom_template(docs): + lfqa_prompt = PromptTemplate( + name="lfqa", + prompt_text=""" + Synthesize a comprehensive answer from your knowledge and the following topk most relevant paragraphs and the given question. + \n===\Paragraphs: $context\n===\n$query""", + prompt_params=["context", "query"], + ) + node = OpenAIAnswerGenerator( + api_key=os.environ.get("OPENAI_API_KEY", ""), model="text-babbage-001", top_k=1, prompt_template=lfqa_prompt + ) + prediction = node.predict(query="Who lives in Berlin?", documents=docs, top_k=1) + assert len(prediction["answers"]) == 1 + + @pytest.mark.integration @pytest.mark.skipif( not os.environ.get("OPENAI_API_KEY", None),