feat: support OpenAI-Organization for authentication (#5292)

* add openai_organization to invocation layer, generator and retriever

* added tests
This commit is contained in:
Stefano Fiorucci 2023-07-07 12:02:21 +02:00 committed by GitHub
parent 0697f5c63e
commit 90ff3817e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 94 additions and 2 deletions

View File

@ -48,6 +48,7 @@ class OpenAIAnswerGenerator(BaseGenerator):
context_join_str: str = " ",
moderate_content: bool = False,
api_base: str = "https://api.openai.com/v1",
openai_organization: Optional[str] = None,
):
"""
:param api_key: Your API key from OpenAI. It is required for this node to work.
@ -105,6 +106,8 @@ class OpenAIAnswerGenerator(BaseGenerator):
using the [OpenAI Moderation API](https://platform.openai.com/docs/guides/moderation). If the input or
answers are flagged, an empty list is returned in place of the answers.
:param api_base: The base URL for the OpenAI API, defaults to `"https://api.openai.com/v1"`.
:param openai_organization: The OpenAI-Organization ID, defaults to `None`. For more details, see see OpenAI
[documentation](https://platform.openai.com/docs/api-reference/requesting-organization).
"""
super().__init__(progress_bar=progress_bar)
if (examples is None and examples_context is not None) or (examples is not None and examples_context is None):
@ -165,6 +168,7 @@ class OpenAIAnswerGenerator(BaseGenerator):
self.context_join_str = context_join_str
self.using_azure = self.azure_deployment_name is not None and self.azure_base_url is not None
self.moderate_content = moderate_content
self.openai_organization = openai_organization
tokenizer_name, max_tokens_limit = _openai_text_completion_tokenization_details(model_name=self.model)
@ -233,6 +237,8 @@ class OpenAIAnswerGenerator(BaseGenerator):
headers["api-key"] = self.api_key
else:
headers["Authorization"] = f"Bearer {self.api_key}"
if self.openai_organization:
headers["OpenAI-Organization"] = self.openai_organization
if self.moderate_content and check_openai_policy_violation(input=prompt, headers=headers):
logger.info("Prompt '%s' will not be sent to OpenAI due to potential policy violation.", prompt)

View File

@ -34,6 +34,7 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
model_name_or_path: str = "text-davinci-003",
max_length: Optional[int] = 100,
api_base: str = "https://api.openai.com/v1",
openai_organization: Optional[str] = None,
**kwargs,
):
"""
@ -43,6 +44,8 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
:param max_length: The maximum number of tokens the output text can have.
:param api_key: The OpenAI API key.
:param api_base: The OpenAI API Base url, defaults to `https://api.openai.com/v1`.
:param openai_organization: The OpenAI-Organization ID, defaults to `None`. For more details, see see OpenAI
[documentation](https://platform.openai.com/docs/api-reference/requesting-organization).
:param kwargs: Additional keyword arguments passed to the underlying model. Due to reflective construction of
all PromptModelInvocationLayer instances, this instance of OpenAIInvocationLayer might receive some unrelated
kwargs. Only the kwargs relevant to OpenAIInvocationLayer are considered. The list of OpenAI-relevant
@ -60,6 +63,7 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
)
self.api_key = api_key
self.api_base = api_base
self.openai_organization = openai_organization
# 16 is the default length for answers from OpenAI shown in the docs
# here, https://platform.openai.com/docs/api-reference/completions/create.
@ -103,7 +107,10 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
@property
def headers(self) -> Dict[str, str]:
return {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
if self.openai_organization:
headers["OpenAI-Organization"] = self.openai_organization
return headers
def invoke(self, *args, **kwargs):
"""

View File

@ -40,6 +40,7 @@ class _OpenAIEmbeddingEncoder(_BaseEmbeddingEncoder):
self.url = f"{retriever.api_base}/embeddings"
self.api_key = retriever.api_key
self.openai_organization = retriever.openai_organization
self.batch_size = min(64, retriever.batch_size)
self.progress_bar = retriever.progress_bar
model_class: str = next(
@ -113,6 +114,8 @@ class _OpenAIEmbeddingEncoder(_BaseEmbeddingEncoder):
else:
payload: Dict[str, Union[List[str], str]] = {"model": model, "input": text}
headers["Authorization"] = f"Bearer {self.api_key}"
if self.openai_organization:
headers["OpenAI-Organization"] = self.openai_organization
res = openai_request(url=self.url, headers=headers, payload=payload, timeout=OPENAI_TIMEOUT)

View File

@ -1462,6 +1462,7 @@ class EmbeddingRetriever(DenseRetriever):
azure_base_url: Optional[str] = None,
azure_deployment_name: Optional[str] = None,
api_base: str = "https://api.openai.com/v1",
openai_organization: Optional[str] = None,
):
"""
:param document_store: An instance of DocumentStore from which to retrieve documents.
@ -1521,7 +1522,9 @@ class EmbeddingRetriever(DenseRetriever):
This parameter is an OpenAI Azure endpoint, usually in the form `https://<your-endpoint>.openai.azure.com'
:param azure_deployment_name: The name of the Azure OpenAI API deployment. If not supplied, Azure OpenAI API
will not be used.
:param api_base: The OpenAI API base URL, defaults to `"https://api.openai.com/v1"`
:param api_base: The OpenAI API base URL, defaults to `"https://api.openai.com/v1"`.
:param openai_organization: The OpenAI-Organization ID, defaults to `None`. For more details, see see OpenAI
[documentation](https://platform.openai.com/docs/api-reference/requesting-organization).
"""
torch_and_transformers_import.check()
@ -1551,6 +1554,7 @@ class EmbeddingRetriever(DenseRetriever):
self.api_version = azure_api_version
self.azure_base_url = azure_base_url
self.azure_deployment_name = azure_deployment_name
self.openai_organization = openai_organization
self.model_format = (
self._infer_model_format(model_name_or_path=embedding_model, use_auth_token=use_auth_token)
if model_format is None

View File

@ -9,6 +9,28 @@ from haystack.nodes import PromptTemplate
import logging
@pytest.mark.unit
@patch("haystack.nodes.answer_generator.openai.openai_request")
def test_no_openai_organization(mock_request):
with patch("haystack.nodes.answer_generator.openai.load_openai_tokenizer"):
generator = OpenAIAnswerGenerator(api_key="fake_api_key")
assert generator.openai_organization is None
generator.predict(query="test query", documents=[Document(content="test document")])
assert "OpenAI-Organization" not in mock_request.call_args.kwargs["headers"]
@pytest.mark.unit
@patch("haystack.nodes.answer_generator.openai.openai_request")
def test_openai_organization(mock_request):
with patch("haystack.nodes.answer_generator.openai.load_openai_tokenizer"):
generator = OpenAIAnswerGenerator(api_key="fake_api_key", openai_organization="fake_organization")
assert generator.openai_organization == "fake_organization"
generator.predict(query="test query", documents=[Document(content="test document")])
assert mock_request.call_args.kwargs["headers"]["OpenAI-Organization"] == "fake_organization"
@pytest.mark.unit
@patch("haystack.nodes.answer_generator.openai.openai_request")
def test_openai_answer_generator_default_api_base(mock_request):

View File

@ -1154,3 +1154,27 @@ def test_openai_custom_api_base(mock_request):
retriever.embed_documents(documents=[Document(content="test document")])
assert mock_request.call_args.kwargs["url"] == "https://fake_api_base.com/embeddings"
@pytest.mark.unit
@patch("haystack.nodes.retriever._openai_encoder.openai_request")
def test_openai_no_openai_organization(mock_request):
with patch("haystack.nodes.retriever._openai_encoder.load_openai_tokenizer"):
retriever = EmbeddingRetriever(embedding_model="text-embedding-ada-002", api_key="fake_api_key")
assert retriever.openai_organization is None
retriever.embed_queries(queries=["test query"])
assert "OpenAI-Organization" not in mock_request.call_args.kwargs["headers"]
@pytest.mark.unit
@patch("haystack.nodes.retriever._openai_encoder.openai_request")
def test_openai_openai_organization(mock_request):
with patch("haystack.nodes.retriever._openai_encoder.load_openai_tokenizer"):
retriever = EmbeddingRetriever(
embedding_model="text-embedding-ada-002", api_key="fake_api_key", openai_organization="fake_organization"
)
assert retriever.openai_organization == "fake_organization"
retriever.embed_queries(queries=["test query"])
assert mock_request.call_args.kwargs["headers"]["OpenAI-Organization"] == "fake_organization"

View File

@ -39,3 +39,29 @@ def test_openai_token_limit_warning(mock_openai_tokenizer, caplog):
_ = invocation_layer._ensure_token_limit(prompt="This is a test for a mock openai tokenizer.")
assert "The prompt has been truncated from" in caplog.text
assert "and answer length (2045 tokens) fit within the max token limit (2049 tokens)." in caplog.text
@pytest.mark.unit
@patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request")
def test_no_openai_organization(mock_request):
with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"):
invocation_layer = OpenAIInvocationLayer(api_key="fake_api_key")
assert invocation_layer.openai_organization is None
assert "OpenAI-Organization" not in invocation_layer.headers
invocation_layer.invoke(prompt="dummy_prompt")
assert "OpenAI-Organization" not in mock_request.call_args.kwargs["headers"]
@pytest.mark.unit
@patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request")
def test_openai_organization(mock_request):
with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"):
invocation_layer = OpenAIInvocationLayer(api_key="fake_api_key", openai_organization="fake_organization")
assert invocation_layer.openai_organization == "fake_organization"
assert invocation_layer.headers["OpenAI-Organization"] == "fake_organization"
invocation_layer.invoke(prompt="dummy_prompt")
assert mock_request.call_args.kwargs["headers"]["OpenAI-Organization"] == "fake_organization"