refactor: Migrate RemoteWhisperTranscriber to OpenAI SDK. (#6149)

* Migrate RemoteWhisperTranscriber to OpenAI SDK

* Migrate RemoteWhisperTranscriber to OpenAI SDK

* Remove unnecessary imports

* Add release notes

* Fix api_key serialization

* Fix linting

* Apply suggestions from code review

Co-authored-by: ZanSara <sarazanzo94@gmail.com>

* Add additional tests for api_key

* Adapt .run() to take ByteStream inputs

* Update docstrings

* Rework implementation to use io.BytesIO

* Update error message

* Add default file name

---------

Co-authored-by: ZanSara <sarazanzo94@gmail.com>
This commit is contained in:
Ashwin Mathur 2023-10-26 19:55:23 +05:30 committed by GitHub
parent 26a22045e4
commit 5f35e7d04a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 278 additions and 259 deletions

View File

@ -1,20 +1,17 @@
from typing import List, Optional, Dict, Any, Union, BinaryIO, Literal, get_args, Sequence
import os
import json
import io
import logging
from pathlib import Path
import os
from typing import Any, Dict, List, Optional
from haystack.preview.utils import request_with_retry
from haystack.preview import component, Document, default_to_dict
import openai
from haystack.preview import Document, component, default_from_dict, default_to_dict
from haystack.preview.dataclasses import ByteStream
logger = logging.getLogger(__name__)
OPENAI_TIMEOUT = float(os.environ.get("HAYSTACK_OPENAI_TIMEOUT_SEC", 600))
WhisperRemoteModel = Literal["whisper-1"]
API_BASE_URL = "https://api.openai.com/v1"
@component
@ -30,108 +27,112 @@ class RemoteWhisperTranscriber:
def __init__(
self,
api_key: str,
model_name: WhisperRemoteModel = "whisper-1",
api_base: str = "https://api.openai.com/v1",
whisper_params: Optional[Dict[str, Any]] = None,
api_key: Optional[str] = None,
model_name: str = "whisper-1",
organization: Optional[str] = None,
api_base_url: str = API_BASE_URL,
**kwargs,
):
"""
Transcribes a list of audio files into a list of Documents.
:param api_key: OpenAI API key.
:param model_name: Name of the model to use. It now accepts only `whisper-1`.
:param organization: The OpenAI-Organization ID, defaults to `None`. For more details, see OpenAI
[documentation](https://platform.openai.com/docs/api-reference/requesting-organization).
:param api_base: OpenAI base URL, defaults to `"https://api.openai.com/v1"`.
:param kwargs: Other parameters to use for the model. These parameters are all sent directly to the OpenAI
endpoint. See OpenAI [documentation](https://platform.openai.com/docs/api-reference/audio) for more details.
Some of the supported parameters:
- `language`: The language of the input audio.
Supplying the input language in ISO-639-1 format
will improve accuracy and latency.
- `prompt`: An optional text to guide the model's
style or continue a previous audio segment.
The prompt should match the audio language.
- `response_format`: The format of the transcript
output, in one of these options: json, text, srt,
verbose_json, or vtt. Defaults to "json". Currently only "json" is supported.
- `temperature`: The sampling temperature, between 0
and 1. Higher values like 0.8 will make the output more
random, while lower values like 0.2 will make it more
focused and deterministic. If set to 0, the model will
use log probability to automatically increase the
temperature until certain thresholds are hit.
"""
if model_name not in get_args(WhisperRemoteModel):
raise ValueError(
f"Model name not recognized. Choose one among: " f"{', '.join(get_args(WhisperRemoteModel))}."
)
if not api_key:
raise ValueError("API key is None.")
# if the user does not provide the API key, check if it is set in the module client
api_key = api_key or openai.api_key
if api_key is None:
try:
api_key = os.environ["OPENAI_API_KEY"]
except KeyError as e:
raise ValueError(
"RemoteWhisperTranscriber expects an OpenAI API key. "
"Set the OPENAI_API_KEY environment variable (recommended) or pass it explicitly."
) from e
openai.api_key = api_key
self.organization = organization
self.model_name = model_name
self.api_key = api_key
self.api_base = api_base
self.whisper_params = whisper_params or {}
self.api_base_url = api_base_url
@component.output_types(documents=List[Document])
def run(self, audio_files: List[Path], whisper_params: Optional[Dict[str, Any]] = None):
"""
Transcribe the audio files into a list of Documents, one for each input file.
For the supported audio formats, languages, and other parameters, see the
[Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text) and the official Whisper
[github repo](https://github.com/openai/whisper).
:param audio_files: a list of paths or binary streams to transcribe
:returns: a list of Documents, one for each file. The content of the document is the transcription text,
while the document's metadata contains all the other values returned by the Whisper model, such as the
alignment data. Another key called `audio_file` contains the path to the audio file used for the
transcription.
"""
if whisper_params is None:
whisper_params = self.whisper_params
documents = self.transcribe(audio_files, **whisper_params)
return {"documents": documents}
def transcribe(self, audio_files: Sequence[Union[str, Path, BinaryIO]], **kwargs) -> List[Document]:
"""
Transcribe the audio files into a list of Documents, one for each input file.
For the supported audio formats, languages, and other parameters, see the
[Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text) and the official Whisper
[github repo](https://github.com/openai/whisper).
:param audio_files: a list of paths or binary streams to transcribe
:returns: a list of transcriptions.
"""
transcriptions = self._raw_transcribe(audio_files=audio_files, **kwargs)
documents = []
for audio, transcript in zip(audio_files, transcriptions):
content = transcript.pop("text")
if not isinstance(audio, (str, Path)):
audio = "<<binary stream>>"
doc = Document(text=content, metadata={"audio_file": audio, **transcript})
documents.append(doc)
return documents
def _raw_transcribe(self, audio_files: Sequence[Union[str, Path, BinaryIO]], **kwargs) -> List[Dict[str, Any]]:
"""
Transcribe the given audio files. Returns a list of strings.
For the supported audio formats, languages, and other parameters, see the
[Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text) and the official Whisper
[github repo](https://github.com/openai/whisper).
:param audio_files: a list of paths or binary streams to transcribe.
:param kwargs: any other parameters that Whisper API can understand.
:returns: a list of transcriptions as they are produced by the Whisper API (JSON).
"""
translate = kwargs.pop("translate", False)
url = f"{self.api_base}/audio/{'translations' if translate else 'transcriptions'}"
data = {"model": self.model_name, **kwargs}
headers = {"Authorization": f"Bearer {self.api_key}"}
transcriptions = []
for audio_file in audio_files:
if isinstance(audio_file, (str, Path)):
audio_file = open(audio_file, "rb")
request_files = ("file", (audio_file.name, audio_file, "application/octet-stream"))
response = request_with_retry(
method="post", url=url, data=data, headers=headers, files=[request_files], timeout=OPENAI_TIMEOUT
# Only response_format = "json" is supported
whisper_params = kwargs
if whisper_params.get("response_format") != "json":
logger.warning(
"RemoteWhisperTranscriber only supports 'response_format: json'. This parameter will be overwritten."
)
transcription = json.loads(response.content)
whisper_params["response_format"] = "json"
self.whisper_params = whisper_params
transcriptions.append(transcription)
return transcriptions
if organization is not None:
openai.organization = organization
def to_dict(self) -> Dict[str, Any]:
"""
This method overrides the default serializer in order to avoid leaking the `api_key` value passed
to the constructor.
Serialize this component to a dictionary.
This method overrides the default serializer in order to
avoid leaking the `api_key` value passed to the constructor.
"""
return default_to_dict(
self, model_name=self.model_name, api_base=self.api_base, whisper_params=self.whisper_params
self,
model_name=self.model_name,
organization=self.organization,
api_base_url=self.api_base_url,
**self.whisper_params,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "RemoteWhisperTranscriber":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)
@component.output_types(documents=List[Document])
def run(self, streams: List[ByteStream]):
"""
Transcribe the audio files into a list of Documents, one for each input file.
For the supported audio formats, languages, and other parameters, see the
[Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text) and the official Whisper
[github repo](https://github.com/openai/whisper).
:param audio_files: a list of ByteStream objects to transcribe.
:returns: a list of Documents, one for each file. The content of the document is the transcription text.
"""
documents = []
for stream in streams:
file = io.BytesIO(stream.data)
try:
file.name = stream.metadata["file_path"]
except KeyError:
file.name = "audio_input.wav"
content = openai.Audio.transcribe(file=file, model=self.model_name, **self.whisper_params)
doc = Document(text=content["text"], metadata=stream.metadata)
documents.append(doc)
return {"documents": documents}

View File

@ -0,0 +1,4 @@
---
preview:
- |
Migrate RemoteWhisperTranscriber to OpenAI SDK.

View File

@ -1,197 +1,202 @@
import os
from unittest.mock import MagicMock, patch
from unittest.mock import patch
import openai
import pytest
from openai.util import convert_to_openai_object
from haystack.preview.dataclasses import Document
from haystack.preview.components.audio.whisper_remote import RemoteWhisperTranscriber, OPENAI_TIMEOUT
from haystack.preview.components.audio.whisper_remote import RemoteWhisperTranscriber
from haystack.preview.dataclasses import ByteStream
def mock_openai_response(response_format="json", **kwargs) -> openai.openai_object.OpenAIObject:
if response_format == "json":
dict_response = {"text": "test transcription"}
# Currently only "json" is supported.
else:
dict_response = {}
return convert_to_openai_object(dict_response)
class TestRemoteWhisperTranscriber:
@pytest.mark.unit
def test_init_unknown_model(self):
with pytest.raises(ValueError, match="not recognized"):
RemoteWhisperTranscriber(model_name="anything", api_key="something")
def test_init_no_key(self, monkeypatch):
openai.api_key = None
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
error_msg = "RemoteWhisperTranscriber expects an OpenAI API key."
with pytest.raises(ValueError, match=error_msg):
RemoteWhisperTranscriber(api_key=None)
def test_init_key_env_var(self, monkeypatch):
openai.api_key = None
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
RemoteWhisperTranscriber(api_key=None)
assert openai.api_key == "test_api_key"
def test_init_key_module_env_and_global_var(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key_2")
openai.api_key = "test_api_key_1"
RemoteWhisperTranscriber(api_key=None)
# The module global variable takes preference
assert openai.api_key == "test_api_key_1"
@pytest.mark.unit
def test_init_default(self):
transcriber = RemoteWhisperTranscriber(api_key="just a test")
transcriber = RemoteWhisperTranscriber(api_key="test_api_key")
assert openai.api_key == "test_api_key"
assert transcriber.model_name == "whisper-1"
assert transcriber.api_key == "just a test"
assert transcriber.api_base == "https://api.openai.com/v1"
assert transcriber.organization is None
assert transcriber.api_base_url == "https://api.openai.com/v1"
assert transcriber.whisper_params == {"response_format": "json"}
@pytest.mark.unit
def test_init_no_key(self):
with pytest.raises(ValueError, match="API key is None"):
RemoteWhisperTranscriber(api_key=None)
def test_init_custom_parameters(self):
transcriber = RemoteWhisperTranscriber(
api_key="test_api_key",
model_name="whisper-1",
organization="test-org",
api_base_url="test_api_url",
language="en",
prompt="test-prompt",
response_format="json",
temperature="0.5",
)
assert openai.api_key == "test_api_key"
assert transcriber.model_name == "whisper-1"
assert transcriber.organization == "test-org"
assert transcriber.api_base_url == "test_api_url"
assert transcriber.whisper_params == {
"language": "en",
"prompt": "test-prompt",
"response_format": "json",
"temperature": "0.5",
}
@pytest.mark.unit
def test_to_dict(self):
transcriber = RemoteWhisperTranscriber(api_key="test")
def test_to_dict_default_parameters(self):
transcriber = RemoteWhisperTranscriber(api_key="test_api_key")
data = transcriber.to_dict()
assert data == {
"type": "RemoteWhisperTranscriber",
"init_parameters": {
"model_name": "whisper-1",
"api_base": "https://api.openai.com/v1",
"whisper_params": {},
"api_base_url": "https://api.openai.com/v1",
"organization": None,
"response_format": "json",
},
}
@pytest.mark.unit
def test_to_dict_with_custom_init_parameters(self):
transcriber = RemoteWhisperTranscriber(
api_key="test",
api_key="test_api_key",
model_name="whisper-1",
api_base="https://my.api.base/something_else/v3",
whisper_params={"return_segments": True, "temperature": [0.1, 0.6, 0.8]},
organization="test-org",
api_base_url="test_api_url",
language="en",
prompt="test-prompt",
response_format="json",
temperature="0.5",
)
data = transcriber.to_dict()
assert data == {
"type": "RemoteWhisperTranscriber",
"init_parameters": {
"model_name": "whisper-1",
"api_base": "https://my.api.base/something_else/v3",
"whisper_params": {"return_segments": True, "temperature": [0.1, 0.6, 0.8]},
"organization": "test-org",
"api_base_url": "test_api_url",
"language": "en",
"prompt": "test-prompt",
"response_format": "json",
"temperature": "0.5",
},
}
@pytest.mark.unit
def test_run_with_path(self, preview_samples_path):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
comp = RemoteWhisperTranscriber(api_key="whatever")
def test_from_dict_with_defualt_parameters(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
with patch("haystack.preview.utils.requests_utils.requests") as mocked_requests:
mocked_requests.request.return_value = mock_response
data = {
"type": "RemoteWhisperTranscriber",
"init_parameters": {
"model_name": "whisper-1",
"api_base_url": "https://api.openai.com/v1",
"organization": None,
"response_format": "json",
},
}
result = comp.run(audio_files=[preview_samples_path / "audio" / "this is the content of the document.wav"])
expected = Document(
text="test transcription",
metadata={
"audio_file": preview_samples_path / "audio" / "this is the content of the document.wav",
"other_metadata": ["other", "meta", "data"],
},
)
assert result["documents"] == [expected]
transcriber = RemoteWhisperTranscriber.from_dict(data)
assert openai.api_key == "test_api_key"
assert transcriber.model_name == "whisper-1"
assert transcriber.organization is None
assert transcriber.api_base_url == "https://api.openai.com/v1"
assert transcriber.whisper_params == {"response_format": "json"}
def test_from_dict_with_custom_init_parameters(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
data = {
"type": "RemoteWhisperTranscriber",
"init_parameters": {
"model_name": "whisper-1",
"organization": "test-org",
"api_base_url": "test_api_url",
"language": "en",
"prompt": "test-prompt",
"response_format": "json",
"temperature": "0.5",
},
}
transcriber = RemoteWhisperTranscriber.from_dict(data)
assert openai.api_key == "test_api_key"
assert transcriber.model_name == "whisper-1"
assert transcriber.organization == "test-org"
assert transcriber.api_base_url == "test_api_url"
assert transcriber.whisper_params == {
"language": "en",
"prompt": "test-prompt",
"response_format": "json",
"temperature": "0.5",
}
def test_from_dict_with_defualt_parameters_no_env_var(self, monkeypatch):
openai.api_key = None
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
data = {
"type": "RemoteWhisperTranscriber",
"init_parameters": {
"model_name": "whisper-1",
"api_base_url": "https://api.openai.com/v1",
"organization": None,
"response_format": "json",
},
}
with pytest.raises(ValueError, match="RemoteWhisperTranscriber expects an OpenAI API key."):
RemoteWhisperTranscriber.from_dict(data)
@pytest.mark.unit
def test_run_with_str(self, preview_samples_path):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
comp = RemoteWhisperTranscriber(api_key="whatever")
def test_run(self, preview_samples_path):
with patch("haystack.preview.components.audio.whisper_remote.openai.Audio") as openai_audio_patch:
model = "whisper-1"
file_path = preview_samples_path / "audio" / "this is the content of the document.wav"
openai_audio_patch.transcribe.side_effect = mock_openai_response
with patch("haystack.preview.utils.requests_utils.requests") as mocked_requests:
mocked_requests.request.return_value = mock_response
transcriber = RemoteWhisperTranscriber(api_key="test_api_key", model_name=model, response_format="json")
with open(file_path, "rb") as audio_stream:
byte_stream = audio_stream.read()
audio_file = ByteStream(byte_stream, metadata={"file_path": str(file_path.absolute())})
result = comp.run(
audio_files=[
str((preview_samples_path / "audio" / "this is the content of the document.wav").absolute())
]
)
expected = Document(
text="test transcription",
metadata={
"audio_file": str(
(preview_samples_path / "audio" / "this is the content of the document.wav").absolute()
),
"other_metadata": ["other", "meta", "data"],
},
)
assert result["documents"] == [expected]
result = transcriber.run(streams=[audio_file])
@pytest.mark.unit
def test_transcribe_with_stream(self, preview_samples_path):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
comp = RemoteWhisperTranscriber(api_key="whatever")
with patch("haystack.preview.utils.requests_utils.requests") as mocked_requests:
mocked_requests.request.return_value = mock_response
with open(preview_samples_path / "audio" / "this is the content of the document.wav", "rb") as audio_stream:
result = comp.transcribe(audio_files=[audio_stream])
expected = Document(
text="test transcription",
metadata={"audio_file": "<<binary stream>>", "other_metadata": ["other", "meta", "data"]},
)
assert result == [expected]
@pytest.mark.unit
def test_api_transcription(self, preview_samples_path):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
comp = RemoteWhisperTranscriber(api_key="whatever")
with patch("haystack.preview.utils.requests_utils.requests") as mocked_requests:
mocked_requests.request.return_value = mock_response
comp.run(audio_files=[preview_samples_path / "audio" / "this is the content of the document.wav"])
requests_params = mocked_requests.request.call_args.kwargs
requests_params.pop("files")
assert requests_params == {
"method": "post",
"url": "https://api.openai.com/v1/audio/transcriptions",
"data": {"model": "whisper-1"},
"headers": {"Authorization": "Bearer whatever"},
"timeout": OPENAI_TIMEOUT,
}
@pytest.mark.unit
def test_api_translation(self, preview_samples_path):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
comp = RemoteWhisperTranscriber(api_key="whatever")
with patch("haystack.preview.utils.requests_utils.requests") as mocked_requests:
mocked_requests.request.return_value = mock_response
comp.run(
audio_files=[preview_samples_path / "audio" / "this is the content of the document.wav"],
whisper_params={"translate": True},
)
requests_params = mocked_requests.request.call_args.kwargs
requests_params.pop("files")
assert requests_params == {
"method": "post",
"url": "https://api.openai.com/v1/audio/translations",
"data": {"model": "whisper-1"},
"headers": {"Authorization": "Bearer whatever"},
"timeout": OPENAI_TIMEOUT,
}
@pytest.mark.unit
@patch("haystack.preview.components.audio.whisper_remote.request_with_retry")
def test_default_api_base(self, mock_request, preview_samples_path):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
mock_request.return_value = mock_response
transcriber = RemoteWhisperTranscriber(api_key="just a test")
assert transcriber.api_base == "https://api.openai.com/v1"
transcriber.transcribe(audio_files=[preview_samples_path / "audio" / "this is the content of the document.wav"])
assert mock_request.call_args.kwargs["url"] == "https://api.openai.com/v1/audio/transcriptions"
@pytest.mark.unit
@patch("haystack.preview.components.audio.whisper_remote.request_with_retry")
def test_custom_api_base(self, mock_request, preview_samples_path):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
mock_request.return_value = mock_response
transcriber = RemoteWhisperTranscriber(api_key="just a test", api_base="https://fake_api_base.com")
assert transcriber.api_base == "https://fake_api_base.com"
transcriber.transcribe(audio_files=[preview_samples_path / "audio" / "this is the content of the document.wav"])
assert mock_request.call_args.kwargs["url"] == "https://fake_api_base.com/audio/transcriptions"
assert result["documents"][0].text == "test transcription"
assert result["documents"][0].metadata["file_path"] == str(file_path.absolute())
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
@ -199,28 +204,37 @@ class TestRemoteWhisperTranscriber:
)
@pytest.mark.integration
def test_whisper_remote_transcriber(self, preview_samples_path):
comp = RemoteWhisperTranscriber(api_key=os.environ.get("OPENAI_API_KEY"))
transcriber = RemoteWhisperTranscriber(api_key=os.environ.get("OPENAI_API_KEY"))
paths = [
preview_samples_path / "audio" / "this is the content of the document.wav",
preview_samples_path / "audio" / "the context for this answer is here.wav",
preview_samples_path / "audio" / "answer.wav",
"rb",
]
audio_files = []
for file_path in paths:
with open(file_path, "rb") as audio_stream:
byte_stream = audio_stream.read()
audio_file = ByteStream(byte_stream, metadata={"file_path": str(file_path.absolute())})
audio_files.append(audio_file)
output = transcriber.run(streams=audio_files)
output = comp.run(
audio_files=[
preview_samples_path / "audio" / "this is the content of the document.wav",
str((preview_samples_path / "audio" / "the context for this answer is here.wav").absolute()),
open(preview_samples_path / "audio" / "answer.wav", "rb"),
]
)
docs = output["documents"]
assert len(docs) == 3
assert docs[0].text.strip().lower() == "this is the content of the document."
assert (
preview_samples_path / "audio" / "this is the content of the document.wav" == docs[0].metadata["audio_file"]
str((preview_samples_path / "audio" / "this is the content of the document.wav").absolute())
== docs[0].metadata["file_path"]
)
assert docs[1].text.strip().lower() == "the context for this answer is here."
assert (
str((preview_samples_path / "audio" / "the context for this answer is here.wav").absolute())
== docs[1].metadata["audio_file"]
== docs[1].metadata["file_path"]
)
assert docs[2].text.strip().lower() == "answer."
assert docs[2].metadata["audio_file"] == "<<binary stream>>"
assert str((preview_samples_path / "audio" / "answer.wav").absolute()) == docs[2].metadata["file_path"]