diff --git a/haystack/nodes/audio/whisper_transcriber.py b/haystack/nodes/audio/whisper_transcriber.py index f22141399..dc1dc3eed 100644 --- a/haystack/nodes/audio/whisper_transcriber.py +++ b/haystack/nodes/audio/whisper_transcriber.py @@ -6,11 +6,11 @@ import requests import torch from requests import PreparedRequest +from haystack import MultiLabel, Document from haystack.errors import OpenAIError, OpenAIRateLimitError from haystack.nodes.base import BaseComponent from haystack.utils.import_utils import is_whisper_available - WhisperModel = Literal["tiny", "small", "medium", "large", "large-v2"] @@ -153,35 +153,57 @@ class WhisperTranscriber(BaseComponent): return transcription - def run(self, audio_file: Union[str, BinaryIO], language: Optional[str] = None, return_segments: bool = False, translate: bool = False): # type: ignore - """ - Transcribe audio file. - - :param audio_file: Path to audio file or a binary file-like object. - :param language: Language of the audio file. If None, the language is automatically detected. - :param return_segments: If True, returns the transcription for each segment of the audio file. - :param translate: If True, translates the transcription to English. - """ - document = self.transcribe(audio_file, language, return_segments, translate) - - output = {"documents": [document]} - - return output, "output_1" - - def run_batch(self, audio_files: List[Union[str, BinaryIO]], language: Optional[str] = None, return_segments: bool = False, translate: bool = False): # type: ignore + def run( + self, + query: Optional[str] = None, + file_paths: Optional[List[str]] = None, + labels: Optional[MultiLabel] = None, + documents: Optional[List[Document]] = None, + meta: Optional[dict] = None, + ): # type: ignore """ Transcribe audio files. - :param audio_files: List of paths to audio files or binary file-like objects. - :param language: Language of the audio files. If None, the language is automatically detected. - :param return_segments: If True, returns the transcription for each segment of the audio files. - :param translate: If True, translates the transcription to English. + :param query: Ignored + :param file_paths: List of paths to audio files. + :param labels: Ignored + :param documents: Ignored + :param meta: Ignored """ - documents = [] - for audio in audio_files: - document = self.transcribe(audio, language, return_segments, translate) - documents.append(document) - - output = {"documents": documents} + transcribed_documents: List[Document] = [] + if file_paths: + for file_path in file_paths: + transcription = self.transcribe(file_path) + d = Document.from_dict(transcription, field_map={"text": "content"}) + transcribed_documents.append(d) + output = {"documents": transcribed_documents} return output, "output_1" + + def run_batch( + self, + queries: Optional[Union[str, List[str]]] = None, + file_paths: Optional[List[str]] = None, + labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None, + documents: Optional[Union[List[Document], List[List[Document]]]] = None, + meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + params: Optional[dict] = None, + debug: Optional[bool] = None, + ): # type: ignore + """ + Transcribe audio files. + + :param queries: Ignored + :param file_paths: List of paths to audio files. + :param labels: Ignored + :param documents: Ignored + :param meta: Ignored + :param params: Ignored + :param debug: Ignored + """ + if file_paths and isinstance(file_paths[0], list): + all_files = [] + for files_list in file_paths: + all_files += files_list + return self.run(file_paths=all_files) + return self.run(file_paths=file_paths) diff --git a/test/nodes/test_whisper.py b/test/nodes/test_whisper.py index 19dffa724..0e6aa4244 100644 --- a/test/nodes/test_whisper.py +++ b/test/nodes/test_whisper.py @@ -2,6 +2,7 @@ import os import pytest +from haystack import Pipeline from haystack.nodes.audio import WhisperTranscriber from haystack.utils.import_utils import is_whisper_available from ..conftest import SAMPLES_PATH @@ -54,3 +55,14 @@ def transcribe_test_helper(whisper, **kwargs): audio_path_transcript = whisper.transcribe(audio_file=file_path, **kwargs) assert "answer" in audio_path_transcript["text"].lower() return audio_object_transcript, audio_path_transcript + + +@pytest.mark.skipif(os.environ.get("OPENAI_API_KEY", "") == "", reason="OpenAI API key not found") +@pytest.mark.integration +def test_whisper_pipeline(): + w = WhisperTranscriber(api_key=os.environ.get("OPENAI_API_KEY")) + pipeline = Pipeline() + pipeline.add_node(component=w, name="whisper", inputs=["File"]) + res = pipeline.run(file_paths=[str(SAMPLES_PATH / "audio" / "answer.wav")]) + assert res["documents"] and len(res["documents"]) == 1 + assert "answer" in res["documents"][0].content.lower()