mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-28 10:26:27 +00:00
feat: Allow setting custom api_base for OpenAI nodes (#5033)
* add changes for api_base * format retriever * Update haystack/nodes/retriever/dense.py Co-authored-by: bogdankostic <bogdankostic@web.de> * Update haystack/nodes/audio/whisper_transcriber.py Co-authored-by: bogdankostic <bogdankostic@web.de> * Update haystack/preview/components/audio/whisper_remote.py Co-authored-by: bogdankostic <bogdankostic@web.de> * Update haystack/nodes/answer_generator/openai.py Co-authored-by: bogdankostic <bogdankostic@web.de> * Update test_retriever.py * Update test_whisper_remote.py * Update test_generator.py * Update test_retriever.py * reformat with black * Update haystack/nodes/prompt/invocation_layer/chatgpt.py Co-authored-by: Daria Fokina <daria.f93@gmail.com> * Add unit tests * apply docstring suggestions --------- Co-authored-by: bogdankostic <bogdankostic@web.de> Co-authored-by: michaelfeil <me@michaelfeil.eu> Co-authored-by: Daria Fokina <daria.f93@gmail.com>
This commit is contained in:
parent
5f6d161cfe
commit
6ea8ae01a2
@ -45,6 +45,7 @@ class OpenAIAnswerGenerator(BaseGenerator):
|
||||
progress_bar: bool = True,
|
||||
prompt_template: Optional[PromptTemplate] = None,
|
||||
context_join_str: str = " ",
|
||||
api_base: str = "https://api.openai.com/v1",
|
||||
):
|
||||
"""
|
||||
:param api_key: Your API key from OpenAI. It is required for this node to work.
|
||||
@ -98,6 +99,7 @@ class OpenAIAnswerGenerator(BaseGenerator):
|
||||
[PromptTemplate](https://docs.haystack.deepset.ai/docs/prompt_node#template-structure).
|
||||
:param context_join_str: The separation string used to join the input documents to create the context
|
||||
used by the PromptTemplate.
|
||||
:param api_base: The base URL for the OpenAI API, defaults to `"https://api.openai.com/v1"`.
|
||||
"""
|
||||
super().__init__(progress_bar=progress_bar)
|
||||
if (examples is None and examples_context is not None) or (examples is not None and examples_context is None):
|
||||
@ -144,6 +146,7 @@ class OpenAIAnswerGenerator(BaseGenerator):
|
||||
self.azure_base_url = azure_base_url
|
||||
self.azure_deployment_name = azure_deployment_name
|
||||
self.api_version = api_version
|
||||
self.api_base = api_base
|
||||
self.model = model
|
||||
self.max_tokens = max_tokens
|
||||
self.top_k = top_k
|
||||
@ -217,7 +220,7 @@ class OpenAIAnswerGenerator(BaseGenerator):
|
||||
if self.using_azure:
|
||||
url = f"{self.azure_base_url}/openai/deployments/{self.azure_deployment_name}/completions?api-version={self.api_version}"
|
||||
else:
|
||||
url = "https://api.openai.com/v1/completions"
|
||||
url = f"{self.api_base}/completions"
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.using_azure:
|
||||
|
@ -41,6 +41,7 @@ class WhisperTranscriber(BaseComponent):
|
||||
api_key: Optional[str] = None,
|
||||
model_name_or_path: WhisperModel = "medium",
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
api_base: str = "https://api.openai.com/v1",
|
||||
) -> None:
|
||||
"""
|
||||
Creates a WhisperTranscriber instance.
|
||||
@ -50,10 +51,11 @@ class WhisperTranscriber(BaseComponent):
|
||||
the API, set thsi value to: "whisper-1" (default).
|
||||
:param device: Device to use for inference. Only used if you're using a local
|
||||
installation of Whisper. If None, the device is automatically selected.
|
||||
:param api_base: The OpenAI API Base url, defaults to `https://api.openai.com/v1`.
|
||||
"""
|
||||
super().__init__()
|
||||
self.api_key = api_key
|
||||
|
||||
self.api_base = api_base
|
||||
self.use_local_whisper = is_whisper_available() and self.api_key is None
|
||||
|
||||
if self.use_local_whisper:
|
||||
@ -108,9 +110,7 @@ class WhisperTranscriber(BaseComponent):
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
request = PreparedRequest()
|
||||
url: str = (
|
||||
"https://api.openai.com/v1/audio/transcriptions"
|
||||
if not translate
|
||||
else "https://api.openai.com/v1/audio/translations"
|
||||
f"{self.api_base}/audio/transcriptions" if not translate else f"{self.api_base}/audio/translations"
|
||||
)
|
||||
|
||||
request.prepare(
|
||||
|
@ -20,9 +20,24 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, api_key: str, model_name_or_path: str = "gpt-3.5-turbo", max_length: Optional[int] = 500, **kwargs
|
||||
self,
|
||||
api_key: str,
|
||||
model_name_or_path: str = "gpt-3.5-turbo",
|
||||
max_length: Optional[int] = 500,
|
||||
api_base: str = "https://api.openai.com/v1",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(api_key, model_name_or_path, max_length, **kwargs)
|
||||
"""
|
||||
Creates an instance of ChatGPTInvocationLayer for OpenAI's GPT-3.5 GPT-4 models.
|
||||
|
||||
:param model_name_or_path: The name or path of the underlying model.
|
||||
:param max_length: The maximum number of tokens the output text can have.
|
||||
:param api_key: The OpenAI API key.
|
||||
:param api_base: The OpenAI API Base url, defaults to `https://api.openai.com/v1`.
|
||||
:param kwargs: Additional keyword arguments passed to the underlying model.
|
||||
[See OpenAI documentation](https://platform.openai.com/docs/api-reference/chat).
|
||||
"""
|
||||
super().__init__(api_key, model_name_or_path, max_length, api_base=api_base, **kwargs)
|
||||
|
||||
def invoke(self, *args, **kwargs):
|
||||
"""
|
||||
@ -125,7 +140,7 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer):
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return "https://api.openai.com/v1/chat/completions"
|
||||
return f"{self.api_base}/chat/completions"
|
||||
|
||||
@classmethod
|
||||
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
|
||||
|
@ -27,7 +27,12 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, api_key: str, model_name_or_path: str = "text-davinci-003", max_length: Optional[int] = 100, **kwargs
|
||||
self,
|
||||
api_key: str,
|
||||
model_name_or_path: str = "text-davinci-003",
|
||||
max_length: Optional[int] = 100,
|
||||
api_base: str = "https://api.openai.com/v1",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates an instance of OpenAIInvocationLayer for OpenAI's GPT-3 InstructGPT models.
|
||||
@ -35,6 +40,7 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
||||
:param model_name_or_path: The name or path of the underlying model.
|
||||
:param max_length: The maximum number of tokens the output text can have.
|
||||
:param api_key: The OpenAI API key.
|
||||
:param api_base: The OpenAI API Base url, defaults to `https://api.openai.com/v1`.
|
||||
:param kwargs: Additional keyword arguments passed to the underlying model. Due to reflective construction of
|
||||
all PromptModelInvocationLayer instances, this instance of OpenAIInvocationLayer might receive some unrelated
|
||||
kwargs. Only the kwargs relevant to OpenAIInvocationLayer are considered. The list of OpenAI-relevant
|
||||
@ -48,6 +54,7 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
||||
f"api_key {api_key} must be a valid OpenAI key. Visit https://openai.com/api/ to get one."
|
||||
)
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
|
||||
# 16 is the default length for answers from OpenAI shown in the docs
|
||||
# here, https://platform.openai.com/docs/api-reference/completions/create.
|
||||
@ -86,7 +93,7 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return "https://api.openai.com/v1/completions"
|
||||
return f"{self.api_base}/completions"
|
||||
|
||||
@property
|
||||
def headers(self) -> Dict[str, str]:
|
||||
|
@ -37,7 +37,7 @@ class _OpenAIEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
if self.using_azure:
|
||||
self.url = f"{retriever.azure_base_url}/openai/deployments/{retriever.azure_deployment_name}/embeddings?api-version={retriever.api_version}"
|
||||
else:
|
||||
self.url = "https://api.openai.com/v1/embeddings"
|
||||
self.url = f"{retriever.api_base}/embeddings"
|
||||
|
||||
self.api_key = retriever.api_key
|
||||
self.batch_size = min(64, retriever.batch_size)
|
||||
|
@ -1453,6 +1453,7 @@ class EmbeddingRetriever(DenseRetriever):
|
||||
azure_api_version: str = "2022-12-01",
|
||||
azure_base_url: Optional[str] = None,
|
||||
azure_deployment_name: Optional[str] = None,
|
||||
api_base: str = "https://api.openai.com/v1",
|
||||
):
|
||||
"""
|
||||
:param document_store: An instance of DocumentStore from which to retrieve documents.
|
||||
@ -1512,6 +1513,7 @@ class EmbeddingRetriever(DenseRetriever):
|
||||
This parameter is an OpenAI Azure endpoint, usually in the form `https://<your-endpoint>.openai.azure.com'
|
||||
:param azure_deployment_name: The name of the Azure OpenAI API deployment. If not supplied, Azure OpenAI API
|
||||
will not be used.
|
||||
:param api_base: The OpenAI API base URL, defaults to `"https://api.openai.com/v1"`
|
||||
"""
|
||||
if embed_meta_fields is None:
|
||||
embed_meta_fields = []
|
||||
@ -1535,6 +1537,7 @@ class EmbeddingRetriever(DenseRetriever):
|
||||
self.use_auth_token = use_auth_token
|
||||
self.scale_score = scale_score
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.api_version = azure_api_version
|
||||
self.azure_base_url = azure_base_url
|
||||
self.azure_deployment_name = azure_deployment_name
|
||||
|
@ -38,12 +38,15 @@ class RemoteWhisperTranscriber:
|
||||
class Output(ComponentOutput):
|
||||
documents: List[Document]
|
||||
|
||||
def __init__(self, api_key: str, model_name: WhisperRemoteModel = "whisper-1"):
|
||||
def __init__(
|
||||
self, api_key: str, model_name: WhisperRemoteModel = "whisper-1", api_base: str = "https://api.openai.com/v1"
|
||||
):
|
||||
"""
|
||||
Transcribes a list of audio files into a list of Documents.
|
||||
|
||||
:param api_key: OpenAI API key.
|
||||
:param model_name_or_path: Name of the model to use. It now accepts only `whisper-1`.
|
||||
:param model_name: Name of the model to use. It now accepts only `whisper-1`.
|
||||
:param api_base: OpenAI base URL, defaults to `"https://api.openai.com/v1"`.
|
||||
"""
|
||||
if model_name not in get_args(WhisperRemoteModel):
|
||||
raise ValueError(
|
||||
@ -53,6 +56,7 @@ class RemoteWhisperTranscriber:
|
||||
raise ValueError("API key is None.")
|
||||
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
|
||||
self.model_name = model_name
|
||||
|
||||
@ -109,7 +113,7 @@ class RemoteWhisperTranscriber:
|
||||
:returns: a list of transcriptions as they are produced by the Whisper API (JSON).
|
||||
"""
|
||||
translate = kwargs.pop("translate", False)
|
||||
url = f"https://api.openai.com/v1/audio/{'translations' if translate else 'transcriptions'}"
|
||||
url = f"{self.api_base}/audio/{'translations' if translate else 'transcriptions'}"
|
||||
data = {"model": self.model_name, **kwargs}
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
|
@ -33,6 +33,26 @@ def test_rag_deprecation():
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.answer_generator.openai.openai_request")
|
||||
def test_openai_answer_generator_default_api_base(mock_request):
|
||||
with patch("haystack.nodes.answer_generator.openai.load_openai_tokenizer"):
|
||||
generator = OpenAIAnswerGenerator(api_key="fake_api_key")
|
||||
assert generator.api_base == "https://api.openai.com/v1"
|
||||
generator.predict(query="test query", documents=[Document(content="test document")])
|
||||
assert mock_request.call_args.kwargs["url"] == "https://api.openai.com/v1/completions"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.answer_generator.openai.openai_request")
|
||||
def test_openai_answer_generator_custom_api_base(mock_request):
|
||||
with patch("haystack.nodes.answer_generator.openai.load_openai_tokenizer"):
|
||||
generator = OpenAIAnswerGenerator(api_key="fake_api_key", api_base="https://fake_api_base.com")
|
||||
assert generator.api_base == "https://fake_api_base.com"
|
||||
generator.predict(query="test query", documents=[Document(content="test document")])
|
||||
assert mock_request.call_args.kwargs["url"] == "https://fake_api_base.com/completions"
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform in ["win32", "cygwin"], reason="Causes OOM on windows github runner")
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.generator
|
||||
|
@ -325,12 +325,14 @@ def test_openai_embedding_retriever_selection():
|
||||
assert er.model_format == "openai"
|
||||
assert er.embedding_encoder.query_encoder_model == "text-embedding-ada-002"
|
||||
assert er.embedding_encoder.doc_encoder_model == "text-embedding-ada-002"
|
||||
assert er.api_base == "https://api.openai.com/v1"
|
||||
|
||||
# but also support old ada and other text-search-<modelname>-*-001 models
|
||||
er = EmbeddingRetriever(embedding_model="ada", document_store=None)
|
||||
assert er.model_format == "openai"
|
||||
assert er.embedding_encoder.query_encoder_model == "text-search-ada-query-001"
|
||||
assert er.embedding_encoder.doc_encoder_model == "text-search-ada-doc-001"
|
||||
assert er.api_base == "https://api.openai.com/v1"
|
||||
|
||||
# but also support old babbage and other text-search-<modelname>-*-001 models
|
||||
er = EmbeddingRetriever(embedding_model="babbage", document_store=None)
|
||||
@ -1270,3 +1272,35 @@ def test_web_retriever_mode_snippets(monkeypatch):
|
||||
web_retriever = WebRetriever(api_key="", top_search_results=2)
|
||||
result = web_retriever.retrieve(query="Who is the father of Arya Stark?")
|
||||
assert result == expected_search_results["documents"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.retriever._openai_encoder.openai_request")
|
||||
def test_openai_default_api_base(mock_request):
|
||||
with patch("haystack.nodes.retriever._openai_encoder.load_openai_tokenizer"):
|
||||
retriever = EmbeddingRetriever(embedding_model="text-embedding-ada-002", api_key="fake_api_key")
|
||||
assert retriever.api_base == "https://api.openai.com/v1"
|
||||
|
||||
retriever.embed_queries(queries=["test query"])
|
||||
assert mock_request.call_args.kwargs["url"] == "https://api.openai.com/v1/embeddings"
|
||||
mock_request.reset_mock()
|
||||
|
||||
retriever.embed_documents(documents=[Document(content="test document")])
|
||||
assert mock_request.call_args.kwargs["url"] == "https://api.openai.com/v1/embeddings"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.retriever._openai_encoder.openai_request")
|
||||
def test_openai_custom_api_base(mock_request):
|
||||
with patch("haystack.nodes.retriever._openai_encoder.load_openai_tokenizer"):
|
||||
retriever = EmbeddingRetriever(
|
||||
embedding_model="text-embedding-ada-002", api_key="fake_api_key", api_base="https://fake_api_base.com"
|
||||
)
|
||||
assert retriever.api_base == "https://fake_api_base.com"
|
||||
|
||||
retriever.embed_queries(queries=["test query"])
|
||||
assert mock_request.call_args.kwargs["url"] == "https://fake_api_base.com/embeddings"
|
||||
mock_request.reset_mock()
|
||||
|
||||
retriever.embed_documents(documents=[Document(content="test document")])
|
||||
assert mock_request.call_args.kwargs["url"] == "https://fake_api_base.com/embeddings"
|
||||
|
@ -29,6 +29,7 @@ class TestRemoteWhisperTranscriber(BaseTestComponent):
|
||||
transcriber = RemoteWhisperTranscriber(api_key="just a test")
|
||||
assert transcriber.model_name == "whisper-1"
|
||||
assert transcriber.api_key == "just a test"
|
||||
assert transcriber.api_base == "https://api.openai.com/v1"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_no_key(self):
|
||||
@ -155,3 +156,31 @@ class TestRemoteWhisperTranscriber(BaseTestComponent):
|
||||
"headers": {"Authorization": f"Bearer whatever"},
|
||||
"timeout": OPENAI_TIMEOUT,
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.preview.components.audio.whisper_remote.request_with_retry")
|
||||
def test_default_api_base(self, mock_request, preview_samples_path):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
transcriber = RemoteWhisperTranscriber(api_key="just a test")
|
||||
assert transcriber.api_base == "https://api.openai.com/v1"
|
||||
|
||||
transcriber.transcribe(audio_files=[preview_samples_path / "audio" / "this is the content of the document.wav"])
|
||||
assert mock_request.call_args.kwargs["url"] == "https://api.openai.com/v1/audio/transcriptions"
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.preview.components.audio.whisper_remote.request_with_retry")
|
||||
def test_custom_api_base(self, mock_request, preview_samples_path):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
transcriber = RemoteWhisperTranscriber(api_key="just a test", api_base="https://fake_api_base.com")
|
||||
assert transcriber.api_base == "https://fake_api_base.com"
|
||||
|
||||
transcriber.transcribe(audio_files=[preview_samples_path / "audio" / "this is the content of the document.wav"])
|
||||
assert mock_request.call_args.kwargs["url"] == "https://fake_api_base.com/audio/transcriptions"
|
||||
|
29
test/prompt/invocation_layer/test_chatgpt.py
Normal file
29
test/prompt/invocation_layer/test_chatgpt.py
Normal file
@ -0,0 +1,29 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.nodes.prompt.invocation_layer import ChatGPTInvocationLayer
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.prompt.invocation_layer.chatgpt.openai_request")
|
||||
def test_default_api_base(mock_request):
|
||||
with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"):
|
||||
invocation_layer = ChatGPTInvocationLayer(api_key="fake_api_key")
|
||||
assert invocation_layer.api_base == "https://api.openai.com/v1"
|
||||
assert invocation_layer.url == "https://api.openai.com/v1/chat/completions"
|
||||
|
||||
invocation_layer.invoke(prompt="dummy_prompt")
|
||||
assert mock_request.call_args.kwargs["url"] == "https://api.openai.com/v1/chat/completions"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.prompt.invocation_layer.chatgpt.openai_request")
|
||||
def test_custom_api_base(mock_request):
|
||||
with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"):
|
||||
invocation_layer = ChatGPTInvocationLayer(api_key="fake_api_key", api_base="https://fake_api_base.com")
|
||||
assert invocation_layer.api_base == "https://fake_api_base.com"
|
||||
assert invocation_layer.url == "https://fake_api_base.com/chat/completions"
|
||||
|
||||
invocation_layer.invoke(prompt="dummy_prompt")
|
||||
assert mock_request.call_args.kwargs["url"] == "https://fake_api_base.com/chat/completions"
|
29
test/prompt/invocation_layer/test_openai.py
Normal file
29
test/prompt/invocation_layer/test_openai.py
Normal file
@ -0,0 +1,29 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.nodes.prompt.invocation_layer import OpenAIInvocationLayer
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request")
|
||||
def test_default_api_base(mock_request):
|
||||
with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"):
|
||||
invocation_layer = OpenAIInvocationLayer(api_key="fake_api_key")
|
||||
assert invocation_layer.api_base == "https://api.openai.com/v1"
|
||||
assert invocation_layer.url == "https://api.openai.com/v1/completions"
|
||||
|
||||
invocation_layer.invoke(prompt="dummy_prompt")
|
||||
assert mock_request.call_args.kwargs["url"] == "https://api.openai.com/v1/completions"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request")
|
||||
def test_custom_api_base(mock_request):
|
||||
with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"):
|
||||
invocation_layer = OpenAIInvocationLayer(api_key="fake_api_key", api_base="https://fake_api_base.com")
|
||||
assert invocation_layer.api_base == "https://fake_api_base.com"
|
||||
assert invocation_layer.url == "https://fake_api_base.com/completions"
|
||||
|
||||
invocation_layer.invoke(prompt="dummy_prompt")
|
||||
assert mock_request.call_args.kwargs["url"] == "https://fake_api_base.com/completions"
|
Loading…
x
Reference in New Issue
Block a user