haystack/test/components/audio/test_whisper_local.py

217 lines
9.0 KiB
Python
Raw Normal View History

# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import sys
from pathlib import Path
from unittest.mock import patch, MagicMock
import pytest
import torch
from haystack import Pipeline
from haystack.components.fetchers import LinkContentFetcher
from haystack.dataclasses import Document, ByteStream
2023-11-24 14:48:43 +01:00
from haystack.components.audio import LocalWhisperTranscriber
from haystack.utils.device import ComponentDevice, Device
SAMPLES_PATH = Path(__file__).parent.parent.parent / "test_files"
class TestLocalWhisperTranscriber:
def test_init(self):
transcriber = LocalWhisperTranscriber(
model="large-v2"
) # Doesn't matter if it's huge, the model is not loaded in init.
assert transcriber.model == "large-v2"
assert transcriber.device == ComponentDevice.resolve_device(None)
assert transcriber._model is None
def test_init_wrong_model(self):
with pytest.raises(ValueError, match="Model name 'whisper-1' not recognized"):
LocalWhisperTranscriber(model="whisper-1")
chore: migrate to `canals==0.7.0` (#5647) * add default_to_dict and default_from_dict placeholders to ease migration to canals 0.7.0 * canals==0.7.0 * whisper components * add to_dict/from_dict stubs * import serialization methods in init to hide canals imports * reno * export deserializationerror too * Update haystack/preview/__init__.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * serialization methods for LocalWhisperTranscriber (#5648) * chore: serialization methods for `FileExtensionClassifier` (#5651) * serialization methods for FileExtensionClassifier * Update test_file_classifier.py * chore: serialization methods for `SentenceTransformersDocumentEmbedder` (#5652) * serialization methods for SentenceTransformersDocumentEmbedder * fix device management * serialization methods for SentenceTransformersTextEmbedder (#5653) * serialization methods for TextFileToDocument (#5654) * chore: serialization methods for `RemoteWhisperTranscriber` (#5650) * serialization methods for RemoteWhisperTranscriber * remove patches * Add default to_dict and from_dict in document stores built with factory (#5674) * fix tests (#5671) * chore: simplify serialization methods for `MemoryDocumentStore` (#5667) * simplify serialization for MemoryDocumentStore * remove redundant tests * pylint * chore: serialization methods for `MemoryRetriever` (#5663) * serialization method for MemoryRetriever * more tests * remove hash from default_document_store_to_dict * remove diff in factory.py * chore: serialization methods for `DocumentWriter` (#5661) * serialization methods for DocumentWriter * more tests * use factory * black --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
2023-08-29 18:15:07 +02:00
def test_to_dict(self):
transcriber = LocalWhisperTranscriber()
data = transcriber.to_dict()
assert data == {
2023-11-24 14:48:43 +01:00
"type": "haystack.components.audio.whisper_local.LocalWhisperTranscriber",
"init_parameters": {
"model": "large",
"device": ComponentDevice.resolve_device(None).to_dict(),
"whisper_params": {},
},
chore: migrate to `canals==0.7.0` (#5647) * add default_to_dict and default_from_dict placeholders to ease migration to canals 0.7.0 * canals==0.7.0 * whisper components * add to_dict/from_dict stubs * import serialization methods in init to hide canals imports * reno * export deserializationerror too * Update haystack/preview/__init__.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * serialization methods for LocalWhisperTranscriber (#5648) * chore: serialization methods for `FileExtensionClassifier` (#5651) * serialization methods for FileExtensionClassifier * Update test_file_classifier.py * chore: serialization methods for `SentenceTransformersDocumentEmbedder` (#5652) * serialization methods for SentenceTransformersDocumentEmbedder * fix device management * serialization methods for SentenceTransformersTextEmbedder (#5653) * serialization methods for TextFileToDocument (#5654) * chore: serialization methods for `RemoteWhisperTranscriber` (#5650) * serialization methods for RemoteWhisperTranscriber * remove patches * Add default to_dict and from_dict in document stores built with factory (#5674) * fix tests (#5671) * chore: simplify serialization methods for `MemoryDocumentStore` (#5667) * simplify serialization for MemoryDocumentStore * remove redundant tests * pylint * chore: serialization methods for `MemoryRetriever` (#5663) * serialization method for MemoryRetriever * more tests * remove hash from default_document_store_to_dict * remove diff in factory.py * chore: serialization methods for `DocumentWriter` (#5661) * serialization methods for DocumentWriter * more tests * use factory * black --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
2023-08-29 18:15:07 +02:00
}
def test_to_dict_with_custom_init_parameters(self):
transcriber = LocalWhisperTranscriber(
model="tiny",
device=ComponentDevice.from_str("cuda:0"),
whisper_params={"return_segments": True, "temperature": [0.1, 0.6, 0.8]},
chore: migrate to `canals==0.7.0` (#5647) * add default_to_dict and default_from_dict placeholders to ease migration to canals 0.7.0 * canals==0.7.0 * whisper components * add to_dict/from_dict stubs * import serialization methods in init to hide canals imports * reno * export deserializationerror too * Update haystack/preview/__init__.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * serialization methods for LocalWhisperTranscriber (#5648) * chore: serialization methods for `FileExtensionClassifier` (#5651) * serialization methods for FileExtensionClassifier * Update test_file_classifier.py * chore: serialization methods for `SentenceTransformersDocumentEmbedder` (#5652) * serialization methods for SentenceTransformersDocumentEmbedder * fix device management * serialization methods for SentenceTransformersTextEmbedder (#5653) * serialization methods for TextFileToDocument (#5654) * chore: serialization methods for `RemoteWhisperTranscriber` (#5650) * serialization methods for RemoteWhisperTranscriber * remove patches * Add default to_dict and from_dict in document stores built with factory (#5674) * fix tests (#5671) * chore: simplify serialization methods for `MemoryDocumentStore` (#5667) * simplify serialization for MemoryDocumentStore * remove redundant tests * pylint * chore: serialization methods for `MemoryRetriever` (#5663) * serialization method for MemoryRetriever * more tests * remove hash from default_document_store_to_dict * remove diff in factory.py * chore: serialization methods for `DocumentWriter` (#5661) * serialization methods for DocumentWriter * more tests * use factory * black --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
2023-08-29 18:15:07 +02:00
)
data = transcriber.to_dict()
assert data == {
2023-11-24 14:48:43 +01:00
"type": "haystack.components.audio.whisper_local.LocalWhisperTranscriber",
chore: migrate to `canals==0.7.0` (#5647) * add default_to_dict and default_from_dict placeholders to ease migration to canals 0.7.0 * canals==0.7.0 * whisper components * add to_dict/from_dict stubs * import serialization methods in init to hide canals imports * reno * export deserializationerror too * Update haystack/preview/__init__.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * serialization methods for LocalWhisperTranscriber (#5648) * chore: serialization methods for `FileExtensionClassifier` (#5651) * serialization methods for FileExtensionClassifier * Update test_file_classifier.py * chore: serialization methods for `SentenceTransformersDocumentEmbedder` (#5652) * serialization methods for SentenceTransformersDocumentEmbedder * fix device management * serialization methods for SentenceTransformersTextEmbedder (#5653) * serialization methods for TextFileToDocument (#5654) * chore: serialization methods for `RemoteWhisperTranscriber` (#5650) * serialization methods for RemoteWhisperTranscriber * remove patches * Add default to_dict and from_dict in document stores built with factory (#5674) * fix tests (#5671) * chore: simplify serialization methods for `MemoryDocumentStore` (#5667) * simplify serialization for MemoryDocumentStore * remove redundant tests * pylint * chore: serialization methods for `MemoryRetriever` (#5663) * serialization method for MemoryRetriever * more tests * remove hash from default_document_store_to_dict * remove diff in factory.py * chore: serialization methods for `DocumentWriter` (#5661) * serialization methods for DocumentWriter * more tests * use factory * black --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
2023-08-29 18:15:07 +02:00
"init_parameters": {
"model": "tiny",
"device": ComponentDevice.from_str("cuda:0").to_dict(),
chore: migrate to `canals==0.7.0` (#5647) * add default_to_dict and default_from_dict placeholders to ease migration to canals 0.7.0 * canals==0.7.0 * whisper components * add to_dict/from_dict stubs * import serialization methods in init to hide canals imports * reno * export deserializationerror too * Update haystack/preview/__init__.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * serialization methods for LocalWhisperTranscriber (#5648) * chore: serialization methods for `FileExtensionClassifier` (#5651) * serialization methods for FileExtensionClassifier * Update test_file_classifier.py * chore: serialization methods for `SentenceTransformersDocumentEmbedder` (#5652) * serialization methods for SentenceTransformersDocumentEmbedder * fix device management * serialization methods for SentenceTransformersTextEmbedder (#5653) * serialization methods for TextFileToDocument (#5654) * chore: serialization methods for `RemoteWhisperTranscriber` (#5650) * serialization methods for RemoteWhisperTranscriber * remove patches * Add default to_dict and from_dict in document stores built with factory (#5674) * fix tests (#5671) * chore: simplify serialization methods for `MemoryDocumentStore` (#5667) * simplify serialization for MemoryDocumentStore * remove redundant tests * pylint * chore: serialization methods for `MemoryRetriever` (#5663) * serialization method for MemoryRetriever * more tests * remove hash from default_document_store_to_dict * remove diff in factory.py * chore: serialization methods for `DocumentWriter` (#5661) * serialization methods for DocumentWriter * more tests * use factory * black --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
2023-08-29 18:15:07 +02:00
"whisper_params": {"return_segments": True, "temperature": [0.1, 0.6, 0.8]},
},
}
def test_from_dict(self):
data = {
"type": "haystack.components.audio.whisper_local.LocalWhisperTranscriber",
"init_parameters": {
"model": "tiny",
"device": ComponentDevice.from_single(Device.cpu()).to_dict(),
"whisper_params": {},
},
}
transcriber = LocalWhisperTranscriber.from_dict(data)
assert transcriber.model == "tiny"
assert transcriber.device == ComponentDevice.from_single(Device.cpu())
assert transcriber.whisper_params == {}
assert transcriber._model is None
def test_from_dict_none_device(self):
data = {
"type": "haystack.components.audio.whisper_local.LocalWhisperTranscriber",
"init_parameters": {"model": "tiny", "device": None, "whisper_params": {}},
}
transcriber = LocalWhisperTranscriber.from_dict(data)
assert transcriber.model == "tiny"
assert transcriber.device == ComponentDevice.resolve_device(None)
assert transcriber.whisper_params == {}
assert transcriber._model is None
def test_warmup(self):
2023-11-24 14:48:43 +01:00
with patch("haystack.components.audio.whisper_local.whisper") as mocked_whisper:
transcriber = LocalWhisperTranscriber(model="large-v2", device=ComponentDevice.from_str("cpu"))
mocked_whisper.load_model.assert_not_called()
transcriber.warm_up()
mocked_whisper.load_model.assert_called_once_with("large-v2", device=torch.device(type="cpu"))
def test_warmup_doesnt_reload(self):
2023-11-24 14:48:43 +01:00
with patch("haystack.components.audio.whisper_local.whisper") as mocked_whisper:
transcriber = LocalWhisperTranscriber(model="large-v2")
transcriber.warm_up()
transcriber.warm_up()
mocked_whisper.load_model.assert_called_once()
def test_run_with_path(self):
comp = LocalWhisperTranscriber(model="large-v2")
comp._model = MagicMock()
comp._model.transcribe.return_value = {
"text": "test transcription",
"other_metadata": ["other", "meta", "data"],
}
results = comp.run(sources=[SAMPLES_PATH / "audio" / "this is the content of the document.wav"])
expected = Document(
content="test transcription",
meta={
"audio_file": SAMPLES_PATH / "audio" / "this is the content of the document.wav",
"other_metadata": ["other", "meta", "data"],
},
)
assert results["documents"] == [expected]
def test_run_with_str(self):
comp = LocalWhisperTranscriber(model="large-v2")
comp._model = MagicMock()
comp._model.transcribe.return_value = {
"text": "test transcription",
"other_metadata": ["other", "meta", "data"],
}
results = comp.run(
sources=[str((SAMPLES_PATH / "audio" / "this is the content of the document.wav").absolute())]
)
expected = Document(
content="test transcription",
meta={
"audio_file": (SAMPLES_PATH / "audio" / "this is the content of the document.wav").absolute(),
"other_metadata": ["other", "meta", "data"],
},
)
assert results["documents"] == [expected]
def test_transcribe(self):
comp = LocalWhisperTranscriber(model="large-v2")
comp._model = MagicMock()
comp._model.transcribe.return_value = {
"text": "test transcription",
"other_metadata": ["other", "meta", "data"],
}
results = comp.transcribe(sources=[SAMPLES_PATH / "audio" / "this is the content of the document.wav"])
expected = Document(
content="test transcription",
meta={
"audio_file": SAMPLES_PATH / "audio" / "this is the content of the document.wav",
"other_metadata": ["other", "meta", "data"],
},
)
assert results == [expected]
def test_transcribe_stream(self):
comp = LocalWhisperTranscriber(model="large-v2")
comp._model = MagicMock()
comp._model.transcribe.return_value = {
"text": "test transcription",
"other_metadata": ["other", "meta", "data"],
}
path = SAMPLES_PATH / "audio" / "this is the content of the document.wav"
bs = ByteStream.from_file_path(path)
2023-12-21 17:09:58 +05:30
bs.meta["file_path"] = path
results = comp.transcribe(sources=[bs])
expected = Document(
content="test transcription", meta={"audio_file": path, "other_metadata": ["other", "meta", "data"]}
)
assert results == [expected]
@pytest.mark.integration
@pytest.mark.skipif(sys.platform in ["win32", "cygwin"], reason="ffmpeg not installed on Windows CI")
2023-11-24 14:48:43 +01:00
def test_whisper_local_transcriber(self, test_files_path):
comp = LocalWhisperTranscriber(model="tiny", whisper_params={"language": "english"})
comp.warm_up()
output = comp.run(
sources=[
2023-11-24 14:48:43 +01:00
test_files_path / "audio" / "this is the content of the document.wav",
str((test_files_path / "audio" / "the context for this answer is here.wav").absolute()),
ByteStream.from_file_path(test_files_path / "audio" / "answer.wav", "rb"),
]
)
docs = output["documents"]
assert len(docs) == 3
assert all(
word in docs[0].content.strip().lower() for word in {"content", "the", "document"}
), f"Expected words not found in: {docs[0].content.strip().lower()}"
2023-11-24 14:48:43 +01:00
assert test_files_path / "audio" / "this is the content of the document.wav" == docs[0].meta["audio_file"]
assert all(
word in docs[1].content.strip().lower() for word in {"context", "answer"}
), f"Expected words not found in: {docs[1].content.strip().lower()}"
path = test_files_path / "audio" / "the context for this answer is here.wav"
assert path.absolute() == docs[1].meta["audio_file"]
assert docs[2].content.strip().lower() == "answer."
# meta.audio_file should contain the temp path where we dumped the audio bytes
assert docs[2].meta["audio_file"]
@pytest.mark.integration
@pytest.mark.skipif(sys.platform in ["win32", "cygwin"], reason="ffmpeg not installed on Windows CI")
def test_whisper_local_transcriber_pipeline_and_url_source(self):
pipe = Pipeline()
pipe.add_component("fetcher", LinkContentFetcher())
pipe.add_component("transcriber", LocalWhisperTranscriber(model="tiny"))
pipe.connect("fetcher", "transcriber")
result = pipe.run(
data={
"fetcher": {
"urls": ["https://ia903102.us.archive.org/19/items/100-Best--Speeches/EK_19690725_64kb.mp3"]
}
}
)
assert "Massachusetts" in result["transcriber"]["documents"][0].content