Add Whisper node (#4335)

* Add Whisper node

* Add support for audio path, improve tests

* Add docs

* Improve tests
This commit is contained in:
Vladimir Blagojevic 2023-03-13 16:17:07 +01:00 committed by GitHub
parent 28724e2e25
commit 98256ecf57
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 250 additions and 0 deletions

View File

@ -48,6 +48,8 @@ from haystack.nodes.retriever import (
from haystack.nodes.summarizer import BaseSummarizer, TransformersSummarizer
from haystack.nodes.translator import BaseTranslator, TransformersTranslator
from haystack.nodes.audio import WhisperTranscriber, WhisperModel
Crawler = safe_import("haystack.nodes.connector.crawler", "Crawler", "crawler") # Has optional dependencies
AnswerToSpeech = safe_import(
"haystack.nodes.audio.answer_to_speech", "AnswerToSpeech", "audio"

View File

@ -1,4 +1,5 @@
from haystack.utils.import_utils import safe_import
from haystack.nodes.audio.whisper_transcriber import WhisperTranscriber, WhisperModel
AnswerToSpeech = safe_import(
"haystack.nodes.audio.answer_to_speech", "AnswerToSpeech", "audio"

View File

@ -0,0 +1,187 @@
import json
from typing import List, Optional, Dict, Any, Union, BinaryIO, Literal
import requests
import torch
from requests import PreparedRequest
from haystack.errors import OpenAIError, OpenAIRateLimitError
from haystack.nodes.base import BaseComponent
from haystack.utils.import_utils import is_whisper_available
WhisperModel = Literal["tiny", "small", "medium", "large", "large-v2"]
class WhisperTranscriber(BaseComponent):
"""
Transcribes audio files using OpenAI's Whisper. This class supports two underlying implementations:
- API (default): Uses the OpenAI API and requires an API key. See blog
[post](https://beta.openai.com/docs/api-reference/whisper for more details.) for more details.
- Local (requires installation of whisper): Uses the local installation
of [whisper](https://github.com/openai/whisper).
If you are using local installation of whisper, install whisper following the instructions available on
the Whisper [github repo](https://github.com/openai/whisper) and omit the api_key parameter.
If you are using the API implementation, you need to provide an api_key. You can get one by signing up
for an OpenAI account [here](https://beta.openai.com/).
For the supported audio formats, languages and other parameters, see the Whisper API
[documentation](https://platform.openai.com/docs/guides/speech-to-text) and the official Whisper
[github repo](https://github.com/openai/whisper).
"""
# If it's not a decision component, there is only one outgoing edge
outgoing_edges = 1
def __init__(
self,
api_key: Optional[str] = None,
model_name_or_path: WhisperModel = "medium",
device: Optional[Union[str, torch.device]] = None,
) -> None:
"""
Creates a WhisperTranscriber instance.
:param api_key: OpenAI API key. If None, local installation of whisper is used.
:param model_name_or_path: Name of the model to use. If using local installation of whisper, this
value has to be one of the following: "tiny", "small", "medium", "large", "large-v2". If using
the API, this value has to be "whisper-1" (default).
:param device: Device to use for inference. This parameter is only used if you are using local
installation of whisper. If None, the device is automatically selected.
"""
super().__init__()
self.api_key = api_key
self.use_local_whisper = is_whisper_available() and self.api_key is None
if self.use_local_whisper:
import whisper
self._model = whisper.load_model(model_name_or_path, device=device)
else:
if api_key is None:
raise ValueError(
"Please provide a valid api_key for OpenAI API. Alternatively, "
"install OpenAI whisper (see https://github.com/openai/whisper for more details)."
)
def transcribe(
self,
audio_file: Union[str, BinaryIO],
language: Optional[str] = None,
return_segments: bool = False,
translate: bool = False,
**kwargs,
) -> Dict[str, Any]:
"""
Transcribe audio file.
:param audio_file: Path to audio file or a binary file-like object.
:param language: Language of the audio file. If None, the language is automatically detected.
:param return_segments: If True, returns the transcription for each segment of the audio file.
:param translate: If True, translates the transcription to English.
"""
transcript: Dict[str, Any] = {}
new_kwargs = {k: v for k, v in kwargs.items() if v is not None}
if language is not None:
new_kwargs["language"] = language
if self.use_local_whisper:
new_kwargs["return_segments"] = return_segments
transcript = self._invoke_local(audio_file, translate, **new_kwargs)
elif self.api_key:
transcript = self._invoke_api(audio_file, translate, **new_kwargs)
return transcript
def _invoke_api(
self, audio_file: Union[str, BinaryIO], translate: Optional[bool] = False, **kwargs
) -> Dict[str, Any]:
if isinstance(audio_file, str):
with open(audio_file, "rb") as f:
return self._invoke_api(f, translate, **kwargs)
else:
headers = {"Authorization": f"Bearer {self.api_key}"}
request = PreparedRequest()
url: str = (
"https://api.openai.com/v1/audio/transcriptions"
if not translate
else "https://api.openai.com/v1/audio/translations"
)
request.prepare(
method="POST",
url=url,
headers=headers,
data={"model": "whisper-1", **kwargs},
files=[("file", (audio_file.name, audio_file, "application/octet-stream"))],
)
response = requests.post(url, data=request.body, headers=request.headers, timeout=600)
if response.status_code != 200:
openai_error: OpenAIError
if response.status_code == 429:
openai_error = OpenAIRateLimitError(f"API rate limit exceeded: {response.text}")
else:
openai_error = OpenAIError(
f"OpenAI returned an error.\n"
f"Status code: {response.status_code}\n"
f"Response body: {response.text}",
status_code=response.status_code,
)
raise openai_error
return json.loads(response.content)
def _invoke_local(
self, audio_file: Union[str, BinaryIO], translate: Optional[bool] = False, **kwargs
) -> Dict[str, Any]:
if isinstance(audio_file, str):
with open(audio_file, "rb") as f:
return self._invoke_local(f, translate, **kwargs)
else:
return_segments = kwargs.pop("return_segments", None)
kwargs["task"] = "translate" if translate else "transcribe"
transcription = self._model.transcribe(audio_file.name, **kwargs)
if not return_segments:
transcription.pop("segments", None)
return transcription
def run(self, audio_file: Union[str, BinaryIO], language: Optional[str] = None, return_segments: bool = False, translate: bool = False): # type: ignore
"""
Transcribe audio file.
:param audio_file: Path to audio file or a binary file-like object.
:param language: Language of the audio file. If None, the language is automatically detected.
:param return_segments: If True, returns the transcription for each segment of the audio file.
:param translate: If True, translates the transcription to English.
"""
document = self.transcribe(audio_file, language, return_segments, translate)
output = {"documents": [document]}
return output, "output_1"
def run_batch(self, audio_files: List[Union[str, BinaryIO]], language: Optional[str] = None, return_segments: bool = False, translate: bool = False): # type: ignore
"""
Transcribe audio files.
:param audio_files: List of paths to audio files or binary file-like objects.
:param language: Language of the audio files. If None, the language is automatically detected.
:param return_segments: If True, returns the transcription for each segment of the audio files.
:param translate: If True, translates the transcription to English.
"""
documents = []
for audio in audio_files:
document = self.transcribe(audio, language, return_segments, translate)
documents.append(document)
output = {"documents": documents}
return output, "output_1"

View File

@ -6,6 +6,7 @@ import tarfile
import zipfile
import logging
import importlib
import importlib.util
from pathlib import Path
import requests
@ -118,3 +119,7 @@ def fetch_archive_from_http(
)
return True
def is_whisper_available():
return importlib.util.find_spec("whisper") is not None

View File

@ -156,6 +156,7 @@ audio = [
"protobuf<=3.20.1",
"soundfile< 0.12.0",
"numpy<1.24", # Keep compatibility with latest numba
"openai-whisper"
]
beir = [
"beir; platform_system != 'Windows'",

View File

@ -0,0 +1,54 @@
import os
import pytest
from haystack.nodes.audio import WhisperTranscriber
from haystack.utils.import_utils import is_whisper_available
from ..conftest import SAMPLES_PATH
@pytest.mark.skipif(os.environ.get("OPENAI_API_KEY", "") == "", reason="OpenAI API key not found")
@pytest.mark.integration
def test_whisper_api_transcribe():
w = WhisperTranscriber(api_key=os.environ.get("OPENAI_API_KEY"))
audio_object_transcript, audio_path_transcript = transcribe_test_helper(w)
assert "segments" not in audio_object_transcript and "segments" not in audio_path_transcript
@pytest.mark.skipif(os.environ.get("OPENAI_API_KEY", "") == "", reason="OpenAI API key not found")
@pytest.mark.integration
def test_whisper_api_transcribe_with_params():
w = WhisperTranscriber(api_key=os.environ.get("OPENAI_API_KEY"))
audio_object_transcript, audio_path_transcript = transcribe_test_helper(w)
assert "segments" not in audio_object_transcript and "segments" not in audio_path_transcript
@pytest.mark.integration
@pytest.mark.skipif(not is_whisper_available(), reason="Whisper is not installed")
def test_whisper_local_transcribe():
w = WhisperTranscriber()
audio_object_transcript, audio_path_transcript = transcribe_test_helper(w, language="en")
assert "segments" not in audio_object_transcript and "segments" not in audio_path_transcript
@pytest.mark.integration
@pytest.mark.skipif(not is_whisper_available(), reason="Whisper is not installed")
def test_whisper_local_transcribe_with_params():
w = WhisperTranscriber()
audio_object, audio_path = transcribe_test_helper(w, language="en", return_segments=True)
assert len(audio_object["segments"]) == 1 and len(audio_path["segments"]) == 1
def transcribe_test_helper(whisper, **kwargs):
# this file is 1 second long and contains the word "answer"
file_path = str(SAMPLES_PATH / "audio" / "answer.wav")
# using audio object
with open(file_path, mode="rb") as audio_file:
audio_object_transcript = whisper.transcribe(audio_file=audio_file, **kwargs)
assert "answer" in audio_object_transcript["text"].lower()
# using path to audio file
audio_path_transcript = whisper.transcribe(audio_file=file_path, **kwargs)
assert "answer" in audio_path_transcript["text"].lower()
return audio_object_transcript, audio_path_transcript