mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-16 05:20:51 +00:00
Add Whisper node (#4335)
* Add Whisper node * Add support for audio path, improve tests * Add docs * Improve tests
This commit is contained in:
parent
28724e2e25
commit
98256ecf57
@ -48,6 +48,8 @@ from haystack.nodes.retriever import (
|
|||||||
from haystack.nodes.summarizer import BaseSummarizer, TransformersSummarizer
|
from haystack.nodes.summarizer import BaseSummarizer, TransformersSummarizer
|
||||||
from haystack.nodes.translator import BaseTranslator, TransformersTranslator
|
from haystack.nodes.translator import BaseTranslator, TransformersTranslator
|
||||||
|
|
||||||
|
from haystack.nodes.audio import WhisperTranscriber, WhisperModel
|
||||||
|
|
||||||
Crawler = safe_import("haystack.nodes.connector.crawler", "Crawler", "crawler") # Has optional dependencies
|
Crawler = safe_import("haystack.nodes.connector.crawler", "Crawler", "crawler") # Has optional dependencies
|
||||||
AnswerToSpeech = safe_import(
|
AnswerToSpeech = safe_import(
|
||||||
"haystack.nodes.audio.answer_to_speech", "AnswerToSpeech", "audio"
|
"haystack.nodes.audio.answer_to_speech", "AnswerToSpeech", "audio"
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from haystack.utils.import_utils import safe_import
|
from haystack.utils.import_utils import safe_import
|
||||||
|
from haystack.nodes.audio.whisper_transcriber import WhisperTranscriber, WhisperModel
|
||||||
|
|
||||||
AnswerToSpeech = safe_import(
|
AnswerToSpeech = safe_import(
|
||||||
"haystack.nodes.audio.answer_to_speech", "AnswerToSpeech", "audio"
|
"haystack.nodes.audio.answer_to_speech", "AnswerToSpeech", "audio"
|
||||||
|
187
haystack/nodes/audio/whisper_transcriber.py
Normal file
187
haystack/nodes/audio/whisper_transcriber.py
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
from typing import List, Optional, Dict, Any, Union, BinaryIO, Literal
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
from requests import PreparedRequest
|
||||||
|
|
||||||
|
from haystack.errors import OpenAIError, OpenAIRateLimitError
|
||||||
|
from haystack.nodes.base import BaseComponent
|
||||||
|
from haystack.utils.import_utils import is_whisper_available
|
||||||
|
|
||||||
|
|
||||||
|
WhisperModel = Literal["tiny", "small", "medium", "large", "large-v2"]
|
||||||
|
|
||||||
|
|
||||||
|
class WhisperTranscriber(BaseComponent):
|
||||||
|
"""
|
||||||
|
Transcribes audio files using OpenAI's Whisper. This class supports two underlying implementations:
|
||||||
|
|
||||||
|
- API (default): Uses the OpenAI API and requires an API key. See blog
|
||||||
|
[post](https://beta.openai.com/docs/api-reference/whisper for more details.) for more details.
|
||||||
|
- Local (requires installation of whisper): Uses the local installation
|
||||||
|
of [whisper](https://github.com/openai/whisper).
|
||||||
|
|
||||||
|
If you are using local installation of whisper, install whisper following the instructions available on
|
||||||
|
the Whisper [github repo](https://github.com/openai/whisper) and omit the api_key parameter.
|
||||||
|
|
||||||
|
If you are using the API implementation, you need to provide an api_key. You can get one by signing up
|
||||||
|
for an OpenAI account [here](https://beta.openai.com/).
|
||||||
|
|
||||||
|
For the supported audio formats, languages and other parameters, see the Whisper API
|
||||||
|
[documentation](https://platform.openai.com/docs/guides/speech-to-text) and the official Whisper
|
||||||
|
[github repo](https://github.com/openai/whisper).
|
||||||
|
"""
|
||||||
|
|
||||||
|
# If it's not a decision component, there is only one outgoing edge
|
||||||
|
outgoing_edges = 1
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
model_name_or_path: WhisperModel = "medium",
|
||||||
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Creates a WhisperTranscriber instance.
|
||||||
|
|
||||||
|
:param api_key: OpenAI API key. If None, local installation of whisper is used.
|
||||||
|
:param model_name_or_path: Name of the model to use. If using local installation of whisper, this
|
||||||
|
value has to be one of the following: "tiny", "small", "medium", "large", "large-v2". If using
|
||||||
|
the API, this value has to be "whisper-1" (default).
|
||||||
|
:param device: Device to use for inference. This parameter is only used if you are using local
|
||||||
|
installation of whisper. If None, the device is automatically selected.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.api_key = api_key
|
||||||
|
|
||||||
|
self.use_local_whisper = is_whisper_available() and self.api_key is None
|
||||||
|
|
||||||
|
if self.use_local_whisper:
|
||||||
|
import whisper
|
||||||
|
|
||||||
|
self._model = whisper.load_model(model_name_or_path, device=device)
|
||||||
|
else:
|
||||||
|
if api_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Please provide a valid api_key for OpenAI API. Alternatively, "
|
||||||
|
"install OpenAI whisper (see https://github.com/openai/whisper for more details)."
|
||||||
|
)
|
||||||
|
|
||||||
|
def transcribe(
|
||||||
|
self,
|
||||||
|
audio_file: Union[str, BinaryIO],
|
||||||
|
language: Optional[str] = None,
|
||||||
|
return_segments: bool = False,
|
||||||
|
translate: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Transcribe audio file.
|
||||||
|
|
||||||
|
:param audio_file: Path to audio file or a binary file-like object.
|
||||||
|
:param language: Language of the audio file. If None, the language is automatically detected.
|
||||||
|
:param return_segments: If True, returns the transcription for each segment of the audio file.
|
||||||
|
:param translate: If True, translates the transcription to English.
|
||||||
|
|
||||||
|
"""
|
||||||
|
transcript: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
new_kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
|
if language is not None:
|
||||||
|
new_kwargs["language"] = language
|
||||||
|
|
||||||
|
if self.use_local_whisper:
|
||||||
|
new_kwargs["return_segments"] = return_segments
|
||||||
|
transcript = self._invoke_local(audio_file, translate, **new_kwargs)
|
||||||
|
elif self.api_key:
|
||||||
|
transcript = self._invoke_api(audio_file, translate, **new_kwargs)
|
||||||
|
return transcript
|
||||||
|
|
||||||
|
def _invoke_api(
|
||||||
|
self, audio_file: Union[str, BinaryIO], translate: Optional[bool] = False, **kwargs
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
if isinstance(audio_file, str):
|
||||||
|
with open(audio_file, "rb") as f:
|
||||||
|
return self._invoke_api(f, translate, **kwargs)
|
||||||
|
else:
|
||||||
|
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||||
|
request = PreparedRequest()
|
||||||
|
url: str = (
|
||||||
|
"https://api.openai.com/v1/audio/transcriptions"
|
||||||
|
if not translate
|
||||||
|
else "https://api.openai.com/v1/audio/translations"
|
||||||
|
)
|
||||||
|
|
||||||
|
request.prepare(
|
||||||
|
method="POST",
|
||||||
|
url=url,
|
||||||
|
headers=headers,
|
||||||
|
data={"model": "whisper-1", **kwargs},
|
||||||
|
files=[("file", (audio_file.name, audio_file, "application/octet-stream"))],
|
||||||
|
)
|
||||||
|
response = requests.post(url, data=request.body, headers=request.headers, timeout=600)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
openai_error: OpenAIError
|
||||||
|
if response.status_code == 429:
|
||||||
|
openai_error = OpenAIRateLimitError(f"API rate limit exceeded: {response.text}")
|
||||||
|
else:
|
||||||
|
openai_error = OpenAIError(
|
||||||
|
f"OpenAI returned an error.\n"
|
||||||
|
f"Status code: {response.status_code}\n"
|
||||||
|
f"Response body: {response.text}",
|
||||||
|
status_code=response.status_code,
|
||||||
|
)
|
||||||
|
raise openai_error
|
||||||
|
|
||||||
|
return json.loads(response.content)
|
||||||
|
|
||||||
|
def _invoke_local(
|
||||||
|
self, audio_file: Union[str, BinaryIO], translate: Optional[bool] = False, **kwargs
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
if isinstance(audio_file, str):
|
||||||
|
with open(audio_file, "rb") as f:
|
||||||
|
return self._invoke_local(f, translate, **kwargs)
|
||||||
|
else:
|
||||||
|
return_segments = kwargs.pop("return_segments", None)
|
||||||
|
kwargs["task"] = "translate" if translate else "transcribe"
|
||||||
|
transcription = self._model.transcribe(audio_file.name, **kwargs)
|
||||||
|
if not return_segments:
|
||||||
|
transcription.pop("segments", None)
|
||||||
|
|
||||||
|
return transcription
|
||||||
|
|
||||||
|
def run(self, audio_file: Union[str, BinaryIO], language: Optional[str] = None, return_segments: bool = False, translate: bool = False): # type: ignore
|
||||||
|
"""
|
||||||
|
Transcribe audio file.
|
||||||
|
|
||||||
|
:param audio_file: Path to audio file or a binary file-like object.
|
||||||
|
:param language: Language of the audio file. If None, the language is automatically detected.
|
||||||
|
:param return_segments: If True, returns the transcription for each segment of the audio file.
|
||||||
|
:param translate: If True, translates the transcription to English.
|
||||||
|
"""
|
||||||
|
document = self.transcribe(audio_file, language, return_segments, translate)
|
||||||
|
|
||||||
|
output = {"documents": [document]}
|
||||||
|
|
||||||
|
return output, "output_1"
|
||||||
|
|
||||||
|
def run_batch(self, audio_files: List[Union[str, BinaryIO]], language: Optional[str] = None, return_segments: bool = False, translate: bool = False): # type: ignore
|
||||||
|
"""
|
||||||
|
Transcribe audio files.
|
||||||
|
|
||||||
|
:param audio_files: List of paths to audio files or binary file-like objects.
|
||||||
|
:param language: Language of the audio files. If None, the language is automatically detected.
|
||||||
|
:param return_segments: If True, returns the transcription for each segment of the audio files.
|
||||||
|
:param translate: If True, translates the transcription to English.
|
||||||
|
"""
|
||||||
|
documents = []
|
||||||
|
for audio in audio_files:
|
||||||
|
document = self.transcribe(audio, language, return_segments, translate)
|
||||||
|
documents.append(document)
|
||||||
|
|
||||||
|
output = {"documents": documents}
|
||||||
|
|
||||||
|
return output, "output_1"
|
@ -6,6 +6,7 @@ import tarfile
|
|||||||
import zipfile
|
import zipfile
|
||||||
import logging
|
import logging
|
||||||
import importlib
|
import importlib
|
||||||
|
import importlib.util
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
@ -118,3 +119,7 @@ def fetch_archive_from_http(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def is_whisper_available():
|
||||||
|
return importlib.util.find_spec("whisper") is not None
|
||||||
|
@ -156,6 +156,7 @@ audio = [
|
|||||||
"protobuf<=3.20.1",
|
"protobuf<=3.20.1",
|
||||||
"soundfile< 0.12.0",
|
"soundfile< 0.12.0",
|
||||||
"numpy<1.24", # Keep compatibility with latest numba
|
"numpy<1.24", # Keep compatibility with latest numba
|
||||||
|
"openai-whisper"
|
||||||
]
|
]
|
||||||
beir = [
|
beir = [
|
||||||
"beir; platform_system != 'Windows'",
|
"beir; platform_system != 'Windows'",
|
||||||
|
54
test/nodes/test_whisper.py
Normal file
54
test/nodes/test_whisper.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from haystack.nodes.audio import WhisperTranscriber
|
||||||
|
from haystack.utils.import_utils import is_whisper_available
|
||||||
|
from ..conftest import SAMPLES_PATH
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(os.environ.get("OPENAI_API_KEY", "") == "", reason="OpenAI API key not found")
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_whisper_api_transcribe():
|
||||||
|
w = WhisperTranscriber(api_key=os.environ.get("OPENAI_API_KEY"))
|
||||||
|
audio_object_transcript, audio_path_transcript = transcribe_test_helper(w)
|
||||||
|
assert "segments" not in audio_object_transcript and "segments" not in audio_path_transcript
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(os.environ.get("OPENAI_API_KEY", "") == "", reason="OpenAI API key not found")
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_whisper_api_transcribe_with_params():
|
||||||
|
w = WhisperTranscriber(api_key=os.environ.get("OPENAI_API_KEY"))
|
||||||
|
audio_object_transcript, audio_path_transcript = transcribe_test_helper(w)
|
||||||
|
assert "segments" not in audio_object_transcript and "segments" not in audio_path_transcript
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.skipif(not is_whisper_available(), reason="Whisper is not installed")
|
||||||
|
def test_whisper_local_transcribe():
|
||||||
|
w = WhisperTranscriber()
|
||||||
|
audio_object_transcript, audio_path_transcript = transcribe_test_helper(w, language="en")
|
||||||
|
assert "segments" not in audio_object_transcript and "segments" not in audio_path_transcript
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.skipif(not is_whisper_available(), reason="Whisper is not installed")
|
||||||
|
def test_whisper_local_transcribe_with_params():
|
||||||
|
w = WhisperTranscriber()
|
||||||
|
audio_object, audio_path = transcribe_test_helper(w, language="en", return_segments=True)
|
||||||
|
assert len(audio_object["segments"]) == 1 and len(audio_path["segments"]) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def transcribe_test_helper(whisper, **kwargs):
|
||||||
|
# this file is 1 second long and contains the word "answer"
|
||||||
|
file_path = str(SAMPLES_PATH / "audio" / "answer.wav")
|
||||||
|
|
||||||
|
# using audio object
|
||||||
|
with open(file_path, mode="rb") as audio_file:
|
||||||
|
audio_object_transcript = whisper.transcribe(audio_file=audio_file, **kwargs)
|
||||||
|
assert "answer" in audio_object_transcript["text"].lower()
|
||||||
|
|
||||||
|
# using path to audio file
|
||||||
|
audio_path_transcript = whisper.transcribe(audio_file=file_path, **kwargs)
|
||||||
|
assert "answer" in audio_path_transcript["text"].lower()
|
||||||
|
return audio_object_transcript, audio_path_transcript
|
Loading…
x
Reference in New Issue
Block a user