mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-06 20:17:14 +00:00
refactor: Migrate RemoteWhisperTranscriber to OpenAI SDK. (#6149)
* Migrate RemoteWhisperTranscriber to OpenAI SDK * Migrate RemoteWhisperTranscriber to OpenAI SDK * Remove unnecessary imports * Add release notes * Fix api_key serialization * Fix linting * Apply suggestions from code review Co-authored-by: ZanSara <sarazanzo94@gmail.com> * Add additional tests for api_key * Adapt .run() to take ByteStream inputs * Update docstrings * Rework implementation to use io.BytesIO * Update error message * Add default file name --------- Co-authored-by: ZanSara <sarazanzo94@gmail.com>
This commit is contained in:
parent
26a22045e4
commit
5f35e7d04a
@ -1,20 +1,17 @@
|
||||
from typing import List, Optional, Dict, Any, Union, BinaryIO, Literal, get_args, Sequence
|
||||
|
||||
import os
|
||||
import json
|
||||
import io
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from haystack.preview.utils import request_with_retry
|
||||
from haystack.preview import component, Document, default_to_dict
|
||||
import openai
|
||||
|
||||
from haystack.preview import Document, component, default_from_dict, default_to_dict
|
||||
from haystack.preview.dataclasses import ByteStream
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
OPENAI_TIMEOUT = float(os.environ.get("HAYSTACK_OPENAI_TIMEOUT_SEC", 600))
|
||||
|
||||
|
||||
WhisperRemoteModel = Literal["whisper-1"]
|
||||
API_BASE_URL = "https://api.openai.com/v1"
|
||||
|
||||
|
||||
@component
|
||||
@ -30,108 +27,112 @@ class RemoteWhisperTranscriber:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model_name: WhisperRemoteModel = "whisper-1",
|
||||
api_base: str = "https://api.openai.com/v1",
|
||||
whisper_params: Optional[Dict[str, Any]] = None,
|
||||
api_key: Optional[str] = None,
|
||||
model_name: str = "whisper-1",
|
||||
organization: Optional[str] = None,
|
||||
api_base_url: str = API_BASE_URL,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Transcribes a list of audio files into a list of Documents.
|
||||
|
||||
:param api_key: OpenAI API key.
|
||||
:param model_name: Name of the model to use. It now accepts only `whisper-1`.
|
||||
:param organization: The OpenAI-Organization ID, defaults to `None`. For more details, see OpenAI
|
||||
[documentation](https://platform.openai.com/docs/api-reference/requesting-organization).
|
||||
:param api_base: OpenAI base URL, defaults to `"https://api.openai.com/v1"`.
|
||||
:param kwargs: Other parameters to use for the model. These parameters are all sent directly to the OpenAI
|
||||
endpoint. See OpenAI [documentation](https://platform.openai.com/docs/api-reference/audio) for more details.
|
||||
Some of the supported parameters:
|
||||
- `language`: The language of the input audio.
|
||||
Supplying the input language in ISO-639-1 format
|
||||
will improve accuracy and latency.
|
||||
- `prompt`: An optional text to guide the model's
|
||||
style or continue a previous audio segment.
|
||||
The prompt should match the audio language.
|
||||
- `response_format`: The format of the transcript
|
||||
output, in one of these options: json, text, srt,
|
||||
verbose_json, or vtt. Defaults to "json". Currently only "json" is supported.
|
||||
- `temperature`: The sampling temperature, between 0
|
||||
and 1. Higher values like 0.8 will make the output more
|
||||
random, while lower values like 0.2 will make it more
|
||||
focused and deterministic. If set to 0, the model will
|
||||
use log probability to automatically increase the
|
||||
temperature until certain thresholds are hit.
|
||||
"""
|
||||
if model_name not in get_args(WhisperRemoteModel):
|
||||
raise ValueError(
|
||||
f"Model name not recognized. Choose one among: " f"{', '.join(get_args(WhisperRemoteModel))}."
|
||||
)
|
||||
if not api_key:
|
||||
raise ValueError("API key is None.")
|
||||
|
||||
# if the user does not provide the API key, check if it is set in the module client
|
||||
api_key = api_key or openai.api_key
|
||||
if api_key is None:
|
||||
try:
|
||||
api_key = os.environ["OPENAI_API_KEY"]
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
"RemoteWhisperTranscriber expects an OpenAI API key. "
|
||||
"Set the OPENAI_API_KEY environment variable (recommended) or pass it explicitly."
|
||||
) from e
|
||||
openai.api_key = api_key
|
||||
|
||||
self.organization = organization
|
||||
self.model_name = model_name
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.whisper_params = whisper_params or {}
|
||||
self.api_base_url = api_base_url
|
||||
|
||||
@component.output_types(documents=List[Document])
|
||||
def run(self, audio_files: List[Path], whisper_params: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
Transcribe the audio files into a list of Documents, one for each input file.
|
||||
|
||||
For the supported audio formats, languages, and other parameters, see the
|
||||
[Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text) and the official Whisper
|
||||
[github repo](https://github.com/openai/whisper).
|
||||
|
||||
:param audio_files: a list of paths or binary streams to transcribe
|
||||
:returns: a list of Documents, one for each file. The content of the document is the transcription text,
|
||||
while the document's metadata contains all the other values returned by the Whisper model, such as the
|
||||
alignment data. Another key called `audio_file` contains the path to the audio file used for the
|
||||
transcription.
|
||||
"""
|
||||
if whisper_params is None:
|
||||
whisper_params = self.whisper_params
|
||||
|
||||
documents = self.transcribe(audio_files, **whisper_params)
|
||||
return {"documents": documents}
|
||||
|
||||
def transcribe(self, audio_files: Sequence[Union[str, Path, BinaryIO]], **kwargs) -> List[Document]:
|
||||
"""
|
||||
Transcribe the audio files into a list of Documents, one for each input file.
|
||||
|
||||
For the supported audio formats, languages, and other parameters, see the
|
||||
[Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text) and the official Whisper
|
||||
[github repo](https://github.com/openai/whisper).
|
||||
|
||||
:param audio_files: a list of paths or binary streams to transcribe
|
||||
:returns: a list of transcriptions.
|
||||
"""
|
||||
transcriptions = self._raw_transcribe(audio_files=audio_files, **kwargs)
|
||||
documents = []
|
||||
for audio, transcript in zip(audio_files, transcriptions):
|
||||
content = transcript.pop("text")
|
||||
if not isinstance(audio, (str, Path)):
|
||||
audio = "<<binary stream>>"
|
||||
doc = Document(text=content, metadata={"audio_file": audio, **transcript})
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
def _raw_transcribe(self, audio_files: Sequence[Union[str, Path, BinaryIO]], **kwargs) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Transcribe the given audio files. Returns a list of strings.
|
||||
|
||||
For the supported audio formats, languages, and other parameters, see the
|
||||
[Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text) and the official Whisper
|
||||
[github repo](https://github.com/openai/whisper).
|
||||
|
||||
:param audio_files: a list of paths or binary streams to transcribe.
|
||||
:param kwargs: any other parameters that Whisper API can understand.
|
||||
:returns: a list of transcriptions as they are produced by the Whisper API (JSON).
|
||||
"""
|
||||
translate = kwargs.pop("translate", False)
|
||||
url = f"{self.api_base}/audio/{'translations' if translate else 'transcriptions'}"
|
||||
data = {"model": self.model_name, **kwargs}
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
transcriptions = []
|
||||
for audio_file in audio_files:
|
||||
if isinstance(audio_file, (str, Path)):
|
||||
audio_file = open(audio_file, "rb")
|
||||
|
||||
request_files = ("file", (audio_file.name, audio_file, "application/octet-stream"))
|
||||
response = request_with_retry(
|
||||
method="post", url=url, data=data, headers=headers, files=[request_files], timeout=OPENAI_TIMEOUT
|
||||
# Only response_format = "json" is supported
|
||||
whisper_params = kwargs
|
||||
if whisper_params.get("response_format") != "json":
|
||||
logger.warning(
|
||||
"RemoteWhisperTranscriber only supports 'response_format: json'. This parameter will be overwritten."
|
||||
)
|
||||
transcription = json.loads(response.content)
|
||||
whisper_params["response_format"] = "json"
|
||||
self.whisper_params = whisper_params
|
||||
|
||||
transcriptions.append(transcription)
|
||||
return transcriptions
|
||||
if organization is not None:
|
||||
openai.organization = organization
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
This method overrides the default serializer in order to avoid leaking the `api_key` value passed
|
||||
to the constructor.
|
||||
Serialize this component to a dictionary.
|
||||
This method overrides the default serializer in order to
|
||||
avoid leaking the `api_key` value passed to the constructor.
|
||||
"""
|
||||
return default_to_dict(
|
||||
self, model_name=self.model_name, api_base=self.api_base, whisper_params=self.whisper_params
|
||||
self,
|
||||
model_name=self.model_name,
|
||||
organization=self.organization,
|
||||
api_base_url=self.api_base_url,
|
||||
**self.whisper_params,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "RemoteWhisperTranscriber":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(documents=List[Document])
|
||||
def run(self, streams: List[ByteStream]):
|
||||
"""
|
||||
Transcribe the audio files into a list of Documents, one for each input file.
|
||||
|
||||
For the supported audio formats, languages, and other parameters, see the
|
||||
[Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text) and the official Whisper
|
||||
[github repo](https://github.com/openai/whisper).
|
||||
|
||||
:param audio_files: a list of ByteStream objects to transcribe.
|
||||
:returns: a list of Documents, one for each file. The content of the document is the transcription text.
|
||||
"""
|
||||
documents = []
|
||||
|
||||
for stream in streams:
|
||||
file = io.BytesIO(stream.data)
|
||||
try:
|
||||
file.name = stream.metadata["file_path"]
|
||||
except KeyError:
|
||||
file.name = "audio_input.wav"
|
||||
|
||||
content = openai.Audio.transcribe(file=file, model=self.model_name, **self.whisper_params)
|
||||
doc = Document(text=content["text"], metadata=stream.metadata)
|
||||
documents.append(doc)
|
||||
|
||||
return {"documents": documents}
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
preview:
|
||||
- |
|
||||
Migrate RemoteWhisperTranscriber to OpenAI SDK.
|
||||
@ -1,197 +1,202 @@
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
from openai.util import convert_to_openai_object
|
||||
|
||||
from haystack.preview.dataclasses import Document
|
||||
from haystack.preview.components.audio.whisper_remote import RemoteWhisperTranscriber, OPENAI_TIMEOUT
|
||||
from haystack.preview.components.audio.whisper_remote import RemoteWhisperTranscriber
|
||||
from haystack.preview.dataclasses import ByteStream
|
||||
|
||||
|
||||
def mock_openai_response(response_format="json", **kwargs) -> openai.openai_object.OpenAIObject:
|
||||
if response_format == "json":
|
||||
dict_response = {"text": "test transcription"}
|
||||
# Currently only "json" is supported.
|
||||
else:
|
||||
dict_response = {}
|
||||
|
||||
return convert_to_openai_object(dict_response)
|
||||
|
||||
|
||||
class TestRemoteWhisperTranscriber:
|
||||
@pytest.mark.unit
|
||||
def test_init_unknown_model(self):
|
||||
with pytest.raises(ValueError, match="not recognized"):
|
||||
RemoteWhisperTranscriber(model_name="anything", api_key="something")
|
||||
def test_init_no_key(self, monkeypatch):
|
||||
openai.api_key = None
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
error_msg = "RemoteWhisperTranscriber expects an OpenAI API key."
|
||||
with pytest.raises(ValueError, match=error_msg):
|
||||
RemoteWhisperTranscriber(api_key=None)
|
||||
|
||||
def test_init_key_env_var(self, monkeypatch):
|
||||
openai.api_key = None
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
|
||||
RemoteWhisperTranscriber(api_key=None)
|
||||
assert openai.api_key == "test_api_key"
|
||||
|
||||
def test_init_key_module_env_and_global_var(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key_2")
|
||||
openai.api_key = "test_api_key_1"
|
||||
RemoteWhisperTranscriber(api_key=None)
|
||||
# The module global variable takes preference
|
||||
assert openai.api_key == "test_api_key_1"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_default(self):
|
||||
transcriber = RemoteWhisperTranscriber(api_key="just a test")
|
||||
transcriber = RemoteWhisperTranscriber(api_key="test_api_key")
|
||||
|
||||
assert openai.api_key == "test_api_key"
|
||||
assert transcriber.model_name == "whisper-1"
|
||||
assert transcriber.api_key == "just a test"
|
||||
assert transcriber.api_base == "https://api.openai.com/v1"
|
||||
assert transcriber.organization is None
|
||||
assert transcriber.api_base_url == "https://api.openai.com/v1"
|
||||
assert transcriber.whisper_params == {"response_format": "json"}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_no_key(self):
|
||||
with pytest.raises(ValueError, match="API key is None"):
|
||||
RemoteWhisperTranscriber(api_key=None)
|
||||
def test_init_custom_parameters(self):
|
||||
transcriber = RemoteWhisperTranscriber(
|
||||
api_key="test_api_key",
|
||||
model_name="whisper-1",
|
||||
organization="test-org",
|
||||
api_base_url="test_api_url",
|
||||
language="en",
|
||||
prompt="test-prompt",
|
||||
response_format="json",
|
||||
temperature="0.5",
|
||||
)
|
||||
|
||||
assert openai.api_key == "test_api_key"
|
||||
assert transcriber.model_name == "whisper-1"
|
||||
assert transcriber.organization == "test-org"
|
||||
assert transcriber.api_base_url == "test_api_url"
|
||||
assert transcriber.whisper_params == {
|
||||
"language": "en",
|
||||
"prompt": "test-prompt",
|
||||
"response_format": "json",
|
||||
"temperature": "0.5",
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict(self):
|
||||
transcriber = RemoteWhisperTranscriber(api_key="test")
|
||||
def test_to_dict_default_parameters(self):
|
||||
transcriber = RemoteWhisperTranscriber(api_key="test_api_key")
|
||||
data = transcriber.to_dict()
|
||||
assert data == {
|
||||
"type": "RemoteWhisperTranscriber",
|
||||
"init_parameters": {
|
||||
"model_name": "whisper-1",
|
||||
"api_base": "https://api.openai.com/v1",
|
||||
"whisper_params": {},
|
||||
"api_base_url": "https://api.openai.com/v1",
|
||||
"organization": None,
|
||||
"response_format": "json",
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_custom_init_parameters(self):
|
||||
transcriber = RemoteWhisperTranscriber(
|
||||
api_key="test",
|
||||
api_key="test_api_key",
|
||||
model_name="whisper-1",
|
||||
api_base="https://my.api.base/something_else/v3",
|
||||
whisper_params={"return_segments": True, "temperature": [0.1, 0.6, 0.8]},
|
||||
organization="test-org",
|
||||
api_base_url="test_api_url",
|
||||
language="en",
|
||||
prompt="test-prompt",
|
||||
response_format="json",
|
||||
temperature="0.5",
|
||||
)
|
||||
data = transcriber.to_dict()
|
||||
assert data == {
|
||||
"type": "RemoteWhisperTranscriber",
|
||||
"init_parameters": {
|
||||
"model_name": "whisper-1",
|
||||
"api_base": "https://my.api.base/something_else/v3",
|
||||
"whisper_params": {"return_segments": True, "temperature": [0.1, 0.6, 0.8]},
|
||||
"organization": "test-org",
|
||||
"api_base_url": "test_api_url",
|
||||
"language": "en",
|
||||
"prompt": "test-prompt",
|
||||
"response_format": "json",
|
||||
"temperature": "0.5",
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run_with_path(self, preview_samples_path):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
|
||||
comp = RemoteWhisperTranscriber(api_key="whatever")
|
||||
def test_from_dict_with_defualt_parameters(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
|
||||
|
||||
with patch("haystack.preview.utils.requests_utils.requests") as mocked_requests:
|
||||
mocked_requests.request.return_value = mock_response
|
||||
data = {
|
||||
"type": "RemoteWhisperTranscriber",
|
||||
"init_parameters": {
|
||||
"model_name": "whisper-1",
|
||||
"api_base_url": "https://api.openai.com/v1",
|
||||
"organization": None,
|
||||
"response_format": "json",
|
||||
},
|
||||
}
|
||||
|
||||
result = comp.run(audio_files=[preview_samples_path / "audio" / "this is the content of the document.wav"])
|
||||
expected = Document(
|
||||
text="test transcription",
|
||||
metadata={
|
||||
"audio_file": preview_samples_path / "audio" / "this is the content of the document.wav",
|
||||
"other_metadata": ["other", "meta", "data"],
|
||||
},
|
||||
)
|
||||
assert result["documents"] == [expected]
|
||||
transcriber = RemoteWhisperTranscriber.from_dict(data)
|
||||
|
||||
assert openai.api_key == "test_api_key"
|
||||
assert transcriber.model_name == "whisper-1"
|
||||
assert transcriber.organization is None
|
||||
assert transcriber.api_base_url == "https://api.openai.com/v1"
|
||||
assert transcriber.whisper_params == {"response_format": "json"}
|
||||
|
||||
def test_from_dict_with_custom_init_parameters(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
|
||||
|
||||
data = {
|
||||
"type": "RemoteWhisperTranscriber",
|
||||
"init_parameters": {
|
||||
"model_name": "whisper-1",
|
||||
"organization": "test-org",
|
||||
"api_base_url": "test_api_url",
|
||||
"language": "en",
|
||||
"prompt": "test-prompt",
|
||||
"response_format": "json",
|
||||
"temperature": "0.5",
|
||||
},
|
||||
}
|
||||
transcriber = RemoteWhisperTranscriber.from_dict(data)
|
||||
|
||||
assert openai.api_key == "test_api_key"
|
||||
assert transcriber.model_name == "whisper-1"
|
||||
assert transcriber.organization == "test-org"
|
||||
assert transcriber.api_base_url == "test_api_url"
|
||||
assert transcriber.whisper_params == {
|
||||
"language": "en",
|
||||
"prompt": "test-prompt",
|
||||
"response_format": "json",
|
||||
"temperature": "0.5",
|
||||
}
|
||||
|
||||
def test_from_dict_with_defualt_parameters_no_env_var(self, monkeypatch):
|
||||
openai.api_key = None
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
|
||||
data = {
|
||||
"type": "RemoteWhisperTranscriber",
|
||||
"init_parameters": {
|
||||
"model_name": "whisper-1",
|
||||
"api_base_url": "https://api.openai.com/v1",
|
||||
"organization": None,
|
||||
"response_format": "json",
|
||||
},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="RemoteWhisperTranscriber expects an OpenAI API key."):
|
||||
RemoteWhisperTranscriber.from_dict(data)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run_with_str(self, preview_samples_path):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
|
||||
comp = RemoteWhisperTranscriber(api_key="whatever")
|
||||
def test_run(self, preview_samples_path):
|
||||
with patch("haystack.preview.components.audio.whisper_remote.openai.Audio") as openai_audio_patch:
|
||||
model = "whisper-1"
|
||||
file_path = preview_samples_path / "audio" / "this is the content of the document.wav"
|
||||
openai_audio_patch.transcribe.side_effect = mock_openai_response
|
||||
|
||||
with patch("haystack.preview.utils.requests_utils.requests") as mocked_requests:
|
||||
mocked_requests.request.return_value = mock_response
|
||||
transcriber = RemoteWhisperTranscriber(api_key="test_api_key", model_name=model, response_format="json")
|
||||
with open(file_path, "rb") as audio_stream:
|
||||
byte_stream = audio_stream.read()
|
||||
audio_file = ByteStream(byte_stream, metadata={"file_path": str(file_path.absolute())})
|
||||
|
||||
result = comp.run(
|
||||
audio_files=[
|
||||
str((preview_samples_path / "audio" / "this is the content of the document.wav").absolute())
|
||||
]
|
||||
)
|
||||
expected = Document(
|
||||
text="test transcription",
|
||||
metadata={
|
||||
"audio_file": str(
|
||||
(preview_samples_path / "audio" / "this is the content of the document.wav").absolute()
|
||||
),
|
||||
"other_metadata": ["other", "meta", "data"],
|
||||
},
|
||||
)
|
||||
assert result["documents"] == [expected]
|
||||
result = transcriber.run(streams=[audio_file])
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_transcribe_with_stream(self, preview_samples_path):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
|
||||
comp = RemoteWhisperTranscriber(api_key="whatever")
|
||||
|
||||
with patch("haystack.preview.utils.requests_utils.requests") as mocked_requests:
|
||||
mocked_requests.request.return_value = mock_response
|
||||
|
||||
with open(preview_samples_path / "audio" / "this is the content of the document.wav", "rb") as audio_stream:
|
||||
result = comp.transcribe(audio_files=[audio_stream])
|
||||
expected = Document(
|
||||
text="test transcription",
|
||||
metadata={"audio_file": "<<binary stream>>", "other_metadata": ["other", "meta", "data"]},
|
||||
)
|
||||
assert result == [expected]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_api_transcription(self, preview_samples_path):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
|
||||
comp = RemoteWhisperTranscriber(api_key="whatever")
|
||||
|
||||
with patch("haystack.preview.utils.requests_utils.requests") as mocked_requests:
|
||||
mocked_requests.request.return_value = mock_response
|
||||
|
||||
comp.run(audio_files=[preview_samples_path / "audio" / "this is the content of the document.wav"])
|
||||
requests_params = mocked_requests.request.call_args.kwargs
|
||||
requests_params.pop("files")
|
||||
assert requests_params == {
|
||||
"method": "post",
|
||||
"url": "https://api.openai.com/v1/audio/transcriptions",
|
||||
"data": {"model": "whisper-1"},
|
||||
"headers": {"Authorization": "Bearer whatever"},
|
||||
"timeout": OPENAI_TIMEOUT,
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_api_translation(self, preview_samples_path):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
|
||||
comp = RemoteWhisperTranscriber(api_key="whatever")
|
||||
|
||||
with patch("haystack.preview.utils.requests_utils.requests") as mocked_requests:
|
||||
mocked_requests.request.return_value = mock_response
|
||||
|
||||
comp.run(
|
||||
audio_files=[preview_samples_path / "audio" / "this is the content of the document.wav"],
|
||||
whisper_params={"translate": True},
|
||||
)
|
||||
requests_params = mocked_requests.request.call_args.kwargs
|
||||
requests_params.pop("files")
|
||||
assert requests_params == {
|
||||
"method": "post",
|
||||
"url": "https://api.openai.com/v1/audio/translations",
|
||||
"data": {"model": "whisper-1"},
|
||||
"headers": {"Authorization": "Bearer whatever"},
|
||||
"timeout": OPENAI_TIMEOUT,
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.preview.components.audio.whisper_remote.request_with_retry")
|
||||
def test_default_api_base(self, mock_request, preview_samples_path):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
transcriber = RemoteWhisperTranscriber(api_key="just a test")
|
||||
assert transcriber.api_base == "https://api.openai.com/v1"
|
||||
|
||||
transcriber.transcribe(audio_files=[preview_samples_path / "audio" / "this is the content of the document.wav"])
|
||||
assert mock_request.call_args.kwargs["url"] == "https://api.openai.com/v1/audio/transcriptions"
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.preview.components.audio.whisper_remote.request_with_retry")
|
||||
def test_custom_api_base(self, mock_request, preview_samples_path):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
transcriber = RemoteWhisperTranscriber(api_key="just a test", api_base="https://fake_api_base.com")
|
||||
assert transcriber.api_base == "https://fake_api_base.com"
|
||||
|
||||
transcriber.transcribe(audio_files=[preview_samples_path / "audio" / "this is the content of the document.wav"])
|
||||
assert mock_request.call_args.kwargs["url"] == "https://fake_api_base.com/audio/transcriptions"
|
||||
assert result["documents"][0].text == "test transcription"
|
||||
assert result["documents"][0].metadata["file_path"] == str(file_path.absolute())
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY", None),
|
||||
@ -199,28 +204,37 @@ class TestRemoteWhisperTranscriber:
|
||||
)
|
||||
@pytest.mark.integration
|
||||
def test_whisper_remote_transcriber(self, preview_samples_path):
|
||||
comp = RemoteWhisperTranscriber(api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
transcriber = RemoteWhisperTranscriber(api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
|
||||
paths = [
|
||||
preview_samples_path / "audio" / "this is the content of the document.wav",
|
||||
preview_samples_path / "audio" / "the context for this answer is here.wav",
|
||||
preview_samples_path / "audio" / "answer.wav",
|
||||
"rb",
|
||||
]
|
||||
|
||||
audio_files = []
|
||||
for file_path in paths:
|
||||
with open(file_path, "rb") as audio_stream:
|
||||
byte_stream = audio_stream.read()
|
||||
audio_file = ByteStream(byte_stream, metadata={"file_path": str(file_path.absolute())})
|
||||
audio_files.append(audio_file)
|
||||
|
||||
output = transcriber.run(streams=audio_files)
|
||||
|
||||
output = comp.run(
|
||||
audio_files=[
|
||||
preview_samples_path / "audio" / "this is the content of the document.wav",
|
||||
str((preview_samples_path / "audio" / "the context for this answer is here.wav").absolute()),
|
||||
open(preview_samples_path / "audio" / "answer.wav", "rb"),
|
||||
]
|
||||
)
|
||||
docs = output["documents"]
|
||||
assert len(docs) == 3
|
||||
|
||||
assert docs[0].text.strip().lower() == "this is the content of the document."
|
||||
assert (
|
||||
preview_samples_path / "audio" / "this is the content of the document.wav" == docs[0].metadata["audio_file"]
|
||||
str((preview_samples_path / "audio" / "this is the content of the document.wav").absolute())
|
||||
== docs[0].metadata["file_path"]
|
||||
)
|
||||
|
||||
assert docs[1].text.strip().lower() == "the context for this answer is here."
|
||||
assert (
|
||||
str((preview_samples_path / "audio" / "the context for this answer is here.wav").absolute())
|
||||
== docs[1].metadata["audio_file"]
|
||||
== docs[1].metadata["file_path"]
|
||||
)
|
||||
|
||||
assert docs[2].text.strip().lower() == "answer."
|
||||
assert docs[2].metadata["audio_file"] == "<<binary stream>>"
|
||||
assert str((preview_samples_path / "audio" / "answer.wav").absolute()) == docs[2].metadata["file_path"]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user