mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-01 10:19:23 +00:00
feat: support OpenAI-Organization for authentication (#5292)
* add openai_organization to invocation layer, generator and retriever * added tests
This commit is contained in:
parent
0697f5c63e
commit
90ff3817e7
@ -48,6 +48,7 @@ class OpenAIAnswerGenerator(BaseGenerator):
|
||||
context_join_str: str = " ",
|
||||
moderate_content: bool = False,
|
||||
api_base: str = "https://api.openai.com/v1",
|
||||
openai_organization: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
:param api_key: Your API key from OpenAI. It is required for this node to work.
|
||||
@ -105,6 +106,8 @@ class OpenAIAnswerGenerator(BaseGenerator):
|
||||
using the [OpenAI Moderation API](https://platform.openai.com/docs/guides/moderation). If the input or
|
||||
answers are flagged, an empty list is returned in place of the answers.
|
||||
:param api_base: The base URL for the OpenAI API, defaults to `"https://api.openai.com/v1"`.
|
||||
:param openai_organization: The OpenAI-Organization ID, defaults to `None`. For more details, see see OpenAI
|
||||
[documentation](https://platform.openai.com/docs/api-reference/requesting-organization).
|
||||
"""
|
||||
super().__init__(progress_bar=progress_bar)
|
||||
if (examples is None and examples_context is not None) or (examples is not None and examples_context is None):
|
||||
@ -165,6 +168,7 @@ class OpenAIAnswerGenerator(BaseGenerator):
|
||||
self.context_join_str = context_join_str
|
||||
self.using_azure = self.azure_deployment_name is not None and self.azure_base_url is not None
|
||||
self.moderate_content = moderate_content
|
||||
self.openai_organization = openai_organization
|
||||
|
||||
tokenizer_name, max_tokens_limit = _openai_text_completion_tokenization_details(model_name=self.model)
|
||||
|
||||
@ -233,6 +237,8 @@ class OpenAIAnswerGenerator(BaseGenerator):
|
||||
headers["api-key"] = self.api_key
|
||||
else:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
if self.openai_organization:
|
||||
headers["OpenAI-Organization"] = self.openai_organization
|
||||
|
||||
if self.moderate_content and check_openai_policy_violation(input=prompt, headers=headers):
|
||||
logger.info("Prompt '%s' will not be sent to OpenAI due to potential policy violation.", prompt)
|
||||
|
||||
@ -34,6 +34,7 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
||||
model_name_or_path: str = "text-davinci-003",
|
||||
max_length: Optional[int] = 100,
|
||||
api_base: str = "https://api.openai.com/v1",
|
||||
openai_organization: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -43,6 +44,8 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
||||
:param max_length: The maximum number of tokens the output text can have.
|
||||
:param api_key: The OpenAI API key.
|
||||
:param api_base: The OpenAI API Base url, defaults to `https://api.openai.com/v1`.
|
||||
:param openai_organization: The OpenAI-Organization ID, defaults to `None`. For more details, see see OpenAI
|
||||
[documentation](https://platform.openai.com/docs/api-reference/requesting-organization).
|
||||
:param kwargs: Additional keyword arguments passed to the underlying model. Due to reflective construction of
|
||||
all PromptModelInvocationLayer instances, this instance of OpenAIInvocationLayer might receive some unrelated
|
||||
kwargs. Only the kwargs relevant to OpenAIInvocationLayer are considered. The list of OpenAI-relevant
|
||||
@ -60,6 +63,7 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
||||
)
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.openai_organization = openai_organization
|
||||
|
||||
# 16 is the default length for answers from OpenAI shown in the docs
|
||||
# here, https://platform.openai.com/docs/api-reference/completions/create.
|
||||
@ -103,7 +107,10 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
||||
|
||||
@property
|
||||
def headers(self) -> Dict[str, str]:
|
||||
return {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
if self.openai_organization:
|
||||
headers["OpenAI-Organization"] = self.openai_organization
|
||||
return headers
|
||||
|
||||
def invoke(self, *args, **kwargs):
|
||||
"""
|
||||
|
||||
@ -40,6 +40,7 @@ class _OpenAIEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
self.url = f"{retriever.api_base}/embeddings"
|
||||
|
||||
self.api_key = retriever.api_key
|
||||
self.openai_organization = retriever.openai_organization
|
||||
self.batch_size = min(64, retriever.batch_size)
|
||||
self.progress_bar = retriever.progress_bar
|
||||
model_class: str = next(
|
||||
@ -113,6 +114,8 @@ class _OpenAIEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
else:
|
||||
payload: Dict[str, Union[List[str], str]] = {"model": model, "input": text}
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
if self.openai_organization:
|
||||
headers["OpenAI-Organization"] = self.openai_organization
|
||||
|
||||
res = openai_request(url=self.url, headers=headers, payload=payload, timeout=OPENAI_TIMEOUT)
|
||||
|
||||
|
||||
@ -1462,6 +1462,7 @@ class EmbeddingRetriever(DenseRetriever):
|
||||
azure_base_url: Optional[str] = None,
|
||||
azure_deployment_name: Optional[str] = None,
|
||||
api_base: str = "https://api.openai.com/v1",
|
||||
openai_organization: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
:param document_store: An instance of DocumentStore from which to retrieve documents.
|
||||
@ -1521,7 +1522,9 @@ class EmbeddingRetriever(DenseRetriever):
|
||||
This parameter is an OpenAI Azure endpoint, usually in the form `https://<your-endpoint>.openai.azure.com'
|
||||
:param azure_deployment_name: The name of the Azure OpenAI API deployment. If not supplied, Azure OpenAI API
|
||||
will not be used.
|
||||
:param api_base: The OpenAI API base URL, defaults to `"https://api.openai.com/v1"`
|
||||
:param api_base: The OpenAI API base URL, defaults to `"https://api.openai.com/v1"`.
|
||||
:param openai_organization: The OpenAI-Organization ID, defaults to `None`. For more details, see see OpenAI
|
||||
[documentation](https://platform.openai.com/docs/api-reference/requesting-organization).
|
||||
"""
|
||||
torch_and_transformers_import.check()
|
||||
|
||||
@ -1551,6 +1554,7 @@ class EmbeddingRetriever(DenseRetriever):
|
||||
self.api_version = azure_api_version
|
||||
self.azure_base_url = azure_base_url
|
||||
self.azure_deployment_name = azure_deployment_name
|
||||
self.openai_organization = openai_organization
|
||||
self.model_format = (
|
||||
self._infer_model_format(model_name_or_path=embedding_model, use_auth_token=use_auth_token)
|
||||
if model_format is None
|
||||
|
||||
@ -9,6 +9,28 @@ from haystack.nodes import PromptTemplate
|
||||
import logging
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.answer_generator.openai.openai_request")
|
||||
def test_no_openai_organization(mock_request):
|
||||
with patch("haystack.nodes.answer_generator.openai.load_openai_tokenizer"):
|
||||
generator = OpenAIAnswerGenerator(api_key="fake_api_key")
|
||||
assert generator.openai_organization is None
|
||||
|
||||
generator.predict(query="test query", documents=[Document(content="test document")])
|
||||
assert "OpenAI-Organization" not in mock_request.call_args.kwargs["headers"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.answer_generator.openai.openai_request")
|
||||
def test_openai_organization(mock_request):
|
||||
with patch("haystack.nodes.answer_generator.openai.load_openai_tokenizer"):
|
||||
generator = OpenAIAnswerGenerator(api_key="fake_api_key", openai_organization="fake_organization")
|
||||
assert generator.openai_organization == "fake_organization"
|
||||
|
||||
generator.predict(query="test query", documents=[Document(content="test document")])
|
||||
assert mock_request.call_args.kwargs["headers"]["OpenAI-Organization"] == "fake_organization"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.answer_generator.openai.openai_request")
|
||||
def test_openai_answer_generator_default_api_base(mock_request):
|
||||
|
||||
@ -1154,3 +1154,27 @@ def test_openai_custom_api_base(mock_request):
|
||||
|
||||
retriever.embed_documents(documents=[Document(content="test document")])
|
||||
assert mock_request.call_args.kwargs["url"] == "https://fake_api_base.com/embeddings"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.retriever._openai_encoder.openai_request")
|
||||
def test_openai_no_openai_organization(mock_request):
|
||||
with patch("haystack.nodes.retriever._openai_encoder.load_openai_tokenizer"):
|
||||
retriever = EmbeddingRetriever(embedding_model="text-embedding-ada-002", api_key="fake_api_key")
|
||||
assert retriever.openai_organization is None
|
||||
|
||||
retriever.embed_queries(queries=["test query"])
|
||||
assert "OpenAI-Organization" not in mock_request.call_args.kwargs["headers"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.retriever._openai_encoder.openai_request")
|
||||
def test_openai_openai_organization(mock_request):
|
||||
with patch("haystack.nodes.retriever._openai_encoder.load_openai_tokenizer"):
|
||||
retriever = EmbeddingRetriever(
|
||||
embedding_model="text-embedding-ada-002", api_key="fake_api_key", openai_organization="fake_organization"
|
||||
)
|
||||
assert retriever.openai_organization == "fake_organization"
|
||||
|
||||
retriever.embed_queries(queries=["test query"])
|
||||
assert mock_request.call_args.kwargs["headers"]["OpenAI-Organization"] == "fake_organization"
|
||||
|
||||
@ -39,3 +39,29 @@ def test_openai_token_limit_warning(mock_openai_tokenizer, caplog):
|
||||
_ = invocation_layer._ensure_token_limit(prompt="This is a test for a mock openai tokenizer.")
|
||||
assert "The prompt has been truncated from" in caplog.text
|
||||
assert "and answer length (2045 tokens) fit within the max token limit (2049 tokens)." in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request")
|
||||
def test_no_openai_organization(mock_request):
|
||||
with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"):
|
||||
invocation_layer = OpenAIInvocationLayer(api_key="fake_api_key")
|
||||
|
||||
assert invocation_layer.openai_organization is None
|
||||
assert "OpenAI-Organization" not in invocation_layer.headers
|
||||
|
||||
invocation_layer.invoke(prompt="dummy_prompt")
|
||||
assert "OpenAI-Organization" not in mock_request.call_args.kwargs["headers"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request")
|
||||
def test_openai_organization(mock_request):
|
||||
with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"):
|
||||
invocation_layer = OpenAIInvocationLayer(api_key="fake_api_key", openai_organization="fake_organization")
|
||||
|
||||
assert invocation_layer.openai_organization == "fake_organization"
|
||||
assert invocation_layer.headers["OpenAI-Organization"] == "fake_organization"
|
||||
|
||||
invocation_layer.invoke(prompt="dummy_prompt")
|
||||
assert mock_request.call_args.kwargs["headers"]["OpenAI-Organization"] == "fake_organization"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user