API key check in OpenAIAnswerGenerator (#2791)

* api key check in node and tests

* Clarify skip message

* Update Documentation & Code Style

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Sara Zan 2022-07-12 14:05:47 +02:00 committed by GitHub
parent 4d2a06989d
commit d8e7aaeacc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 8 deletions

View File

@ -36,7 +36,7 @@ class OpenAIAnswerGenerator(BaseGenerator):
):
"""
:param api_key: Your API key from OpenAI
:param api_key: Your API key from OpenAI. It is required for this node to work.
:param model: ID of the engine to use for generating the answer. You can select one of `"text-ada-001"`,
`"text-babbage-001"`, `"text-curie-001"`, or `"text-davinci-002"`
(from worst to best + cheapest to most expensive). Please refer to the
@ -71,6 +71,10 @@ class OpenAIAnswerGenerator(BaseGenerator):
if not stop_words:
stop_words = ["\n", "<|endoftext|>"]
if not api_key:
raise ValueError("OpenAIAnswerGenerator requires an API key.")
self.api_key = api_key
self.model = model
self.max_tokens = max_tokens
self.top_k = top_k
@ -79,7 +83,6 @@ class OpenAIAnswerGenerator(BaseGenerator):
self.frequency_penalty = frequency_penalty
self.examples_context = examples_context
self.examples = examples
self.api_key = api_key
self.stop_words = stop_words
self._tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")

View File

@ -128,10 +128,11 @@ def test_lfqa_pipeline_invalid_converter(document_store, retriever, docs_with_tr
@pytest.mark.integration
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="No OpenAI API key provided. Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
def test_openai_answer_generator(openai_generator, docs):
if "OPENAI_API_KEY" in os.environ:
prediction = openai_generator.predict(query="Who lives in Berlin?", documents=docs, top_k=1)
assert len(prediction["answers"]) == 1
assert "Carla" in prediction["answers"][0].answer
else:
pytest.skip("No API key provided in environment variables.")
prediction = openai_generator.predict(query="Who lives in Berlin?", documents=docs, top_k=1)
assert len(prediction["answers"]) == 1
assert "Carla" in prediction["answers"][0].answer