From 90ff3817e75a98cc1715daf5e083d7edabca93af Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com> Date: Fri, 7 Jul 2023 12:02:21 +0200 Subject: [PATCH] feat: support `OpenAI-Organization` for authentication (#5292) * add openai_organization to invocation layer, generator and retriever * added tests --- haystack/nodes/answer_generator/openai.py | 6 +++++ .../nodes/prompt/invocation_layer/open_ai.py | 9 ++++++- haystack/nodes/retriever/_openai_encoder.py | 3 +++ haystack/nodes/retriever/dense.py | 6 ++++- test/nodes/test_generator.py | 22 ++++++++++++++++ test/nodes/test_retriever.py | 24 +++++++++++++++++ test/prompt/invocation_layer/test_openai.py | 26 +++++++++++++++++++ 7 files changed, 94 insertions(+), 2 deletions(-) diff --git a/haystack/nodes/answer_generator/openai.py b/haystack/nodes/answer_generator/openai.py index 4185f4c06..ae1ac362e 100644 --- a/haystack/nodes/answer_generator/openai.py +++ b/haystack/nodes/answer_generator/openai.py @@ -48,6 +48,7 @@ class OpenAIAnswerGenerator(BaseGenerator): context_join_str: str = " ", moderate_content: bool = False, api_base: str = "https://api.openai.com/v1", + openai_organization: Optional[str] = None, ): """ :param api_key: Your API key from OpenAI. It is required for this node to work. @@ -105,6 +106,8 @@ class OpenAIAnswerGenerator(BaseGenerator): using the [OpenAI Moderation API](https://platform.openai.com/docs/guides/moderation). If the input or answers are flagged, an empty list is returned in place of the answers. :param api_base: The base URL for the OpenAI API, defaults to `"https://api.openai.com/v1"`. + :param openai_organization: The OpenAI-Organization ID, defaults to `None`. For more details, see see OpenAI + [documentation](https://platform.openai.com/docs/api-reference/requesting-organization). """ super().__init__(progress_bar=progress_bar) if (examples is None and examples_context is not None) or (examples is not None and examples_context is None): @@ -165,6 +168,7 @@ class OpenAIAnswerGenerator(BaseGenerator): self.context_join_str = context_join_str self.using_azure = self.azure_deployment_name is not None and self.azure_base_url is not None self.moderate_content = moderate_content + self.openai_organization = openai_organization tokenizer_name, max_tokens_limit = _openai_text_completion_tokenization_details(model_name=self.model) @@ -233,6 +237,8 @@ class OpenAIAnswerGenerator(BaseGenerator): headers["api-key"] = self.api_key else: headers["Authorization"] = f"Bearer {self.api_key}" + if self.openai_organization: + headers["OpenAI-Organization"] = self.openai_organization if self.moderate_content and check_openai_policy_violation(input=prompt, headers=headers): logger.info("Prompt '%s' will not be sent to OpenAI due to potential policy violation.", prompt) diff --git a/haystack/nodes/prompt/invocation_layer/open_ai.py b/haystack/nodes/prompt/invocation_layer/open_ai.py index b530a3aa3..7388d51f7 100644 --- a/haystack/nodes/prompt/invocation_layer/open_ai.py +++ b/haystack/nodes/prompt/invocation_layer/open_ai.py @@ -34,6 +34,7 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer): model_name_or_path: str = "text-davinci-003", max_length: Optional[int] = 100, api_base: str = "https://api.openai.com/v1", + openai_organization: Optional[str] = None, **kwargs, ): """ @@ -43,6 +44,8 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer): :param max_length: The maximum number of tokens the output text can have. :param api_key: The OpenAI API key. :param api_base: The OpenAI API Base url, defaults to `https://api.openai.com/v1`. + :param openai_organization: The OpenAI-Organization ID, defaults to `None`. For more details, see see OpenAI + [documentation](https://platform.openai.com/docs/api-reference/requesting-organization). :param kwargs: Additional keyword arguments passed to the underlying model. Due to reflective construction of all PromptModelInvocationLayer instances, this instance of OpenAIInvocationLayer might receive some unrelated kwargs. Only the kwargs relevant to OpenAIInvocationLayer are considered. The list of OpenAI-relevant @@ -60,6 +63,7 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer): ) self.api_key = api_key self.api_base = api_base + self.openai_organization = openai_organization # 16 is the default length for answers from OpenAI shown in the docs # here, https://platform.openai.com/docs/api-reference/completions/create. @@ -103,7 +107,10 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer): @property def headers(self) -> Dict[str, str]: - return {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} + headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} + if self.openai_organization: + headers["OpenAI-Organization"] = self.openai_organization + return headers def invoke(self, *args, **kwargs): """ diff --git a/haystack/nodes/retriever/_openai_encoder.py b/haystack/nodes/retriever/_openai_encoder.py index 03c914da3..6079a188a 100644 --- a/haystack/nodes/retriever/_openai_encoder.py +++ b/haystack/nodes/retriever/_openai_encoder.py @@ -40,6 +40,7 @@ class _OpenAIEmbeddingEncoder(_BaseEmbeddingEncoder): self.url = f"{retriever.api_base}/embeddings" self.api_key = retriever.api_key + self.openai_organization = retriever.openai_organization self.batch_size = min(64, retriever.batch_size) self.progress_bar = retriever.progress_bar model_class: str = next( @@ -113,6 +114,8 @@ class _OpenAIEmbeddingEncoder(_BaseEmbeddingEncoder): else: payload: Dict[str, Union[List[str], str]] = {"model": model, "input": text} headers["Authorization"] = f"Bearer {self.api_key}" + if self.openai_organization: + headers["OpenAI-Organization"] = self.openai_organization res = openai_request(url=self.url, headers=headers, payload=payload, timeout=OPENAI_TIMEOUT) diff --git a/haystack/nodes/retriever/dense.py b/haystack/nodes/retriever/dense.py index 29883a564..4bd5c8b74 100644 --- a/haystack/nodes/retriever/dense.py +++ b/haystack/nodes/retriever/dense.py @@ -1462,6 +1462,7 @@ class EmbeddingRetriever(DenseRetriever): azure_base_url: Optional[str] = None, azure_deployment_name: Optional[str] = None, api_base: str = "https://api.openai.com/v1", + openai_organization: Optional[str] = None, ): """ :param document_store: An instance of DocumentStore from which to retrieve documents. @@ -1521,7 +1522,9 @@ class EmbeddingRetriever(DenseRetriever): This parameter is an OpenAI Azure endpoint, usually in the form `https://.openai.azure.com' :param azure_deployment_name: The name of the Azure OpenAI API deployment. If not supplied, Azure OpenAI API will not be used. - :param api_base: The OpenAI API base URL, defaults to `"https://api.openai.com/v1"` + :param api_base: The OpenAI API base URL, defaults to `"https://api.openai.com/v1"`. + :param openai_organization: The OpenAI-Organization ID, defaults to `None`. For more details, see see OpenAI + [documentation](https://platform.openai.com/docs/api-reference/requesting-organization). """ torch_and_transformers_import.check() @@ -1551,6 +1554,7 @@ class EmbeddingRetriever(DenseRetriever): self.api_version = azure_api_version self.azure_base_url = azure_base_url self.azure_deployment_name = azure_deployment_name + self.openai_organization = openai_organization self.model_format = ( self._infer_model_format(model_name_or_path=embedding_model, use_auth_token=use_auth_token) if model_format is None diff --git a/test/nodes/test_generator.py b/test/nodes/test_generator.py index 262bc830e..989524b63 100644 --- a/test/nodes/test_generator.py +++ b/test/nodes/test_generator.py @@ -9,6 +9,28 @@ from haystack.nodes import PromptTemplate import logging +@pytest.mark.unit +@patch("haystack.nodes.answer_generator.openai.openai_request") +def test_no_openai_organization(mock_request): + with patch("haystack.nodes.answer_generator.openai.load_openai_tokenizer"): + generator = OpenAIAnswerGenerator(api_key="fake_api_key") + assert generator.openai_organization is None + + generator.predict(query="test query", documents=[Document(content="test document")]) + assert "OpenAI-Organization" not in mock_request.call_args.kwargs["headers"] + + +@pytest.mark.unit +@patch("haystack.nodes.answer_generator.openai.openai_request") +def test_openai_organization(mock_request): + with patch("haystack.nodes.answer_generator.openai.load_openai_tokenizer"): + generator = OpenAIAnswerGenerator(api_key="fake_api_key", openai_organization="fake_organization") + assert generator.openai_organization == "fake_organization" + + generator.predict(query="test query", documents=[Document(content="test document")]) + assert mock_request.call_args.kwargs["headers"]["OpenAI-Organization"] == "fake_organization" + + @pytest.mark.unit @patch("haystack.nodes.answer_generator.openai.openai_request") def test_openai_answer_generator_default_api_base(mock_request): diff --git a/test/nodes/test_retriever.py b/test/nodes/test_retriever.py index 656ed64c6..07b86adc0 100644 --- a/test/nodes/test_retriever.py +++ b/test/nodes/test_retriever.py @@ -1154,3 +1154,27 @@ def test_openai_custom_api_base(mock_request): retriever.embed_documents(documents=[Document(content="test document")]) assert mock_request.call_args.kwargs["url"] == "https://fake_api_base.com/embeddings" + + +@pytest.mark.unit +@patch("haystack.nodes.retriever._openai_encoder.openai_request") +def test_openai_no_openai_organization(mock_request): + with patch("haystack.nodes.retriever._openai_encoder.load_openai_tokenizer"): + retriever = EmbeddingRetriever(embedding_model="text-embedding-ada-002", api_key="fake_api_key") + assert retriever.openai_organization is None + + retriever.embed_queries(queries=["test query"]) + assert "OpenAI-Organization" not in mock_request.call_args.kwargs["headers"] + + +@pytest.mark.unit +@patch("haystack.nodes.retriever._openai_encoder.openai_request") +def test_openai_openai_organization(mock_request): + with patch("haystack.nodes.retriever._openai_encoder.load_openai_tokenizer"): + retriever = EmbeddingRetriever( + embedding_model="text-embedding-ada-002", api_key="fake_api_key", openai_organization="fake_organization" + ) + assert retriever.openai_organization == "fake_organization" + + retriever.embed_queries(queries=["test query"]) + assert mock_request.call_args.kwargs["headers"]["OpenAI-Organization"] == "fake_organization" diff --git a/test/prompt/invocation_layer/test_openai.py b/test/prompt/invocation_layer/test_openai.py index c4960bbf6..31fe7066d 100644 --- a/test/prompt/invocation_layer/test_openai.py +++ b/test/prompt/invocation_layer/test_openai.py @@ -39,3 +39,29 @@ def test_openai_token_limit_warning(mock_openai_tokenizer, caplog): _ = invocation_layer._ensure_token_limit(prompt="This is a test for a mock openai tokenizer.") assert "The prompt has been truncated from" in caplog.text assert "and answer length (2045 tokens) fit within the max token limit (2049 tokens)." in caplog.text + + +@pytest.mark.unit +@patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request") +def test_no_openai_organization(mock_request): + with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"): + invocation_layer = OpenAIInvocationLayer(api_key="fake_api_key") + + assert invocation_layer.openai_organization is None + assert "OpenAI-Organization" not in invocation_layer.headers + + invocation_layer.invoke(prompt="dummy_prompt") + assert "OpenAI-Organization" not in mock_request.call_args.kwargs["headers"] + + +@pytest.mark.unit +@patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request") +def test_openai_organization(mock_request): + with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"): + invocation_layer = OpenAIInvocationLayer(api_key="fake_api_key", openai_organization="fake_organization") + + assert invocation_layer.openai_organization == "fake_organization" + assert invocation_layer.headers["OpenAI-Organization"] == "fake_organization" + + invocation_layer.invoke(prompt="dummy_prompt") + assert mock_request.call_args.kwargs["headers"]["OpenAI-Organization"] == "fake_organization"