From 6ea8ae01a231c6fd4a1cecf7269b6eedf78813a4 Mon Sep 17 00:00:00 2001 From: Michael Feil <63565275+michaelfeil@users.noreply.github.com> Date: Mon, 5 Jun 2023 11:32:06 +0200 Subject: [PATCH] feat: Allow setting custom api_base for OpenAI nodes (#5033) * add changes for api_base * format retriever * Update haystack/nodes/retriever/dense.py Co-authored-by: bogdankostic * Update haystack/nodes/audio/whisper_transcriber.py Co-authored-by: bogdankostic * Update haystack/preview/components/audio/whisper_remote.py Co-authored-by: bogdankostic * Update haystack/nodes/answer_generator/openai.py Co-authored-by: bogdankostic * Update test_retriever.py * Update test_whisper_remote.py * Update test_generator.py * Update test_retriever.py * reformat with black * Update haystack/nodes/prompt/invocation_layer/chatgpt.py Co-authored-by: Daria Fokina * Add unit tests * apply docstring suggestions --------- Co-authored-by: bogdankostic Co-authored-by: michaelfeil Co-authored-by: Daria Fokina --- haystack/nodes/answer_generator/openai.py | 5 ++- haystack/nodes/audio/whisper_transcriber.py | 8 ++--- .../nodes/prompt/invocation_layer/chatgpt.py | 21 ++++++++++-- .../nodes/prompt/invocation_layer/open_ai.py | 11 ++++-- haystack/nodes/retriever/_openai_encoder.py | 2 +- haystack/nodes/retriever/dense.py | 3 ++ .../components/audio/whisper_remote.py | 10 ++++-- test/nodes/test_generator.py | 20 +++++++++++ test/nodes/test_retriever.py | 34 +++++++++++++++++++ .../components/audio/test_whisper_remote.py | 29 ++++++++++++++++ test/prompt/invocation_layer/test_chatgpt.py | 29 ++++++++++++++++ test/prompt/invocation_layer/test_openai.py | 29 ++++++++++++++++ 12 files changed, 187 insertions(+), 14 deletions(-) create mode 100644 test/prompt/invocation_layer/test_chatgpt.py create mode 100644 test/prompt/invocation_layer/test_openai.py diff --git a/haystack/nodes/answer_generator/openai.py b/haystack/nodes/answer_generator/openai.py index 4c7dbb683..4358094fe 100644 --- a/haystack/nodes/answer_generator/openai.py +++ b/haystack/nodes/answer_generator/openai.py @@ -45,6 +45,7 @@ class OpenAIAnswerGenerator(BaseGenerator): progress_bar: bool = True, prompt_template: Optional[PromptTemplate] = None, context_join_str: str = " ", + api_base: str = "https://api.openai.com/v1", ): """ :param api_key: Your API key from OpenAI. It is required for this node to work. @@ -98,6 +99,7 @@ class OpenAIAnswerGenerator(BaseGenerator): [PromptTemplate](https://docs.haystack.deepset.ai/docs/prompt_node#template-structure). :param context_join_str: The separation string used to join the input documents to create the context used by the PromptTemplate. + :param api_base: The base URL for the OpenAI API, defaults to `"https://api.openai.com/v1"`. """ super().__init__(progress_bar=progress_bar) if (examples is None and examples_context is not None) or (examples is not None and examples_context is None): @@ -144,6 +146,7 @@ class OpenAIAnswerGenerator(BaseGenerator): self.azure_base_url = azure_base_url self.azure_deployment_name = azure_deployment_name self.api_version = api_version + self.api_base = api_base self.model = model self.max_tokens = max_tokens self.top_k = top_k @@ -217,7 +220,7 @@ class OpenAIAnswerGenerator(BaseGenerator): if self.using_azure: url = f"{self.azure_base_url}/openai/deployments/{self.azure_deployment_name}/completions?api-version={self.api_version}" else: - url = "https://api.openai.com/v1/completions" + url = f"{self.api_base}/completions" headers = {"Content-Type": "application/json"} if self.using_azure: diff --git a/haystack/nodes/audio/whisper_transcriber.py b/haystack/nodes/audio/whisper_transcriber.py index 943e260b4..b2f7b67b4 100644 --- a/haystack/nodes/audio/whisper_transcriber.py +++ b/haystack/nodes/audio/whisper_transcriber.py @@ -41,6 +41,7 @@ class WhisperTranscriber(BaseComponent): api_key: Optional[str] = None, model_name_or_path: WhisperModel = "medium", device: Optional[Union[str, torch.device]] = None, + api_base: str = "https://api.openai.com/v1", ) -> None: """ Creates a WhisperTranscriber instance. @@ -50,10 +51,11 @@ class WhisperTranscriber(BaseComponent): the API, set thsi value to: "whisper-1" (default). :param device: Device to use for inference. Only used if you're using a local installation of Whisper. If None, the device is automatically selected. + :param api_base: The OpenAI API Base url, defaults to `https://api.openai.com/v1`. """ super().__init__() self.api_key = api_key - + self.api_base = api_base self.use_local_whisper = is_whisper_available() and self.api_key is None if self.use_local_whisper: @@ -108,9 +110,7 @@ class WhisperTranscriber(BaseComponent): headers = {"Authorization": f"Bearer {self.api_key}"} request = PreparedRequest() url: str = ( - "https://api.openai.com/v1/audio/transcriptions" - if not translate - else "https://api.openai.com/v1/audio/translations" + f"{self.api_base}/audio/transcriptions" if not translate else f"{self.api_base}/audio/translations" ) request.prepare( diff --git a/haystack/nodes/prompt/invocation_layer/chatgpt.py b/haystack/nodes/prompt/invocation_layer/chatgpt.py index c47bfeedf..cd49cfcf6 100644 --- a/haystack/nodes/prompt/invocation_layer/chatgpt.py +++ b/haystack/nodes/prompt/invocation_layer/chatgpt.py @@ -20,9 +20,24 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer): """ def __init__( - self, api_key: str, model_name_or_path: str = "gpt-3.5-turbo", max_length: Optional[int] = 500, **kwargs + self, + api_key: str, + model_name_or_path: str = "gpt-3.5-turbo", + max_length: Optional[int] = 500, + api_base: str = "https://api.openai.com/v1", + **kwargs, ): - super().__init__(api_key, model_name_or_path, max_length, **kwargs) + """ + Creates an instance of ChatGPTInvocationLayer for OpenAI's GPT-3.5 GPT-4 models. + + :param model_name_or_path: The name or path of the underlying model. + :param max_length: The maximum number of tokens the output text can have. + :param api_key: The OpenAI API key. + :param api_base: The OpenAI API Base url, defaults to `https://api.openai.com/v1`. + :param kwargs: Additional keyword arguments passed to the underlying model. + [See OpenAI documentation](https://platform.openai.com/docs/api-reference/chat). + """ + super().__init__(api_key, model_name_or_path, max_length, api_base=api_base, **kwargs) def invoke(self, *args, **kwargs): """ @@ -125,7 +140,7 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer): @property def url(self) -> str: - return "https://api.openai.com/v1/chat/completions" + return f"{self.api_base}/chat/completions" @classmethod def supports(cls, model_name_or_path: str, **kwargs) -> bool: diff --git a/haystack/nodes/prompt/invocation_layer/open_ai.py b/haystack/nodes/prompt/invocation_layer/open_ai.py index e48513774..32c6a801b 100644 --- a/haystack/nodes/prompt/invocation_layer/open_ai.py +++ b/haystack/nodes/prompt/invocation_layer/open_ai.py @@ -27,7 +27,12 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer): """ def __init__( - self, api_key: str, model_name_or_path: str = "text-davinci-003", max_length: Optional[int] = 100, **kwargs + self, + api_key: str, + model_name_or_path: str = "text-davinci-003", + max_length: Optional[int] = 100, + api_base: str = "https://api.openai.com/v1", + **kwargs, ): """ Creates an instance of OpenAIInvocationLayer for OpenAI's GPT-3 InstructGPT models. @@ -35,6 +40,7 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer): :param model_name_or_path: The name or path of the underlying model. :param max_length: The maximum number of tokens the output text can have. :param api_key: The OpenAI API key. + :param api_base: The OpenAI API Base url, defaults to `https://api.openai.com/v1`. :param kwargs: Additional keyword arguments passed to the underlying model. Due to reflective construction of all PromptModelInvocationLayer instances, this instance of OpenAIInvocationLayer might receive some unrelated kwargs. Only the kwargs relevant to OpenAIInvocationLayer are considered. The list of OpenAI-relevant @@ -48,6 +54,7 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer): f"api_key {api_key} must be a valid OpenAI key. Visit https://openai.com/api/ to get one." ) self.api_key = api_key + self.api_base = api_base # 16 is the default length for answers from OpenAI shown in the docs # here, https://platform.openai.com/docs/api-reference/completions/create. @@ -86,7 +93,7 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer): @property def url(self) -> str: - return "https://api.openai.com/v1/completions" + return f"{self.api_base}/completions" @property def headers(self) -> Dict[str, str]: diff --git a/haystack/nodes/retriever/_openai_encoder.py b/haystack/nodes/retriever/_openai_encoder.py index 2b101853e..f476fdbd9 100644 --- a/haystack/nodes/retriever/_openai_encoder.py +++ b/haystack/nodes/retriever/_openai_encoder.py @@ -37,7 +37,7 @@ class _OpenAIEmbeddingEncoder(_BaseEmbeddingEncoder): if self.using_azure: self.url = f"{retriever.azure_base_url}/openai/deployments/{retriever.azure_deployment_name}/embeddings?api-version={retriever.api_version}" else: - self.url = "https://api.openai.com/v1/embeddings" + self.url = f"{retriever.api_base}/embeddings" self.api_key = retriever.api_key self.batch_size = min(64, retriever.batch_size) diff --git a/haystack/nodes/retriever/dense.py b/haystack/nodes/retriever/dense.py index e2bba4f46..a5a503f1e 100644 --- a/haystack/nodes/retriever/dense.py +++ b/haystack/nodes/retriever/dense.py @@ -1453,6 +1453,7 @@ class EmbeddingRetriever(DenseRetriever): azure_api_version: str = "2022-12-01", azure_base_url: Optional[str] = None, azure_deployment_name: Optional[str] = None, + api_base: str = "https://api.openai.com/v1", ): """ :param document_store: An instance of DocumentStore from which to retrieve documents. @@ -1512,6 +1513,7 @@ class EmbeddingRetriever(DenseRetriever): This parameter is an OpenAI Azure endpoint, usually in the form `https://.openai.azure.com' :param azure_deployment_name: The name of the Azure OpenAI API deployment. If not supplied, Azure OpenAI API will not be used. + :param api_base: The OpenAI API base URL, defaults to `"https://api.openai.com/v1"` """ if embed_meta_fields is None: embed_meta_fields = [] @@ -1535,6 +1537,7 @@ class EmbeddingRetriever(DenseRetriever): self.use_auth_token = use_auth_token self.scale_score = scale_score self.api_key = api_key + self.api_base = api_base self.api_version = azure_api_version self.azure_base_url = azure_base_url self.azure_deployment_name = azure_deployment_name diff --git a/haystack/preview/components/audio/whisper_remote.py b/haystack/preview/components/audio/whisper_remote.py index 3d5bdd44b..9cb721e39 100644 --- a/haystack/preview/components/audio/whisper_remote.py +++ b/haystack/preview/components/audio/whisper_remote.py @@ -38,12 +38,15 @@ class RemoteWhisperTranscriber: class Output(ComponentOutput): documents: List[Document] - def __init__(self, api_key: str, model_name: WhisperRemoteModel = "whisper-1"): + def __init__( + self, api_key: str, model_name: WhisperRemoteModel = "whisper-1", api_base: str = "https://api.openai.com/v1" + ): """ Transcribes a list of audio files into a list of Documents. :param api_key: OpenAI API key. - :param model_name_or_path: Name of the model to use. It now accepts only `whisper-1`. + :param model_name: Name of the model to use. It now accepts only `whisper-1`. + :param api_base: OpenAI base URL, defaults to `"https://api.openai.com/v1"`. """ if model_name not in get_args(WhisperRemoteModel): raise ValueError( @@ -53,6 +56,7 @@ class RemoteWhisperTranscriber: raise ValueError("API key is None.") self.api_key = api_key + self.api_base = api_base self.model_name = model_name @@ -109,7 +113,7 @@ class RemoteWhisperTranscriber: :returns: a list of transcriptions as they are produced by the Whisper API (JSON). """ translate = kwargs.pop("translate", False) - url = f"https://api.openai.com/v1/audio/{'translations' if translate else 'transcriptions'}" + url = f"{self.api_base}/audio/{'translations' if translate else 'transcriptions'}" data = {"model": self.model_name, **kwargs} headers = {"Authorization": f"Bearer {self.api_key}"} diff --git a/test/nodes/test_generator.py b/test/nodes/test_generator.py index aabc19151..5c19fe2c8 100644 --- a/test/nodes/test_generator.py +++ b/test/nodes/test_generator.py @@ -33,6 +33,26 @@ def test_rag_deprecation(): pass +@pytest.mark.unit +@patch("haystack.nodes.answer_generator.openai.openai_request") +def test_openai_answer_generator_default_api_base(mock_request): + with patch("haystack.nodes.answer_generator.openai.load_openai_tokenizer"): + generator = OpenAIAnswerGenerator(api_key="fake_api_key") + assert generator.api_base == "https://api.openai.com/v1" + generator.predict(query="test query", documents=[Document(content="test document")]) + assert mock_request.call_args.kwargs["url"] == "https://api.openai.com/v1/completions" + + +@pytest.mark.unit +@patch("haystack.nodes.answer_generator.openai.openai_request") +def test_openai_answer_generator_custom_api_base(mock_request): + with patch("haystack.nodes.answer_generator.openai.load_openai_tokenizer"): + generator = OpenAIAnswerGenerator(api_key="fake_api_key", api_base="https://fake_api_base.com") + assert generator.api_base == "https://fake_api_base.com" + generator.predict(query="test query", documents=[Document(content="test document")]) + assert mock_request.call_args.kwargs["url"] == "https://fake_api_base.com/completions" + + @pytest.mark.skipif(sys.platform in ["win32", "cygwin"], reason="Causes OOM on windows github runner") @pytest.mark.integration @pytest.mark.generator diff --git a/test/nodes/test_retriever.py b/test/nodes/test_retriever.py index 8c5d9899d..ea49521c8 100644 --- a/test/nodes/test_retriever.py +++ b/test/nodes/test_retriever.py @@ -325,12 +325,14 @@ def test_openai_embedding_retriever_selection(): assert er.model_format == "openai" assert er.embedding_encoder.query_encoder_model == "text-embedding-ada-002" assert er.embedding_encoder.doc_encoder_model == "text-embedding-ada-002" + assert er.api_base == "https://api.openai.com/v1" # but also support old ada and other text-search--*-001 models er = EmbeddingRetriever(embedding_model="ada", document_store=None) assert er.model_format == "openai" assert er.embedding_encoder.query_encoder_model == "text-search-ada-query-001" assert er.embedding_encoder.doc_encoder_model == "text-search-ada-doc-001" + assert er.api_base == "https://api.openai.com/v1" # but also support old babbage and other text-search--*-001 models er = EmbeddingRetriever(embedding_model="babbage", document_store=None) @@ -1270,3 +1272,35 @@ def test_web_retriever_mode_snippets(monkeypatch): web_retriever = WebRetriever(api_key="", top_search_results=2) result = web_retriever.retrieve(query="Who is the father of Arya Stark?") assert result == expected_search_results["documents"] + + +@pytest.mark.unit +@patch("haystack.nodes.retriever._openai_encoder.openai_request") +def test_openai_default_api_base(mock_request): + with patch("haystack.nodes.retriever._openai_encoder.load_openai_tokenizer"): + retriever = EmbeddingRetriever(embedding_model="text-embedding-ada-002", api_key="fake_api_key") + assert retriever.api_base == "https://api.openai.com/v1" + + retriever.embed_queries(queries=["test query"]) + assert mock_request.call_args.kwargs["url"] == "https://api.openai.com/v1/embeddings" + mock_request.reset_mock() + + retriever.embed_documents(documents=[Document(content="test document")]) + assert mock_request.call_args.kwargs["url"] == "https://api.openai.com/v1/embeddings" + + +@pytest.mark.unit +@patch("haystack.nodes.retriever._openai_encoder.openai_request") +def test_openai_custom_api_base(mock_request): + with patch("haystack.nodes.retriever._openai_encoder.load_openai_tokenizer"): + retriever = EmbeddingRetriever( + embedding_model="text-embedding-ada-002", api_key="fake_api_key", api_base="https://fake_api_base.com" + ) + assert retriever.api_base == "https://fake_api_base.com" + + retriever.embed_queries(queries=["test query"]) + assert mock_request.call_args.kwargs["url"] == "https://fake_api_base.com/embeddings" + mock_request.reset_mock() + + retriever.embed_documents(documents=[Document(content="test document")]) + assert mock_request.call_args.kwargs["url"] == "https://fake_api_base.com/embeddings" diff --git a/test/preview/components/audio/test_whisper_remote.py b/test/preview/components/audio/test_whisper_remote.py index e1f239aec..465e6a47c 100644 --- a/test/preview/components/audio/test_whisper_remote.py +++ b/test/preview/components/audio/test_whisper_remote.py @@ -29,6 +29,7 @@ class TestRemoteWhisperTranscriber(BaseTestComponent): transcriber = RemoteWhisperTranscriber(api_key="just a test") assert transcriber.model_name == "whisper-1" assert transcriber.api_key == "just a test" + assert transcriber.api_base == "https://api.openai.com/v1" @pytest.mark.unit def test_init_no_key(self): @@ -155,3 +156,31 @@ class TestRemoteWhisperTranscriber(BaseTestComponent): "headers": {"Authorization": f"Bearer whatever"}, "timeout": OPENAI_TIMEOUT, } + + @pytest.mark.unit + @patch("haystack.preview.components.audio.whisper_remote.request_with_retry") + def test_default_api_base(self, mock_request, preview_samples_path): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}' + mock_request.return_value = mock_response + + transcriber = RemoteWhisperTranscriber(api_key="just a test") + assert transcriber.api_base == "https://api.openai.com/v1" + + transcriber.transcribe(audio_files=[preview_samples_path / "audio" / "this is the content of the document.wav"]) + assert mock_request.call_args.kwargs["url"] == "https://api.openai.com/v1/audio/transcriptions" + + @pytest.mark.unit + @patch("haystack.preview.components.audio.whisper_remote.request_with_retry") + def test_custom_api_base(self, mock_request, preview_samples_path): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}' + mock_request.return_value = mock_response + + transcriber = RemoteWhisperTranscriber(api_key="just a test", api_base="https://fake_api_base.com") + assert transcriber.api_base == "https://fake_api_base.com" + + transcriber.transcribe(audio_files=[preview_samples_path / "audio" / "this is the content of the document.wav"]) + assert mock_request.call_args.kwargs["url"] == "https://fake_api_base.com/audio/transcriptions" diff --git a/test/prompt/invocation_layer/test_chatgpt.py b/test/prompt/invocation_layer/test_chatgpt.py new file mode 100644 index 000000000..17799c655 --- /dev/null +++ b/test/prompt/invocation_layer/test_chatgpt.py @@ -0,0 +1,29 @@ +from unittest.mock import patch + +import pytest + +from haystack.nodes.prompt.invocation_layer import ChatGPTInvocationLayer + + +@pytest.mark.unit +@patch("haystack.nodes.prompt.invocation_layer.chatgpt.openai_request") +def test_default_api_base(mock_request): + with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"): + invocation_layer = ChatGPTInvocationLayer(api_key="fake_api_key") + assert invocation_layer.api_base == "https://api.openai.com/v1" + assert invocation_layer.url == "https://api.openai.com/v1/chat/completions" + + invocation_layer.invoke(prompt="dummy_prompt") + assert mock_request.call_args.kwargs["url"] == "https://api.openai.com/v1/chat/completions" + + +@pytest.mark.unit +@patch("haystack.nodes.prompt.invocation_layer.chatgpt.openai_request") +def test_custom_api_base(mock_request): + with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"): + invocation_layer = ChatGPTInvocationLayer(api_key="fake_api_key", api_base="https://fake_api_base.com") + assert invocation_layer.api_base == "https://fake_api_base.com" + assert invocation_layer.url == "https://fake_api_base.com/chat/completions" + + invocation_layer.invoke(prompt="dummy_prompt") + assert mock_request.call_args.kwargs["url"] == "https://fake_api_base.com/chat/completions" diff --git a/test/prompt/invocation_layer/test_openai.py b/test/prompt/invocation_layer/test_openai.py new file mode 100644 index 000000000..98de293b9 --- /dev/null +++ b/test/prompt/invocation_layer/test_openai.py @@ -0,0 +1,29 @@ +from unittest.mock import patch + +import pytest + +from haystack.nodes.prompt.invocation_layer import OpenAIInvocationLayer + + +@pytest.mark.unit +@patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request") +def test_default_api_base(mock_request): + with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"): + invocation_layer = OpenAIInvocationLayer(api_key="fake_api_key") + assert invocation_layer.api_base == "https://api.openai.com/v1" + assert invocation_layer.url == "https://api.openai.com/v1/completions" + + invocation_layer.invoke(prompt="dummy_prompt") + assert mock_request.call_args.kwargs["url"] == "https://api.openai.com/v1/completions" + + +@pytest.mark.unit +@patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request") +def test_custom_api_base(mock_request): + with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"): + invocation_layer = OpenAIInvocationLayer(api_key="fake_api_key", api_base="https://fake_api_base.com") + assert invocation_layer.api_base == "https://fake_api_base.com" + assert invocation_layer.url == "https://fake_api_base.com/completions" + + invocation_layer.invoke(prompt="dummy_prompt") + assert mock_request.call_args.kwargs["url"] == "https://fake_api_base.com/completions"