feat: Allow setting custom api_base for OpenAI nodes (#5033)

* add changes for api_base

* format retriever

* Update haystack/nodes/retriever/dense.py

Co-authored-by: bogdankostic <bogdankostic@web.de>

* Update haystack/nodes/audio/whisper_transcriber.py

Co-authored-by: bogdankostic <bogdankostic@web.de>

* Update haystack/preview/components/audio/whisper_remote.py

Co-authored-by: bogdankostic <bogdankostic@web.de>

* Update haystack/nodes/answer_generator/openai.py

Co-authored-by: bogdankostic <bogdankostic@web.de>

* Update test_retriever.py

* Update test_whisper_remote.py

* Update test_generator.py

* Update test_retriever.py

* reformat with black

* Update haystack/nodes/prompt/invocation_layer/chatgpt.py

Co-authored-by: Daria Fokina <daria.f93@gmail.com>

* Add unit tests

* apply docstring suggestions

---------

Co-authored-by: bogdankostic <bogdankostic@web.de>
Co-authored-by: michaelfeil <me@michaelfeil.eu>
Co-authored-by: Daria Fokina <daria.f93@gmail.com>
This commit is contained in:
Michael Feil 2023-06-05 11:32:06 +02:00 committed by GitHub
parent 5f6d161cfe
commit 6ea8ae01a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 187 additions and 14 deletions

View File

@ -45,6 +45,7 @@ class OpenAIAnswerGenerator(BaseGenerator):
progress_bar: bool = True,
prompt_template: Optional[PromptTemplate] = None,
context_join_str: str = " ",
api_base: str = "https://api.openai.com/v1",
):
"""
:param api_key: Your API key from OpenAI. It is required for this node to work.
@ -98,6 +99,7 @@ class OpenAIAnswerGenerator(BaseGenerator):
[PromptTemplate](https://docs.haystack.deepset.ai/docs/prompt_node#template-structure).
:param context_join_str: The separation string used to join the input documents to create the context
used by the PromptTemplate.
:param api_base: The base URL for the OpenAI API, defaults to `"https://api.openai.com/v1"`.
"""
super().__init__(progress_bar=progress_bar)
if (examples is None and examples_context is not None) or (examples is not None and examples_context is None):
@ -144,6 +146,7 @@ class OpenAIAnswerGenerator(BaseGenerator):
self.azure_base_url = azure_base_url
self.azure_deployment_name = azure_deployment_name
self.api_version = api_version
self.api_base = api_base
self.model = model
self.max_tokens = max_tokens
self.top_k = top_k
@ -217,7 +220,7 @@ class OpenAIAnswerGenerator(BaseGenerator):
if self.using_azure:
url = f"{self.azure_base_url}/openai/deployments/{self.azure_deployment_name}/completions?api-version={self.api_version}"
else:
url = "https://api.openai.com/v1/completions"
url = f"{self.api_base}/completions"
headers = {"Content-Type": "application/json"}
if self.using_azure:

View File

@ -41,6 +41,7 @@ class WhisperTranscriber(BaseComponent):
api_key: Optional[str] = None,
model_name_or_path: WhisperModel = "medium",
device: Optional[Union[str, torch.device]] = None,
api_base: str = "https://api.openai.com/v1",
) -> None:
"""
Creates a WhisperTranscriber instance.
@ -50,10 +51,11 @@ class WhisperTranscriber(BaseComponent):
the API, set thsi value to: "whisper-1" (default).
:param device: Device to use for inference. Only used if you're using a local
installation of Whisper. If None, the device is automatically selected.
:param api_base: The OpenAI API Base url, defaults to `https://api.openai.com/v1`.
"""
super().__init__()
self.api_key = api_key
self.api_base = api_base
self.use_local_whisper = is_whisper_available() and self.api_key is None
if self.use_local_whisper:
@ -108,9 +110,7 @@ class WhisperTranscriber(BaseComponent):
headers = {"Authorization": f"Bearer {self.api_key}"}
request = PreparedRequest()
url: str = (
"https://api.openai.com/v1/audio/transcriptions"
if not translate
else "https://api.openai.com/v1/audio/translations"
f"{self.api_base}/audio/transcriptions" if not translate else f"{self.api_base}/audio/translations"
)
request.prepare(

View File

@ -20,9 +20,24 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer):
"""
def __init__(
self, api_key: str, model_name_or_path: str = "gpt-3.5-turbo", max_length: Optional[int] = 500, **kwargs
self,
api_key: str,
model_name_or_path: str = "gpt-3.5-turbo",
max_length: Optional[int] = 500,
api_base: str = "https://api.openai.com/v1",
**kwargs,
):
super().__init__(api_key, model_name_or_path, max_length, **kwargs)
"""
Creates an instance of ChatGPTInvocationLayer for OpenAI's GPT-3.5 GPT-4 models.
:param model_name_or_path: The name or path of the underlying model.
:param max_length: The maximum number of tokens the output text can have.
:param api_key: The OpenAI API key.
:param api_base: The OpenAI API Base url, defaults to `https://api.openai.com/v1`.
:param kwargs: Additional keyword arguments passed to the underlying model.
[See OpenAI documentation](https://platform.openai.com/docs/api-reference/chat).
"""
super().__init__(api_key, model_name_or_path, max_length, api_base=api_base, **kwargs)
def invoke(self, *args, **kwargs):
"""
@ -125,7 +140,7 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer):
@property
def url(self) -> str:
return "https://api.openai.com/v1/chat/completions"
return f"{self.api_base}/chat/completions"
@classmethod
def supports(cls, model_name_or_path: str, **kwargs) -> bool:

View File

@ -27,7 +27,12 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
"""
def __init__(
self, api_key: str, model_name_or_path: str = "text-davinci-003", max_length: Optional[int] = 100, **kwargs
self,
api_key: str,
model_name_or_path: str = "text-davinci-003",
max_length: Optional[int] = 100,
api_base: str = "https://api.openai.com/v1",
**kwargs,
):
"""
Creates an instance of OpenAIInvocationLayer for OpenAI's GPT-3 InstructGPT models.
@ -35,6 +40,7 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
:param model_name_or_path: The name or path of the underlying model.
:param max_length: The maximum number of tokens the output text can have.
:param api_key: The OpenAI API key.
:param api_base: The OpenAI API Base url, defaults to `https://api.openai.com/v1`.
:param kwargs: Additional keyword arguments passed to the underlying model. Due to reflective construction of
all PromptModelInvocationLayer instances, this instance of OpenAIInvocationLayer might receive some unrelated
kwargs. Only the kwargs relevant to OpenAIInvocationLayer are considered. The list of OpenAI-relevant
@ -48,6 +54,7 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
f"api_key {api_key} must be a valid OpenAI key. Visit https://openai.com/api/ to get one."
)
self.api_key = api_key
self.api_base = api_base
# 16 is the default length for answers from OpenAI shown in the docs
# here, https://platform.openai.com/docs/api-reference/completions/create.
@ -86,7 +93,7 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
@property
def url(self) -> str:
return "https://api.openai.com/v1/completions"
return f"{self.api_base}/completions"
@property
def headers(self) -> Dict[str, str]:

View File

@ -37,7 +37,7 @@ class _OpenAIEmbeddingEncoder(_BaseEmbeddingEncoder):
if self.using_azure:
self.url = f"{retriever.azure_base_url}/openai/deployments/{retriever.azure_deployment_name}/embeddings?api-version={retriever.api_version}"
else:
self.url = "https://api.openai.com/v1/embeddings"
self.url = f"{retriever.api_base}/embeddings"
self.api_key = retriever.api_key
self.batch_size = min(64, retriever.batch_size)

View File

@ -1453,6 +1453,7 @@ class EmbeddingRetriever(DenseRetriever):
azure_api_version: str = "2022-12-01",
azure_base_url: Optional[str] = None,
azure_deployment_name: Optional[str] = None,
api_base: str = "https://api.openai.com/v1",
):
"""
:param document_store: An instance of DocumentStore from which to retrieve documents.
@ -1512,6 +1513,7 @@ class EmbeddingRetriever(DenseRetriever):
This parameter is an OpenAI Azure endpoint, usually in the form `https://<your-endpoint>.openai.azure.com'
:param azure_deployment_name: The name of the Azure OpenAI API deployment. If not supplied, Azure OpenAI API
will not be used.
:param api_base: The OpenAI API base URL, defaults to `"https://api.openai.com/v1"`
"""
if embed_meta_fields is None:
embed_meta_fields = []
@ -1535,6 +1537,7 @@ class EmbeddingRetriever(DenseRetriever):
self.use_auth_token = use_auth_token
self.scale_score = scale_score
self.api_key = api_key
self.api_base = api_base
self.api_version = azure_api_version
self.azure_base_url = azure_base_url
self.azure_deployment_name = azure_deployment_name

View File

@ -38,12 +38,15 @@ class RemoteWhisperTranscriber:
class Output(ComponentOutput):
documents: List[Document]
def __init__(self, api_key: str, model_name: WhisperRemoteModel = "whisper-1"):
def __init__(
self, api_key: str, model_name: WhisperRemoteModel = "whisper-1", api_base: str = "https://api.openai.com/v1"
):
"""
Transcribes a list of audio files into a list of Documents.
:param api_key: OpenAI API key.
:param model_name_or_path: Name of the model to use. It now accepts only `whisper-1`.
:param model_name: Name of the model to use. It now accepts only `whisper-1`.
:param api_base: OpenAI base URL, defaults to `"https://api.openai.com/v1"`.
"""
if model_name not in get_args(WhisperRemoteModel):
raise ValueError(
@ -53,6 +56,7 @@ class RemoteWhisperTranscriber:
raise ValueError("API key is None.")
self.api_key = api_key
self.api_base = api_base
self.model_name = model_name
@ -109,7 +113,7 @@ class RemoteWhisperTranscriber:
:returns: a list of transcriptions as they are produced by the Whisper API (JSON).
"""
translate = kwargs.pop("translate", False)
url = f"https://api.openai.com/v1/audio/{'translations' if translate else 'transcriptions'}"
url = f"{self.api_base}/audio/{'translations' if translate else 'transcriptions'}"
data = {"model": self.model_name, **kwargs}
headers = {"Authorization": f"Bearer {self.api_key}"}

View File

@ -33,6 +33,26 @@ def test_rag_deprecation():
pass
@pytest.mark.unit
@patch("haystack.nodes.answer_generator.openai.openai_request")
def test_openai_answer_generator_default_api_base(mock_request):
with patch("haystack.nodes.answer_generator.openai.load_openai_tokenizer"):
generator = OpenAIAnswerGenerator(api_key="fake_api_key")
assert generator.api_base == "https://api.openai.com/v1"
generator.predict(query="test query", documents=[Document(content="test document")])
assert mock_request.call_args.kwargs["url"] == "https://api.openai.com/v1/completions"
@pytest.mark.unit
@patch("haystack.nodes.answer_generator.openai.openai_request")
def test_openai_answer_generator_custom_api_base(mock_request):
with patch("haystack.nodes.answer_generator.openai.load_openai_tokenizer"):
generator = OpenAIAnswerGenerator(api_key="fake_api_key", api_base="https://fake_api_base.com")
assert generator.api_base == "https://fake_api_base.com"
generator.predict(query="test query", documents=[Document(content="test document")])
assert mock_request.call_args.kwargs["url"] == "https://fake_api_base.com/completions"
@pytest.mark.skipif(sys.platform in ["win32", "cygwin"], reason="Causes OOM on windows github runner")
@pytest.mark.integration
@pytest.mark.generator

View File

@ -325,12 +325,14 @@ def test_openai_embedding_retriever_selection():
assert er.model_format == "openai"
assert er.embedding_encoder.query_encoder_model == "text-embedding-ada-002"
assert er.embedding_encoder.doc_encoder_model == "text-embedding-ada-002"
assert er.api_base == "https://api.openai.com/v1"
# but also support old ada and other text-search-<modelname>-*-001 models
er = EmbeddingRetriever(embedding_model="ada", document_store=None)
assert er.model_format == "openai"
assert er.embedding_encoder.query_encoder_model == "text-search-ada-query-001"
assert er.embedding_encoder.doc_encoder_model == "text-search-ada-doc-001"
assert er.api_base == "https://api.openai.com/v1"
# but also support old babbage and other text-search-<modelname>-*-001 models
er = EmbeddingRetriever(embedding_model="babbage", document_store=None)
@ -1270,3 +1272,35 @@ def test_web_retriever_mode_snippets(monkeypatch):
web_retriever = WebRetriever(api_key="", top_search_results=2)
result = web_retriever.retrieve(query="Who is the father of Arya Stark?")
assert result == expected_search_results["documents"]
@pytest.mark.unit
@patch("haystack.nodes.retriever._openai_encoder.openai_request")
def test_openai_default_api_base(mock_request):
with patch("haystack.nodes.retriever._openai_encoder.load_openai_tokenizer"):
retriever = EmbeddingRetriever(embedding_model="text-embedding-ada-002", api_key="fake_api_key")
assert retriever.api_base == "https://api.openai.com/v1"
retriever.embed_queries(queries=["test query"])
assert mock_request.call_args.kwargs["url"] == "https://api.openai.com/v1/embeddings"
mock_request.reset_mock()
retriever.embed_documents(documents=[Document(content="test document")])
assert mock_request.call_args.kwargs["url"] == "https://api.openai.com/v1/embeddings"
@pytest.mark.unit
@patch("haystack.nodes.retriever._openai_encoder.openai_request")
def test_openai_custom_api_base(mock_request):
with patch("haystack.nodes.retriever._openai_encoder.load_openai_tokenizer"):
retriever = EmbeddingRetriever(
embedding_model="text-embedding-ada-002", api_key="fake_api_key", api_base="https://fake_api_base.com"
)
assert retriever.api_base == "https://fake_api_base.com"
retriever.embed_queries(queries=["test query"])
assert mock_request.call_args.kwargs["url"] == "https://fake_api_base.com/embeddings"
mock_request.reset_mock()
retriever.embed_documents(documents=[Document(content="test document")])
assert mock_request.call_args.kwargs["url"] == "https://fake_api_base.com/embeddings"

View File

@ -29,6 +29,7 @@ class TestRemoteWhisperTranscriber(BaseTestComponent):
transcriber = RemoteWhisperTranscriber(api_key="just a test")
assert transcriber.model_name == "whisper-1"
assert transcriber.api_key == "just a test"
assert transcriber.api_base == "https://api.openai.com/v1"
@pytest.mark.unit
def test_init_no_key(self):
@ -155,3 +156,31 @@ class TestRemoteWhisperTranscriber(BaseTestComponent):
"headers": {"Authorization": f"Bearer whatever"},
"timeout": OPENAI_TIMEOUT,
}
@pytest.mark.unit
@patch("haystack.preview.components.audio.whisper_remote.request_with_retry")
def test_default_api_base(self, mock_request, preview_samples_path):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
mock_request.return_value = mock_response
transcriber = RemoteWhisperTranscriber(api_key="just a test")
assert transcriber.api_base == "https://api.openai.com/v1"
transcriber.transcribe(audio_files=[preview_samples_path / "audio" / "this is the content of the document.wav"])
assert mock_request.call_args.kwargs["url"] == "https://api.openai.com/v1/audio/transcriptions"
@pytest.mark.unit
@patch("haystack.preview.components.audio.whisper_remote.request_with_retry")
def test_custom_api_base(self, mock_request, preview_samples_path):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
mock_request.return_value = mock_response
transcriber = RemoteWhisperTranscriber(api_key="just a test", api_base="https://fake_api_base.com")
assert transcriber.api_base == "https://fake_api_base.com"
transcriber.transcribe(audio_files=[preview_samples_path / "audio" / "this is the content of the document.wav"])
assert mock_request.call_args.kwargs["url"] == "https://fake_api_base.com/audio/transcriptions"

View File

@ -0,0 +1,29 @@
from unittest.mock import patch
import pytest
from haystack.nodes.prompt.invocation_layer import ChatGPTInvocationLayer
@pytest.mark.unit
@patch("haystack.nodes.prompt.invocation_layer.chatgpt.openai_request")
def test_default_api_base(mock_request):
with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"):
invocation_layer = ChatGPTInvocationLayer(api_key="fake_api_key")
assert invocation_layer.api_base == "https://api.openai.com/v1"
assert invocation_layer.url == "https://api.openai.com/v1/chat/completions"
invocation_layer.invoke(prompt="dummy_prompt")
assert mock_request.call_args.kwargs["url"] == "https://api.openai.com/v1/chat/completions"
@pytest.mark.unit
@patch("haystack.nodes.prompt.invocation_layer.chatgpt.openai_request")
def test_custom_api_base(mock_request):
with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"):
invocation_layer = ChatGPTInvocationLayer(api_key="fake_api_key", api_base="https://fake_api_base.com")
assert invocation_layer.api_base == "https://fake_api_base.com"
assert invocation_layer.url == "https://fake_api_base.com/chat/completions"
invocation_layer.invoke(prompt="dummy_prompt")
assert mock_request.call_args.kwargs["url"] == "https://fake_api_base.com/chat/completions"

View File

@ -0,0 +1,29 @@
from unittest.mock import patch
import pytest
from haystack.nodes.prompt.invocation_layer import OpenAIInvocationLayer
@pytest.mark.unit
@patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request")
def test_default_api_base(mock_request):
with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"):
invocation_layer = OpenAIInvocationLayer(api_key="fake_api_key")
assert invocation_layer.api_base == "https://api.openai.com/v1"
assert invocation_layer.url == "https://api.openai.com/v1/completions"
invocation_layer.invoke(prompt="dummy_prompt")
assert mock_request.call_args.kwargs["url"] == "https://api.openai.com/v1/completions"
@pytest.mark.unit
@patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request")
def test_custom_api_base(mock_request):
with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"):
invocation_layer = OpenAIInvocationLayer(api_key="fake_api_key", api_base="https://fake_api_base.com")
assert invocation_layer.api_base == "https://fake_api_base.com"
assert invocation_layer.url == "https://fake_api_base.com/completions"
invocation_layer.invoke(prompt="dummy_prompt")
assert mock_request.call_args.kwargs["url"] == "https://fake_api_base.com/completions"