diff --git a/haystack/nodes/answer_generator/openai.py b/haystack/nodes/answer_generator/openai.py index 5954134c5..818931a77 100644 --- a/haystack/nodes/answer_generator/openai.py +++ b/haystack/nodes/answer_generator/openai.py @@ -36,7 +36,7 @@ class OpenAIAnswerGenerator(BaseGenerator): ): """ - :param api_key: Your API key from OpenAI + :param api_key: Your API key from OpenAI. It is required for this node to work. :param model: ID of the engine to use for generating the answer. You can select one of `"text-ada-001"`, `"text-babbage-001"`, `"text-curie-001"`, or `"text-davinci-002"` (from worst to best + cheapest to most expensive). Please refer to the @@ -71,6 +71,10 @@ class OpenAIAnswerGenerator(BaseGenerator): if not stop_words: stop_words = ["\n", "<|endoftext|>"] + if not api_key: + raise ValueError("OpenAIAnswerGenerator requires an API key.") + + self.api_key = api_key self.model = model self.max_tokens = max_tokens self.top_k = top_k @@ -79,7 +83,6 @@ class OpenAIAnswerGenerator(BaseGenerator): self.frequency_penalty = frequency_penalty self.examples_context = examples_context self.examples = examples - self.api_key = api_key self.stop_words = stop_words self._tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") diff --git a/test/nodes/test_generator.py b/test/nodes/test_generator.py index f76a111ed..23c1a9aba 100644 --- a/test/nodes/test_generator.py +++ b/test/nodes/test_generator.py @@ -128,10 +128,11 @@ def test_lfqa_pipeline_invalid_converter(document_store, retriever, docs_with_tr @pytest.mark.integration +@pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY", None), + reason="No OpenAI API key provided. Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", +) def test_openai_answer_generator(openai_generator, docs): - if "OPENAI_API_KEY" in os.environ: - prediction = openai_generator.predict(query="Who lives in Berlin?", documents=docs, top_k=1) - assert len(prediction["answers"]) == 1 - assert "Carla" in prediction["answers"][0].answer - else: - pytest.skip("No API key provided in environment variables.") + prediction = openai_generator.predict(query="Who lives in Berlin?", documents=docs, top_k=1) + assert len(prediction["answers"]) == 1 + assert "Carla" in prediction["answers"][0].answer