feat: Update LocalWhisperTranscriber, add tests (#7935)

* Update LocalWhisperTranscriber, add tests

* Final touches

* Update haystack/components/audio/whisper_local.py

Co-authored-by: David S. Batista <dsbatista@gmail.com>

* Fix prev commit

* Relax test for tiny model to work

---------

Co-authored-by: David S. Batista <dsbatista@gmail.com>
This commit is contained in:
Vladimir Blagojevic 2024-06-27 12:53:41 +02:00 committed by GitHub
parent 1f7786d6dd
commit 569b2a87cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 67 additions and 16 deletions

View File

@ -16,7 +16,20 @@ with LazyImport("Run 'pip install \"openai-whisper>=20231106\"' to install whisp
logger = logging.getLogger(__name__)
WhisperLocalModel = Literal["tiny", "small", "medium", "large", "large-v2"]
WhisperLocalModel = Literal[
"base",
"base.en",
"large",
"large-v1",
"large-v2",
"large-v3",
"medium",
"medium.en",
"small",
"small.en",
"tiny",
"tiny.en",
]
@component
@ -161,25 +174,21 @@ class LocalWhisperTranscriber:
raise RuntimeError("Model is not loaded, please run 'warm_up()' before calling 'run()'")
return_segments = kwargs.pop("return_segments", False)
transcriptions: Dict[Path, Any] = {}
transcriptions = {}
for source in sources:
if not isinstance(source, ByteStream):
path = Path(source)
source = ByteStream.from_file_path(path)
source.meta["file_path"] = path
else:
# If we received a ByteStream instance that doesn't have the "file_path" metadata set,
# we dump the bytes into a temporary file.
path = source.meta.get("file_path")
if path is None:
fp = tempfile.NamedTemporaryFile(delete=False)
path = Path(source) if not isinstance(source, ByteStream) else source.meta.get("file_path")
if isinstance(source, ByteStream) and path is None:
with tempfile.NamedTemporaryFile(delete=False) as fp:
path = Path(fp.name)
source.to_file(path)
source.meta["file_path"] = path
transcription = self._model.transcribe(str(path), **kwargs)
if not return_segments:
transcription.pop("segments", None)
transcriptions[path] = transcription
return transcriptions

View File

@ -8,6 +8,8 @@ from unittest.mock import patch, MagicMock
import pytest
import torch
from haystack import Pipeline
from haystack.components.fetchers import LinkContentFetcher
from haystack.dataclasses import Document, ByteStream
from haystack.components.audio import LocalWhisperTranscriber
from haystack.utils.device import ComponentDevice, Device
@ -169,7 +171,7 @@ class TestLocalWhisperTranscriber:
@pytest.mark.integration
@pytest.mark.skipif(sys.platform in ["win32", "cygwin"], reason="ffmpeg not installed on Windows CI")
def test_whisper_local_transcriber(self, test_files_path):
comp = LocalWhisperTranscriber(model="medium", whisper_params={"language": "english"})
comp = LocalWhisperTranscriber(model="tiny", whisper_params={"language": "english"})
comp.warm_up()
output = comp.run(
sources=[
@ -181,13 +183,34 @@ class TestLocalWhisperTranscriber:
docs = output["documents"]
assert len(docs) == 3
assert docs[0].content.strip().lower() == "this is the content of the document."
assert all(
word in docs[0].content.strip().lower() for word in {"content", "the", "document"}
), f"Expected words not found in: {docs[0].content.strip().lower()}"
assert test_files_path / "audio" / "this is the content of the document.wav" == docs[0].meta["audio_file"]
assert docs[1].content.strip().lower() == "the context for this answer is here."
assert all(
word in docs[1].content.strip().lower() for word in {"context", "answer"}
), f"Expected words not found in: {docs[1].content.strip().lower()}"
path = test_files_path / "audio" / "the context for this answer is here.wav"
assert path.absolute() == docs[1].meta["audio_file"]
assert docs[2].content.strip().lower() == "answer."
# meta.audio_file should contain the temp path where we dumped the audio bytes
assert docs[2].meta["audio_file"]
@pytest.mark.integration
@pytest.mark.skipif(sys.platform in ["win32", "cygwin"], reason="ffmpeg not installed on Windows CI")
def test_whisper_local_transcriber_pipeline_and_url_source(self):
pipe = Pipeline()
pipe.add_component("fetcher", LinkContentFetcher())
pipe.add_component("transcriber", LocalWhisperTranscriber(model="tiny"))
pipe.connect("fetcher", "transcriber")
result = pipe.run(
data={
"fetcher": {
"urls": ["https://ia903102.us.archive.org/19/items/100-Best--Speeches/EK_19690725_64kb.mp3"]
}
}
)
assert "Massachusetts" in result["transcriber"]["documents"][0].content

View File

@ -4,7 +4,10 @@
import os
import pytest
from haystack import Pipeline
from haystack.components.audio import LocalWhisperTranscriber
from haystack.components.audio.whisper_remote import RemoteWhisperTranscriber
from haystack.components.fetchers import LinkContentFetcher
from haystack.dataclasses import ByteStream
from haystack.utils import Secret
@ -186,3 +189,19 @@ class TestRemoteWhisperTranscriber:
assert str(test_files_path / "audio" / "the context for this answer is here.wav") == docs[1].meta["file_path"]
assert docs[2].content.strip().lower() == "answer."
@pytest.mark.integration
def test_whisper_local_transcriber_pipeline_and_url_source(self):
pipe = Pipeline()
pipe.add_component("fetcher", LinkContentFetcher())
pipe.add_component("transcriber", RemoteWhisperTranscriber())
pipe.connect("fetcher", "transcriber")
result = pipe.run(
data={
"fetcher": {
"urls": ["https://ia903102.us.archive.org/19/items/100-Best--Speeches/EK_19690725_64kb.mp3"]
}
}
)
assert "Massachusetts" in result["transcriber"]["documents"][0].content