refactor: Adjust WhisperTranscriber to pipeline run methods (#4510)

* Retrofit WhisperTranscriber run methods
* Add pipeline unit test
---------
Co-authored-by: ZanSara <sara.zanzottera@deepset.ai>
This commit is contained in:
Vladimir Blagojevic 2023-03-28 13:52:21 +02:00 committed by GitHub
parent 098342da32
commit 7c9f719496
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 61 additions and 27 deletions

View File

@ -6,11 +6,11 @@ import requests
import torch
from requests import PreparedRequest
from haystack import MultiLabel, Document
from haystack.errors import OpenAIError, OpenAIRateLimitError
from haystack.nodes.base import BaseComponent
from haystack.utils.import_utils import is_whisper_available
WhisperModel = Literal["tiny", "small", "medium", "large", "large-v2"]
@ -153,35 +153,57 @@ class WhisperTranscriber(BaseComponent):
return transcription
def run(self, audio_file: Union[str, BinaryIO], language: Optional[str] = None, return_segments: bool = False, translate: bool = False): # type: ignore
"""
Transcribe audio file.
:param audio_file: Path to audio file or a binary file-like object.
:param language: Language of the audio file. If None, the language is automatically detected.
:param return_segments: If True, returns the transcription for each segment of the audio file.
:param translate: If True, translates the transcription to English.
"""
document = self.transcribe(audio_file, language, return_segments, translate)
output = {"documents": [document]}
return output, "output_1"
def run_batch(self, audio_files: List[Union[str, BinaryIO]], language: Optional[str] = None, return_segments: bool = False, translate: bool = False): # type: ignore
def run(
self,
query: Optional[str] = None,
file_paths: Optional[List[str]] = None,
labels: Optional[MultiLabel] = None,
documents: Optional[List[Document]] = None,
meta: Optional[dict] = None,
): # type: ignore
"""
Transcribe audio files.
:param audio_files: List of paths to audio files or binary file-like objects.
:param language: Language of the audio files. If None, the language is automatically detected.
:param return_segments: If True, returns the transcription for each segment of the audio files.
:param translate: If True, translates the transcription to English.
:param query: Ignored
:param file_paths: List of paths to audio files.
:param labels: Ignored
:param documents: Ignored
:param meta: Ignored
"""
documents = []
for audio in audio_files:
document = self.transcribe(audio, language, return_segments, translate)
documents.append(document)
output = {"documents": documents}
transcribed_documents: List[Document] = []
if file_paths:
for file_path in file_paths:
transcription = self.transcribe(file_path)
d = Document.from_dict(transcription, field_map={"text": "content"})
transcribed_documents.append(d)
output = {"documents": transcribed_documents}
return output, "output_1"
def run_batch(
self,
queries: Optional[Union[str, List[str]]] = None,
file_paths: Optional[List[str]] = None,
labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None,
documents: Optional[Union[List[Document], List[List[Document]]]] = None,
meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
params: Optional[dict] = None,
debug: Optional[bool] = None,
): # type: ignore
"""
Transcribe audio files.
:param queries: Ignored
:param file_paths: List of paths to audio files.
:param labels: Ignored
:param documents: Ignored
:param meta: Ignored
:param params: Ignored
:param debug: Ignored
"""
if file_paths and isinstance(file_paths[0], list):
all_files = []
for files_list in file_paths:
all_files += files_list
return self.run(file_paths=all_files)
return self.run(file_paths=file_paths)

View File

@ -2,6 +2,7 @@ import os
import pytest
from haystack import Pipeline
from haystack.nodes.audio import WhisperTranscriber
from haystack.utils.import_utils import is_whisper_available
from ..conftest import SAMPLES_PATH
@ -54,3 +55,14 @@ def transcribe_test_helper(whisper, **kwargs):
audio_path_transcript = whisper.transcribe(audio_file=file_path, **kwargs)
assert "answer" in audio_path_transcript["text"].lower()
return audio_object_transcript, audio_path_transcript
@pytest.mark.skipif(os.environ.get("OPENAI_API_KEY", "") == "", reason="OpenAI API key not found")
@pytest.mark.integration
def test_whisper_pipeline():
w = WhisperTranscriber(api_key=os.environ.get("OPENAI_API_KEY"))
pipeline = Pipeline()
pipeline.add_node(component=w, name="whisper", inputs=["File"])
res = pipeline.run(file_paths=[str(SAMPLES_PATH / "audio" / "answer.wav")])
assert res["documents"] and len(res["documents"]) == 1
assert "answer" in res["documents"][0].content.lower()