mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-15 21:10:46 +00:00
Add Whisper node (#4335)
* Add Whisper node * Add support for audio path, improve tests * Add docs * Improve tests
This commit is contained in:
parent
28724e2e25
commit
98256ecf57
@ -48,6 +48,8 @@ from haystack.nodes.retriever import (
|
||||
from haystack.nodes.summarizer import BaseSummarizer, TransformersSummarizer
|
||||
from haystack.nodes.translator import BaseTranslator, TransformersTranslator
|
||||
|
||||
from haystack.nodes.audio import WhisperTranscriber, WhisperModel
|
||||
|
||||
Crawler = safe_import("haystack.nodes.connector.crawler", "Crawler", "crawler") # Has optional dependencies
|
||||
AnswerToSpeech = safe_import(
|
||||
"haystack.nodes.audio.answer_to_speech", "AnswerToSpeech", "audio"
|
||||
|
@ -1,4 +1,5 @@
|
||||
from haystack.utils.import_utils import safe_import
|
||||
from haystack.nodes.audio.whisper_transcriber import WhisperTranscriber, WhisperModel
|
||||
|
||||
AnswerToSpeech = safe_import(
|
||||
"haystack.nodes.audio.answer_to_speech", "AnswerToSpeech", "audio"
|
||||
|
187
haystack/nodes/audio/whisper_transcriber.py
Normal file
187
haystack/nodes/audio/whisper_transcriber.py
Normal file
@ -0,0 +1,187 @@
|
||||
import json
|
||||
|
||||
from typing import List, Optional, Dict, Any, Union, BinaryIO, Literal
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from requests import PreparedRequest
|
||||
|
||||
from haystack.errors import OpenAIError, OpenAIRateLimitError
|
||||
from haystack.nodes.base import BaseComponent
|
||||
from haystack.utils.import_utils import is_whisper_available
|
||||
|
||||
|
||||
WhisperModel = Literal["tiny", "small", "medium", "large", "large-v2"]
|
||||
|
||||
|
||||
class WhisperTranscriber(BaseComponent):
|
||||
"""
|
||||
Transcribes audio files using OpenAI's Whisper. This class supports two underlying implementations:
|
||||
|
||||
- API (default): Uses the OpenAI API and requires an API key. See blog
|
||||
[post](https://beta.openai.com/docs/api-reference/whisper for more details.) for more details.
|
||||
- Local (requires installation of whisper): Uses the local installation
|
||||
of [whisper](https://github.com/openai/whisper).
|
||||
|
||||
If you are using local installation of whisper, install whisper following the instructions available on
|
||||
the Whisper [github repo](https://github.com/openai/whisper) and omit the api_key parameter.
|
||||
|
||||
If you are using the API implementation, you need to provide an api_key. You can get one by signing up
|
||||
for an OpenAI account [here](https://beta.openai.com/).
|
||||
|
||||
For the supported audio formats, languages and other parameters, see the Whisper API
|
||||
[documentation](https://platform.openai.com/docs/guides/speech-to-text) and the official Whisper
|
||||
[github repo](https://github.com/openai/whisper).
|
||||
"""
|
||||
|
||||
# If it's not a decision component, there is only one outgoing edge
|
||||
outgoing_edges = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model_name_or_path: WhisperModel = "medium",
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Creates a WhisperTranscriber instance.
|
||||
|
||||
:param api_key: OpenAI API key. If None, local installation of whisper is used.
|
||||
:param model_name_or_path: Name of the model to use. If using local installation of whisper, this
|
||||
value has to be one of the following: "tiny", "small", "medium", "large", "large-v2". If using
|
||||
the API, this value has to be "whisper-1" (default).
|
||||
:param device: Device to use for inference. This parameter is only used if you are using local
|
||||
installation of whisper. If None, the device is automatically selected.
|
||||
"""
|
||||
super().__init__()
|
||||
self.api_key = api_key
|
||||
|
||||
self.use_local_whisper = is_whisper_available() and self.api_key is None
|
||||
|
||||
if self.use_local_whisper:
|
||||
import whisper
|
||||
|
||||
self._model = whisper.load_model(model_name_or_path, device=device)
|
||||
else:
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"Please provide a valid api_key for OpenAI API. Alternatively, "
|
||||
"install OpenAI whisper (see https://github.com/openai/whisper for more details)."
|
||||
)
|
||||
|
||||
def transcribe(
|
||||
self,
|
||||
audio_file: Union[str, BinaryIO],
|
||||
language: Optional[str] = None,
|
||||
return_segments: bool = False,
|
||||
translate: bool = False,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Transcribe audio file.
|
||||
|
||||
:param audio_file: Path to audio file or a binary file-like object.
|
||||
:param language: Language of the audio file. If None, the language is automatically detected.
|
||||
:param return_segments: If True, returns the transcription for each segment of the audio file.
|
||||
:param translate: If True, translates the transcription to English.
|
||||
|
||||
"""
|
||||
transcript: Dict[str, Any] = {}
|
||||
|
||||
new_kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
if language is not None:
|
||||
new_kwargs["language"] = language
|
||||
|
||||
if self.use_local_whisper:
|
||||
new_kwargs["return_segments"] = return_segments
|
||||
transcript = self._invoke_local(audio_file, translate, **new_kwargs)
|
||||
elif self.api_key:
|
||||
transcript = self._invoke_api(audio_file, translate, **new_kwargs)
|
||||
return transcript
|
||||
|
||||
def _invoke_api(
|
||||
self, audio_file: Union[str, BinaryIO], translate: Optional[bool] = False, **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
if isinstance(audio_file, str):
|
||||
with open(audio_file, "rb") as f:
|
||||
return self._invoke_api(f, translate, **kwargs)
|
||||
else:
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
request = PreparedRequest()
|
||||
url: str = (
|
||||
"https://api.openai.com/v1/audio/transcriptions"
|
||||
if not translate
|
||||
else "https://api.openai.com/v1/audio/translations"
|
||||
)
|
||||
|
||||
request.prepare(
|
||||
method="POST",
|
||||
url=url,
|
||||
headers=headers,
|
||||
data={"model": "whisper-1", **kwargs},
|
||||
files=[("file", (audio_file.name, audio_file, "application/octet-stream"))],
|
||||
)
|
||||
response = requests.post(url, data=request.body, headers=request.headers, timeout=600)
|
||||
|
||||
if response.status_code != 200:
|
||||
openai_error: OpenAIError
|
||||
if response.status_code == 429:
|
||||
openai_error = OpenAIRateLimitError(f"API rate limit exceeded: {response.text}")
|
||||
else:
|
||||
openai_error = OpenAIError(
|
||||
f"OpenAI returned an error.\n"
|
||||
f"Status code: {response.status_code}\n"
|
||||
f"Response body: {response.text}",
|
||||
status_code=response.status_code,
|
||||
)
|
||||
raise openai_error
|
||||
|
||||
return json.loads(response.content)
|
||||
|
||||
def _invoke_local(
|
||||
self, audio_file: Union[str, BinaryIO], translate: Optional[bool] = False, **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
if isinstance(audio_file, str):
|
||||
with open(audio_file, "rb") as f:
|
||||
return self._invoke_local(f, translate, **kwargs)
|
||||
else:
|
||||
return_segments = kwargs.pop("return_segments", None)
|
||||
kwargs["task"] = "translate" if translate else "transcribe"
|
||||
transcription = self._model.transcribe(audio_file.name, **kwargs)
|
||||
if not return_segments:
|
||||
transcription.pop("segments", None)
|
||||
|
||||
return transcription
|
||||
|
||||
def run(self, audio_file: Union[str, BinaryIO], language: Optional[str] = None, return_segments: bool = False, translate: bool = False): # type: ignore
|
||||
"""
|
||||
Transcribe audio file.
|
||||
|
||||
:param audio_file: Path to audio file or a binary file-like object.
|
||||
:param language: Language of the audio file. If None, the language is automatically detected.
|
||||
:param return_segments: If True, returns the transcription for each segment of the audio file.
|
||||
:param translate: If True, translates the transcription to English.
|
||||
"""
|
||||
document = self.transcribe(audio_file, language, return_segments, translate)
|
||||
|
||||
output = {"documents": [document]}
|
||||
|
||||
return output, "output_1"
|
||||
|
||||
def run_batch(self, audio_files: List[Union[str, BinaryIO]], language: Optional[str] = None, return_segments: bool = False, translate: bool = False): # type: ignore
|
||||
"""
|
||||
Transcribe audio files.
|
||||
|
||||
:param audio_files: List of paths to audio files or binary file-like objects.
|
||||
:param language: Language of the audio files. If None, the language is automatically detected.
|
||||
:param return_segments: If True, returns the transcription for each segment of the audio files.
|
||||
:param translate: If True, translates the transcription to English.
|
||||
"""
|
||||
documents = []
|
||||
for audio in audio_files:
|
||||
document = self.transcribe(audio, language, return_segments, translate)
|
||||
documents.append(document)
|
||||
|
||||
output = {"documents": documents}
|
||||
|
||||
return output, "output_1"
|
@ -6,6 +6,7 @@ import tarfile
|
||||
import zipfile
|
||||
import logging
|
||||
import importlib
|
||||
import importlib.util
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
@ -118,3 +119,7 @@ def fetch_archive_from_http(
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def is_whisper_available():
|
||||
return importlib.util.find_spec("whisper") is not None
|
||||
|
@ -156,6 +156,7 @@ audio = [
|
||||
"protobuf<=3.20.1",
|
||||
"soundfile< 0.12.0",
|
||||
"numpy<1.24", # Keep compatibility with latest numba
|
||||
"openai-whisper"
|
||||
]
|
||||
beir = [
|
||||
"beir; platform_system != 'Windows'",
|
||||
|
54
test/nodes/test_whisper.py
Normal file
54
test/nodes/test_whisper.py
Normal file
@ -0,0 +1,54 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.nodes.audio import WhisperTranscriber
|
||||
from haystack.utils.import_utils import is_whisper_available
|
||||
from ..conftest import SAMPLES_PATH
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get("OPENAI_API_KEY", "") == "", reason="OpenAI API key not found")
|
||||
@pytest.mark.integration
|
||||
def test_whisper_api_transcribe():
|
||||
w = WhisperTranscriber(api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
audio_object_transcript, audio_path_transcript = transcribe_test_helper(w)
|
||||
assert "segments" not in audio_object_transcript and "segments" not in audio_path_transcript
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get("OPENAI_API_KEY", "") == "", reason="OpenAI API key not found")
|
||||
@pytest.mark.integration
|
||||
def test_whisper_api_transcribe_with_params():
|
||||
w = WhisperTranscriber(api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
audio_object_transcript, audio_path_transcript = transcribe_test_helper(w)
|
||||
assert "segments" not in audio_object_transcript and "segments" not in audio_path_transcript
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(not is_whisper_available(), reason="Whisper is not installed")
|
||||
def test_whisper_local_transcribe():
|
||||
w = WhisperTranscriber()
|
||||
audio_object_transcript, audio_path_transcript = transcribe_test_helper(w, language="en")
|
||||
assert "segments" not in audio_object_transcript and "segments" not in audio_path_transcript
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(not is_whisper_available(), reason="Whisper is not installed")
|
||||
def test_whisper_local_transcribe_with_params():
|
||||
w = WhisperTranscriber()
|
||||
audio_object, audio_path = transcribe_test_helper(w, language="en", return_segments=True)
|
||||
assert len(audio_object["segments"]) == 1 and len(audio_path["segments"]) == 1
|
||||
|
||||
|
||||
def transcribe_test_helper(whisper, **kwargs):
|
||||
# this file is 1 second long and contains the word "answer"
|
||||
file_path = str(SAMPLES_PATH / "audio" / "answer.wav")
|
||||
|
||||
# using audio object
|
||||
with open(file_path, mode="rb") as audio_file:
|
||||
audio_object_transcript = whisper.transcribe(audio_file=audio_file, **kwargs)
|
||||
assert "answer" in audio_object_transcript["text"].lower()
|
||||
|
||||
# using path to audio file
|
||||
audio_path_transcript = whisper.transcribe(audio_file=file_path, **kwargs)
|
||||
assert "answer" in audio_path_transcript["text"].lower()
|
||||
return audio_object_transcript, audio_path_transcript
|
Loading…
x
Reference in New Issue
Block a user