mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-06 03:57:19 +00:00
refactor: Adjust WhisperTranscriber to pipeline run methods (#4510)
* Retrofit WhisperTranscriber run methods * Add pipeline unit test --------- Co-authored-by: ZanSara <sara.zanzottera@deepset.ai>
This commit is contained in:
parent
098342da32
commit
7c9f719496
@ -6,11 +6,11 @@ import requests
|
||||
import torch
|
||||
from requests import PreparedRequest
|
||||
|
||||
from haystack import MultiLabel, Document
|
||||
from haystack.errors import OpenAIError, OpenAIRateLimitError
|
||||
from haystack.nodes.base import BaseComponent
|
||||
from haystack.utils.import_utils import is_whisper_available
|
||||
|
||||
|
||||
WhisperModel = Literal["tiny", "small", "medium", "large", "large-v2"]
|
||||
|
||||
|
||||
@ -153,35 +153,57 @@ class WhisperTranscriber(BaseComponent):
|
||||
|
||||
return transcription
|
||||
|
||||
def run(self, audio_file: Union[str, BinaryIO], language: Optional[str] = None, return_segments: bool = False, translate: bool = False): # type: ignore
|
||||
"""
|
||||
Transcribe audio file.
|
||||
|
||||
:param audio_file: Path to audio file or a binary file-like object.
|
||||
:param language: Language of the audio file. If None, the language is automatically detected.
|
||||
:param return_segments: If True, returns the transcription for each segment of the audio file.
|
||||
:param translate: If True, translates the transcription to English.
|
||||
"""
|
||||
document = self.transcribe(audio_file, language, return_segments, translate)
|
||||
|
||||
output = {"documents": [document]}
|
||||
|
||||
return output, "output_1"
|
||||
|
||||
def run_batch(self, audio_files: List[Union[str, BinaryIO]], language: Optional[str] = None, return_segments: bool = False, translate: bool = False): # type: ignore
|
||||
def run(
|
||||
self,
|
||||
query: Optional[str] = None,
|
||||
file_paths: Optional[List[str]] = None,
|
||||
labels: Optional[MultiLabel] = None,
|
||||
documents: Optional[List[Document]] = None,
|
||||
meta: Optional[dict] = None,
|
||||
): # type: ignore
|
||||
"""
|
||||
Transcribe audio files.
|
||||
|
||||
:param audio_files: List of paths to audio files or binary file-like objects.
|
||||
:param language: Language of the audio files. If None, the language is automatically detected.
|
||||
:param return_segments: If True, returns the transcription for each segment of the audio files.
|
||||
:param translate: If True, translates the transcription to English.
|
||||
:param query: Ignored
|
||||
:param file_paths: List of paths to audio files.
|
||||
:param labels: Ignored
|
||||
:param documents: Ignored
|
||||
:param meta: Ignored
|
||||
"""
|
||||
documents = []
|
||||
for audio in audio_files:
|
||||
document = self.transcribe(audio, language, return_segments, translate)
|
||||
documents.append(document)
|
||||
|
||||
output = {"documents": documents}
|
||||
transcribed_documents: List[Document] = []
|
||||
if file_paths:
|
||||
for file_path in file_paths:
|
||||
transcription = self.transcribe(file_path)
|
||||
d = Document.from_dict(transcription, field_map={"text": "content"})
|
||||
transcribed_documents.append(d)
|
||||
|
||||
output = {"documents": transcribed_documents}
|
||||
return output, "output_1"
|
||||
|
||||
def run_batch(
|
||||
self,
|
||||
queries: Optional[Union[str, List[str]]] = None,
|
||||
file_paths: Optional[List[str]] = None,
|
||||
labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None,
|
||||
documents: Optional[Union[List[Document], List[List[Document]]]] = None,
|
||||
meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||
params: Optional[dict] = None,
|
||||
debug: Optional[bool] = None,
|
||||
): # type: ignore
|
||||
"""
|
||||
Transcribe audio files.
|
||||
|
||||
:param queries: Ignored
|
||||
:param file_paths: List of paths to audio files.
|
||||
:param labels: Ignored
|
||||
:param documents: Ignored
|
||||
:param meta: Ignored
|
||||
:param params: Ignored
|
||||
:param debug: Ignored
|
||||
"""
|
||||
if file_paths and isinstance(file_paths[0], list):
|
||||
all_files = []
|
||||
for files_list in file_paths:
|
||||
all_files += files_list
|
||||
return self.run(file_paths=all_files)
|
||||
return self.run(file_paths=file_paths)
|
||||
|
||||
@ -2,6 +2,7 @@ import os
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack import Pipeline
|
||||
from haystack.nodes.audio import WhisperTranscriber
|
||||
from haystack.utils.import_utils import is_whisper_available
|
||||
from ..conftest import SAMPLES_PATH
|
||||
@ -54,3 +55,14 @@ def transcribe_test_helper(whisper, **kwargs):
|
||||
audio_path_transcript = whisper.transcribe(audio_file=file_path, **kwargs)
|
||||
assert "answer" in audio_path_transcript["text"].lower()
|
||||
return audio_object_transcript, audio_path_transcript
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get("OPENAI_API_KEY", "") == "", reason="OpenAI API key not found")
|
||||
@pytest.mark.integration
|
||||
def test_whisper_pipeline():
|
||||
w = WhisperTranscriber(api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(component=w, name="whisper", inputs=["File"])
|
||||
res = pipeline.run(file_paths=[str(SAMPLES_PATH / "audio" / "answer.wav")])
|
||||
assert res["documents"] and len(res["documents"]) == 1
|
||||
assert "answer" in res["documents"][0].content.lower()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user