mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-28 07:29:06 +00:00
feat: Update LocalWhisperTranscriber, add tests (#7935)
* Update LocalWhisperTranscriber, add tests * Final touches * Update haystack/components/audio/whisper_local.py Co-authored-by: David S. Batista <dsbatista@gmail.com> * Fix prev commit * Relax test for tiny model to work --------- Co-authored-by: David S. Batista <dsbatista@gmail.com>
This commit is contained in:
parent
1f7786d6dd
commit
569b2a87cb
@ -16,7 +16,20 @@ with LazyImport("Run 'pip install \"openai-whisper>=20231106\"' to install whisp
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
WhisperLocalModel = Literal["tiny", "small", "medium", "large", "large-v2"]
|
||||
WhisperLocalModel = Literal[
|
||||
"base",
|
||||
"base.en",
|
||||
"large",
|
||||
"large-v1",
|
||||
"large-v2",
|
||||
"large-v3",
|
||||
"medium",
|
||||
"medium.en",
|
||||
"small",
|
||||
"small.en",
|
||||
"tiny",
|
||||
"tiny.en",
|
||||
]
|
||||
|
||||
|
||||
@component
|
||||
@ -161,25 +174,21 @@ class LocalWhisperTranscriber:
|
||||
raise RuntimeError("Model is not loaded, please run 'warm_up()' before calling 'run()'")
|
||||
|
||||
return_segments = kwargs.pop("return_segments", False)
|
||||
transcriptions: Dict[Path, Any] = {}
|
||||
transcriptions = {}
|
||||
|
||||
for source in sources:
|
||||
if not isinstance(source, ByteStream):
|
||||
path = Path(source)
|
||||
source = ByteStream.from_file_path(path)
|
||||
source.meta["file_path"] = path
|
||||
else:
|
||||
# If we received a ByteStream instance that doesn't have the "file_path" metadata set,
|
||||
# we dump the bytes into a temporary file.
|
||||
path = source.meta.get("file_path")
|
||||
if path is None:
|
||||
fp = tempfile.NamedTemporaryFile(delete=False)
|
||||
path = Path(source) if not isinstance(source, ByteStream) else source.meta.get("file_path")
|
||||
|
||||
if isinstance(source, ByteStream) and path is None:
|
||||
with tempfile.NamedTemporaryFile(delete=False) as fp:
|
||||
path = Path(fp.name)
|
||||
source.to_file(path)
|
||||
source.meta["file_path"] = path
|
||||
|
||||
transcription = self._model.transcribe(str(path), **kwargs)
|
||||
|
||||
if not return_segments:
|
||||
transcription.pop("segments", None)
|
||||
|
||||
transcriptions[path] = transcription
|
||||
|
||||
return transcriptions
|
||||
|
||||
@ -8,6 +8,8 @@ from unittest.mock import patch, MagicMock
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from haystack import Pipeline
|
||||
from haystack.components.fetchers import LinkContentFetcher
|
||||
from haystack.dataclasses import Document, ByteStream
|
||||
from haystack.components.audio import LocalWhisperTranscriber
|
||||
from haystack.utils.device import ComponentDevice, Device
|
||||
@ -169,7 +171,7 @@ class TestLocalWhisperTranscriber:
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(sys.platform in ["win32", "cygwin"], reason="ffmpeg not installed on Windows CI")
|
||||
def test_whisper_local_transcriber(self, test_files_path):
|
||||
comp = LocalWhisperTranscriber(model="medium", whisper_params={"language": "english"})
|
||||
comp = LocalWhisperTranscriber(model="tiny", whisper_params={"language": "english"})
|
||||
comp.warm_up()
|
||||
output = comp.run(
|
||||
sources=[
|
||||
@ -181,13 +183,34 @@ class TestLocalWhisperTranscriber:
|
||||
docs = output["documents"]
|
||||
assert len(docs) == 3
|
||||
|
||||
assert docs[0].content.strip().lower() == "this is the content of the document."
|
||||
assert all(
|
||||
word in docs[0].content.strip().lower() for word in {"content", "the", "document"}
|
||||
), f"Expected words not found in: {docs[0].content.strip().lower()}"
|
||||
assert test_files_path / "audio" / "this is the content of the document.wav" == docs[0].meta["audio_file"]
|
||||
|
||||
assert docs[1].content.strip().lower() == "the context for this answer is here."
|
||||
assert all(
|
||||
word in docs[1].content.strip().lower() for word in {"context", "answer"}
|
||||
), f"Expected words not found in: {docs[1].content.strip().lower()}"
|
||||
path = test_files_path / "audio" / "the context for this answer is here.wav"
|
||||
assert path.absolute() == docs[1].meta["audio_file"]
|
||||
|
||||
assert docs[2].content.strip().lower() == "answer."
|
||||
# meta.audio_file should contain the temp path where we dumped the audio bytes
|
||||
assert docs[2].meta["audio_file"]
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(sys.platform in ["win32", "cygwin"], reason="ffmpeg not installed on Windows CI")
|
||||
def test_whisper_local_transcriber_pipeline_and_url_source(self):
|
||||
pipe = Pipeline()
|
||||
pipe.add_component("fetcher", LinkContentFetcher())
|
||||
pipe.add_component("transcriber", LocalWhisperTranscriber(model="tiny"))
|
||||
|
||||
pipe.connect("fetcher", "transcriber")
|
||||
result = pipe.run(
|
||||
data={
|
||||
"fetcher": {
|
||||
"urls": ["https://ia903102.us.archive.org/19/items/100-Best--Speeches/EK_19690725_64kb.mp3"]
|
||||
}
|
||||
}
|
||||
)
|
||||
assert "Massachusetts" in result["transcriber"]["documents"][0].content
|
||||
|
||||
@ -4,7 +4,10 @@
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from haystack import Pipeline
|
||||
from haystack.components.audio import LocalWhisperTranscriber
|
||||
from haystack.components.audio.whisper_remote import RemoteWhisperTranscriber
|
||||
from haystack.components.fetchers import LinkContentFetcher
|
||||
from haystack.dataclasses import ByteStream
|
||||
from haystack.utils import Secret
|
||||
|
||||
@ -186,3 +189,19 @@ class TestRemoteWhisperTranscriber:
|
||||
assert str(test_files_path / "audio" / "the context for this answer is here.wav") == docs[1].meta["file_path"]
|
||||
|
||||
assert docs[2].content.strip().lower() == "answer."
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_whisper_local_transcriber_pipeline_and_url_source(self):
|
||||
pipe = Pipeline()
|
||||
pipe.add_component("fetcher", LinkContentFetcher())
|
||||
pipe.add_component("transcriber", RemoteWhisperTranscriber())
|
||||
|
||||
pipe.connect("fetcher", "transcriber")
|
||||
result = pipe.run(
|
||||
data={
|
||||
"fetcher": {
|
||||
"urls": ["https://ia903102.us.archive.org/19/items/100-Best--Speeches/EK_19690725_64kb.mp3"]
|
||||
}
|
||||
}
|
||||
)
|
||||
assert "Massachusetts" in result["transcriber"]["documents"][0].content
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user