diff --git a/haystack/components/audio/whisper_local.py b/haystack/components/audio/whisper_local.py index abef048e8..00f73fe9f 100644 --- a/haystack/components/audio/whisper_local.py +++ b/haystack/components/audio/whisper_local.py @@ -16,7 +16,20 @@ with LazyImport("Run 'pip install \"openai-whisper>=20231106\"' to install whisp logger = logging.getLogger(__name__) -WhisperLocalModel = Literal["tiny", "small", "medium", "large", "large-v2"] +WhisperLocalModel = Literal[ + "base", + "base.en", + "large", + "large-v1", + "large-v2", + "large-v3", + "medium", + "medium.en", + "small", + "small.en", + "tiny", + "tiny.en", +] @component @@ -161,25 +174,21 @@ class LocalWhisperTranscriber: raise RuntimeError("Model is not loaded, please run 'warm_up()' before calling 'run()'") return_segments = kwargs.pop("return_segments", False) - transcriptions: Dict[Path, Any] = {} + transcriptions = {} + for source in sources: - if not isinstance(source, ByteStream): - path = Path(source) - source = ByteStream.from_file_path(path) - source.meta["file_path"] = path - else: - # If we received a ByteStream instance that doesn't have the "file_path" metadata set, - # we dump the bytes into a temporary file. - path = source.meta.get("file_path") - if path is None: - fp = tempfile.NamedTemporaryFile(delete=False) + path = Path(source) if not isinstance(source, ByteStream) else source.meta.get("file_path") + + if isinstance(source, ByteStream) and path is None: + with tempfile.NamedTemporaryFile(delete=False) as fp: path = Path(fp.name) source.to_file(path) - source.meta["file_path"] = path transcription = self._model.transcribe(str(path), **kwargs) + if not return_segments: transcription.pop("segments", None) + transcriptions[path] = transcription return transcriptions diff --git a/test/components/audio/test_whisper_local.py b/test/components/audio/test_whisper_local.py index 6cbd43575..ee2c07c31 100644 --- a/test/components/audio/test_whisper_local.py +++ b/test/components/audio/test_whisper_local.py @@ -8,6 +8,8 @@ from unittest.mock import patch, MagicMock import pytest import torch +from haystack import Pipeline +from haystack.components.fetchers import LinkContentFetcher from haystack.dataclasses import Document, ByteStream from haystack.components.audio import LocalWhisperTranscriber from haystack.utils.device import ComponentDevice, Device @@ -169,7 +171,7 @@ class TestLocalWhisperTranscriber: @pytest.mark.integration @pytest.mark.skipif(sys.platform in ["win32", "cygwin"], reason="ffmpeg not installed on Windows CI") def test_whisper_local_transcriber(self, test_files_path): - comp = LocalWhisperTranscriber(model="medium", whisper_params={"language": "english"}) + comp = LocalWhisperTranscriber(model="tiny", whisper_params={"language": "english"}) comp.warm_up() output = comp.run( sources=[ @@ -181,13 +183,34 @@ class TestLocalWhisperTranscriber: docs = output["documents"] assert len(docs) == 3 - assert docs[0].content.strip().lower() == "this is the content of the document." + assert all( + word in docs[0].content.strip().lower() for word in {"content", "the", "document"} + ), f"Expected words not found in: {docs[0].content.strip().lower()}" assert test_files_path / "audio" / "this is the content of the document.wav" == docs[0].meta["audio_file"] - assert docs[1].content.strip().lower() == "the context for this answer is here." + assert all( + word in docs[1].content.strip().lower() for word in {"context", "answer"} + ), f"Expected words not found in: {docs[1].content.strip().lower()}" path = test_files_path / "audio" / "the context for this answer is here.wav" assert path.absolute() == docs[1].meta["audio_file"] assert docs[2].content.strip().lower() == "answer." # meta.audio_file should contain the temp path where we dumped the audio bytes assert docs[2].meta["audio_file"] + + @pytest.mark.integration + @pytest.mark.skipif(sys.platform in ["win32", "cygwin"], reason="ffmpeg not installed on Windows CI") + def test_whisper_local_transcriber_pipeline_and_url_source(self): + pipe = Pipeline() + pipe.add_component("fetcher", LinkContentFetcher()) + pipe.add_component("transcriber", LocalWhisperTranscriber(model="tiny")) + + pipe.connect("fetcher", "transcriber") + result = pipe.run( + data={ + "fetcher": { + "urls": ["https://ia903102.us.archive.org/19/items/100-Best--Speeches/EK_19690725_64kb.mp3"] + } + } + ) + assert "Massachusetts" in result["transcriber"]["documents"][0].content diff --git a/test/components/audio/test_whisper_remote.py b/test/components/audio/test_whisper_remote.py index 69c001159..f543d3305 100644 --- a/test/components/audio/test_whisper_remote.py +++ b/test/components/audio/test_whisper_remote.py @@ -4,7 +4,10 @@ import os import pytest +from haystack import Pipeline +from haystack.components.audio import LocalWhisperTranscriber from haystack.components.audio.whisper_remote import RemoteWhisperTranscriber +from haystack.components.fetchers import LinkContentFetcher from haystack.dataclasses import ByteStream from haystack.utils import Secret @@ -186,3 +189,19 @@ class TestRemoteWhisperTranscriber: assert str(test_files_path / "audio" / "the context for this answer is here.wav") == docs[1].meta["file_path"] assert docs[2].content.strip().lower() == "answer." + + @pytest.mark.integration + def test_whisper_local_transcriber_pipeline_and_url_source(self): + pipe = Pipeline() + pipe.add_component("fetcher", LinkContentFetcher()) + pipe.add_component("transcriber", RemoteWhisperTranscriber()) + + pipe.connect("fetcher", "transcriber") + result = pipe.run( + data={ + "fetcher": { + "urls": ["https://ia903102.us.archive.org/19/items/100-Best--Speeches/EK_19690725_64kb.mp3"] + } + } + ) + assert "Massachusetts" in result["transcriber"]["documents"][0].content