haystack/test/components/audio/test_whisper_remote.py

255 lines
10 KiB
Python
Raw Normal View History

import os
from unittest.mock import patch
from pathlib import Path
import openai
import pytest
from openai.util import convert_to_openai_object
from haystack.preview.components.audio.whisper_remote import RemoteWhisperTranscriber
from haystack.preview.dataclasses import ByteStream
def mock_openai_response(response_format="json", **kwargs) -> openai.openai_object.OpenAIObject:
if response_format == "json":
dict_response = {"text": "test transcription"}
# Currently only "json" is supported.
else:
dict_response = {}
return convert_to_openai_object(dict_response)
class TestRemoteWhisperTranscriber:
@pytest.mark.unit
def test_init_no_key(self, monkeypatch):
openai.api_key = None
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
error_msg = "RemoteWhisperTranscriber expects an OpenAI API key."
with pytest.raises(ValueError, match=error_msg):
RemoteWhisperTranscriber(api_key=None)
def test_init_key_env_var(self, monkeypatch):
openai.api_key = None
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
RemoteWhisperTranscriber(api_key=None)
assert openai.api_key == "test_api_key"
def test_init_key_module_env_and_global_var(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key_2")
openai.api_key = "test_api_key_1"
RemoteWhisperTranscriber(api_key=None)
# The module global variable takes preference
assert openai.api_key == "test_api_key_1"
@pytest.mark.unit
def test_init_default(self):
transcriber = RemoteWhisperTranscriber(api_key="test_api_key")
assert openai.api_key == "test_api_key"
assert transcriber.model_name == "whisper-1"
assert transcriber.organization is None
assert transcriber.api_base_url == "https://api.openai.com/v1"
assert transcriber.whisper_params == {"response_format": "json"}
@pytest.mark.unit
def test_init_custom_parameters(self):
transcriber = RemoteWhisperTranscriber(
api_key="test_api_key",
model_name="whisper-1",
organization="test-org",
api_base_url="test_api_url",
language="en",
prompt="test-prompt",
response_format="json",
temperature="0.5",
)
assert openai.api_key == "test_api_key"
assert transcriber.model_name == "whisper-1"
assert transcriber.organization == "test-org"
assert transcriber.api_base_url == "test_api_url"
assert transcriber.whisper_params == {
"language": "en",
"prompt": "test-prompt",
"response_format": "json",
"temperature": "0.5",
}
chore: migrate to `canals==0.7.0` (#5647) * add default_to_dict and default_from_dict placeholders to ease migration to canals 0.7.0 * canals==0.7.0 * whisper components * add to_dict/from_dict stubs * import serialization methods in init to hide canals imports * reno * export deserializationerror too * Update haystack/preview/__init__.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * serialization methods for LocalWhisperTranscriber (#5648) * chore: serialization methods for `FileExtensionClassifier` (#5651) * serialization methods for FileExtensionClassifier * Update test_file_classifier.py * chore: serialization methods for `SentenceTransformersDocumentEmbedder` (#5652) * serialization methods for SentenceTransformersDocumentEmbedder * fix device management * serialization methods for SentenceTransformersTextEmbedder (#5653) * serialization methods for TextFileToDocument (#5654) * chore: serialization methods for `RemoteWhisperTranscriber` (#5650) * serialization methods for RemoteWhisperTranscriber * remove patches * Add default to_dict and from_dict in document stores built with factory (#5674) * fix tests (#5671) * chore: simplify serialization methods for `MemoryDocumentStore` (#5667) * simplify serialization for MemoryDocumentStore * remove redundant tests * pylint * chore: serialization methods for `MemoryRetriever` (#5663) * serialization method for MemoryRetriever * more tests * remove hash from default_document_store_to_dict * remove diff in factory.py * chore: serialization methods for `DocumentWriter` (#5661) * serialization methods for DocumentWriter * more tests * use factory * black --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
2023-08-29 18:15:07 +02:00
@pytest.mark.unit
def test_to_dict_default_parameters(self):
transcriber = RemoteWhisperTranscriber(api_key="test_api_key")
chore: migrate to `canals==0.7.0` (#5647) * add default_to_dict and default_from_dict placeholders to ease migration to canals 0.7.0 * canals==0.7.0 * whisper components * add to_dict/from_dict stubs * import serialization methods in init to hide canals imports * reno * export deserializationerror too * Update haystack/preview/__init__.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * serialization methods for LocalWhisperTranscriber (#5648) * chore: serialization methods for `FileExtensionClassifier` (#5651) * serialization methods for FileExtensionClassifier * Update test_file_classifier.py * chore: serialization methods for `SentenceTransformersDocumentEmbedder` (#5652) * serialization methods for SentenceTransformersDocumentEmbedder * fix device management * serialization methods for SentenceTransformersTextEmbedder (#5653) * serialization methods for TextFileToDocument (#5654) * chore: serialization methods for `RemoteWhisperTranscriber` (#5650) * serialization methods for RemoteWhisperTranscriber * remove patches * Add default to_dict and from_dict in document stores built with factory (#5674) * fix tests (#5671) * chore: simplify serialization methods for `MemoryDocumentStore` (#5667) * simplify serialization for MemoryDocumentStore * remove redundant tests * pylint * chore: serialization methods for `MemoryRetriever` (#5663) * serialization method for MemoryRetriever * more tests * remove hash from default_document_store_to_dict * remove diff in factory.py * chore: serialization methods for `DocumentWriter` (#5661) * serialization methods for DocumentWriter * more tests * use factory * black --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
2023-08-29 18:15:07 +02:00
data = transcriber.to_dict()
assert data == {
"type": "haystack.preview.components.audio.whisper_remote.RemoteWhisperTranscriber",
chore: migrate to `canals==0.7.0` (#5647) * add default_to_dict and default_from_dict placeholders to ease migration to canals 0.7.0 * canals==0.7.0 * whisper components * add to_dict/from_dict stubs * import serialization methods in init to hide canals imports * reno * export deserializationerror too * Update haystack/preview/__init__.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * serialization methods for LocalWhisperTranscriber (#5648) * chore: serialization methods for `FileExtensionClassifier` (#5651) * serialization methods for FileExtensionClassifier * Update test_file_classifier.py * chore: serialization methods for `SentenceTransformersDocumentEmbedder` (#5652) * serialization methods for SentenceTransformersDocumentEmbedder * fix device management * serialization methods for SentenceTransformersTextEmbedder (#5653) * serialization methods for TextFileToDocument (#5654) * chore: serialization methods for `RemoteWhisperTranscriber` (#5650) * serialization methods for RemoteWhisperTranscriber * remove patches * Add default to_dict and from_dict in document stores built with factory (#5674) * fix tests (#5671) * chore: simplify serialization methods for `MemoryDocumentStore` (#5667) * simplify serialization for MemoryDocumentStore * remove redundant tests * pylint * chore: serialization methods for `MemoryRetriever` (#5663) * serialization method for MemoryRetriever * more tests * remove hash from default_document_store_to_dict * remove diff in factory.py * chore: serialization methods for `DocumentWriter` (#5661) * serialization methods for DocumentWriter * more tests * use factory * black --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
2023-08-29 18:15:07 +02:00
"init_parameters": {
"model_name": "whisper-1",
"api_base_url": "https://api.openai.com/v1",
"organization": None,
"response_format": "json",
chore: migrate to `canals==0.7.0` (#5647) * add default_to_dict and default_from_dict placeholders to ease migration to canals 0.7.0 * canals==0.7.0 * whisper components * add to_dict/from_dict stubs * import serialization methods in init to hide canals imports * reno * export deserializationerror too * Update haystack/preview/__init__.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * serialization methods for LocalWhisperTranscriber (#5648) * chore: serialization methods for `FileExtensionClassifier` (#5651) * serialization methods for FileExtensionClassifier * Update test_file_classifier.py * chore: serialization methods for `SentenceTransformersDocumentEmbedder` (#5652) * serialization methods for SentenceTransformersDocumentEmbedder * fix device management * serialization methods for SentenceTransformersTextEmbedder (#5653) * serialization methods for TextFileToDocument (#5654) * chore: serialization methods for `RemoteWhisperTranscriber` (#5650) * serialization methods for RemoteWhisperTranscriber * remove patches * Add default to_dict and from_dict in document stores built with factory (#5674) * fix tests (#5671) * chore: simplify serialization methods for `MemoryDocumentStore` (#5667) * simplify serialization for MemoryDocumentStore * remove redundant tests * pylint * chore: serialization methods for `MemoryRetriever` (#5663) * serialization method for MemoryRetriever * more tests * remove hash from default_document_store_to_dict * remove diff in factory.py * chore: serialization methods for `DocumentWriter` (#5661) * serialization methods for DocumentWriter * more tests * use factory * black --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
2023-08-29 18:15:07 +02:00
},
}
@pytest.mark.unit
def test_to_dict_with_custom_init_parameters(self):
transcriber = RemoteWhisperTranscriber(
api_key="test_api_key",
chore: migrate to `canals==0.7.0` (#5647) * add default_to_dict and default_from_dict placeholders to ease migration to canals 0.7.0 * canals==0.7.0 * whisper components * add to_dict/from_dict stubs * import serialization methods in init to hide canals imports * reno * export deserializationerror too * Update haystack/preview/__init__.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * serialization methods for LocalWhisperTranscriber (#5648) * chore: serialization methods for `FileExtensionClassifier` (#5651) * serialization methods for FileExtensionClassifier * Update test_file_classifier.py * chore: serialization methods for `SentenceTransformersDocumentEmbedder` (#5652) * serialization methods for SentenceTransformersDocumentEmbedder * fix device management * serialization methods for SentenceTransformersTextEmbedder (#5653) * serialization methods for TextFileToDocument (#5654) * chore: serialization methods for `RemoteWhisperTranscriber` (#5650) * serialization methods for RemoteWhisperTranscriber * remove patches * Add default to_dict and from_dict in document stores built with factory (#5674) * fix tests (#5671) * chore: simplify serialization methods for `MemoryDocumentStore` (#5667) * simplify serialization for MemoryDocumentStore * remove redundant tests * pylint * chore: serialization methods for `MemoryRetriever` (#5663) * serialization method for MemoryRetriever * more tests * remove hash from default_document_store_to_dict * remove diff in factory.py * chore: serialization methods for `DocumentWriter` (#5661) * serialization methods for DocumentWriter * more tests * use factory * black --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
2023-08-29 18:15:07 +02:00
model_name="whisper-1",
organization="test-org",
api_base_url="test_api_url",
language="en",
prompt="test-prompt",
response_format="json",
temperature="0.5",
chore: migrate to `canals==0.7.0` (#5647) * add default_to_dict and default_from_dict placeholders to ease migration to canals 0.7.0 * canals==0.7.0 * whisper components * add to_dict/from_dict stubs * import serialization methods in init to hide canals imports * reno * export deserializationerror too * Update haystack/preview/__init__.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * serialization methods for LocalWhisperTranscriber (#5648) * chore: serialization methods for `FileExtensionClassifier` (#5651) * serialization methods for FileExtensionClassifier * Update test_file_classifier.py * chore: serialization methods for `SentenceTransformersDocumentEmbedder` (#5652) * serialization methods for SentenceTransformersDocumentEmbedder * fix device management * serialization methods for SentenceTransformersTextEmbedder (#5653) * serialization methods for TextFileToDocument (#5654) * chore: serialization methods for `RemoteWhisperTranscriber` (#5650) * serialization methods for RemoteWhisperTranscriber * remove patches * Add default to_dict and from_dict in document stores built with factory (#5674) * fix tests (#5671) * chore: simplify serialization methods for `MemoryDocumentStore` (#5667) * simplify serialization for MemoryDocumentStore * remove redundant tests * pylint * chore: serialization methods for `MemoryRetriever` (#5663) * serialization method for MemoryRetriever * more tests * remove hash from default_document_store_to_dict * remove diff in factory.py * chore: serialization methods for `DocumentWriter` (#5661) * serialization methods for DocumentWriter * more tests * use factory * black --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
2023-08-29 18:15:07 +02:00
)
data = transcriber.to_dict()
assert data == {
"type": "haystack.preview.components.audio.whisper_remote.RemoteWhisperTranscriber",
chore: migrate to `canals==0.7.0` (#5647) * add default_to_dict and default_from_dict placeholders to ease migration to canals 0.7.0 * canals==0.7.0 * whisper components * add to_dict/from_dict stubs * import serialization methods in init to hide canals imports * reno * export deserializationerror too * Update haystack/preview/__init__.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * serialization methods for LocalWhisperTranscriber (#5648) * chore: serialization methods for `FileExtensionClassifier` (#5651) * serialization methods for FileExtensionClassifier * Update test_file_classifier.py * chore: serialization methods for `SentenceTransformersDocumentEmbedder` (#5652) * serialization methods for SentenceTransformersDocumentEmbedder * fix device management * serialization methods for SentenceTransformersTextEmbedder (#5653) * serialization methods for TextFileToDocument (#5654) * chore: serialization methods for `RemoteWhisperTranscriber` (#5650) * serialization methods for RemoteWhisperTranscriber * remove patches * Add default to_dict and from_dict in document stores built with factory (#5674) * fix tests (#5671) * chore: simplify serialization methods for `MemoryDocumentStore` (#5667) * simplify serialization for MemoryDocumentStore * remove redundant tests * pylint * chore: serialization methods for `MemoryRetriever` (#5663) * serialization method for MemoryRetriever * more tests * remove hash from default_document_store_to_dict * remove diff in factory.py * chore: serialization methods for `DocumentWriter` (#5661) * serialization methods for DocumentWriter * more tests * use factory * black --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
2023-08-29 18:15:07 +02:00
"init_parameters": {
"model_name": "whisper-1",
"organization": "test-org",
"api_base_url": "test_api_url",
"language": "en",
"prompt": "test-prompt",
"response_format": "json",
"temperature": "0.5",
chore: migrate to `canals==0.7.0` (#5647) * add default_to_dict and default_from_dict placeholders to ease migration to canals 0.7.0 * canals==0.7.0 * whisper components * add to_dict/from_dict stubs * import serialization methods in init to hide canals imports * reno * export deserializationerror too * Update haystack/preview/__init__.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * serialization methods for LocalWhisperTranscriber (#5648) * chore: serialization methods for `FileExtensionClassifier` (#5651) * serialization methods for FileExtensionClassifier * Update test_file_classifier.py * chore: serialization methods for `SentenceTransformersDocumentEmbedder` (#5652) * serialization methods for SentenceTransformersDocumentEmbedder * fix device management * serialization methods for SentenceTransformersTextEmbedder (#5653) * serialization methods for TextFileToDocument (#5654) * chore: serialization methods for `RemoteWhisperTranscriber` (#5650) * serialization methods for RemoteWhisperTranscriber * remove patches * Add default to_dict and from_dict in document stores built with factory (#5674) * fix tests (#5671) * chore: simplify serialization methods for `MemoryDocumentStore` (#5667) * simplify serialization for MemoryDocumentStore * remove redundant tests * pylint * chore: serialization methods for `MemoryRetriever` (#5663) * serialization method for MemoryRetriever * more tests * remove hash from default_document_store_to_dict * remove diff in factory.py * chore: serialization methods for `DocumentWriter` (#5661) * serialization methods for DocumentWriter * more tests * use factory * black --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
2023-08-29 18:15:07 +02:00
},
}
def test_from_dict_with_defualt_parameters(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
data = {
"type": "haystack.preview.components.audio.whisper_remote.RemoteWhisperTranscriber",
"init_parameters": {
"model_name": "whisper-1",
"api_base_url": "https://api.openai.com/v1",
"organization": None,
"response_format": "json",
},
}
transcriber = RemoteWhisperTranscriber.from_dict(data)
assert openai.api_key == "test_api_key"
assert transcriber.model_name == "whisper-1"
assert transcriber.organization is None
assert transcriber.api_base_url == "https://api.openai.com/v1"
assert transcriber.whisper_params == {"response_format": "json"}
def test_from_dict_with_custom_init_parameters(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
data = {
"type": "haystack.preview.components.audio.whisper_remote.RemoteWhisperTranscriber",
"init_parameters": {
"model_name": "whisper-1",
"organization": "test-org",
"api_base_url": "test_api_url",
"language": "en",
"prompt": "test-prompt",
"response_format": "json",
"temperature": "0.5",
},
}
transcriber = RemoteWhisperTranscriber.from_dict(data)
assert openai.api_key == "test_api_key"
assert transcriber.model_name == "whisper-1"
assert transcriber.organization == "test-org"
assert transcriber.api_base_url == "test_api_url"
assert transcriber.whisper_params == {
"language": "en",
"prompt": "test-prompt",
"response_format": "json",
"temperature": "0.5",
}
def test_from_dict_with_defualt_parameters_no_env_var(self, monkeypatch):
openai.api_key = None
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
data = {
"type": "haystack.preview.components.audio.whisper_remote.RemoteWhisperTranscriber",
"init_parameters": {
"model_name": "whisper-1",
"api_base_url": "https://api.openai.com/v1",
"organization": None,
"response_format": "json",
},
}
with pytest.raises(ValueError, match="RemoteWhisperTranscriber expects an OpenAI API key."):
RemoteWhisperTranscriber.from_dict(data)
@pytest.mark.unit
def test_run_str(self, preview_samples_path):
with patch("haystack.preview.components.audio.whisper_remote.openai.Audio") as openai_audio_patch:
model = "whisper-1"
file_path = str(preview_samples_path / "audio" / "this is the content of the document.wav")
openai_audio_patch.transcribe.side_effect = mock_openai_response
transcriber = RemoteWhisperTranscriber(api_key="test_api_key", model_name=model, response_format="json")
result = transcriber.run(sources=[file_path])
assert result["documents"][0].content == "test transcription"
assert result["documents"][0].meta["file_path"] == file_path
@pytest.mark.unit
def test_run_path(self, preview_samples_path):
with patch("haystack.preview.components.audio.whisper_remote.openai.Audio") as openai_audio_patch:
model = "whisper-1"
file_path = preview_samples_path / "audio" / "this is the content of the document.wav"
openai_audio_patch.transcribe.side_effect = mock_openai_response
transcriber = RemoteWhisperTranscriber(api_key="test_api_key", model_name=model, response_format="json")
result = transcriber.run(sources=[file_path])
assert result["documents"][0].content == "test transcription"
assert result["documents"][0].meta["file_path"] == file_path
@pytest.mark.unit
def test_run_bytestream(self, preview_samples_path):
with patch("haystack.preview.components.audio.whisper_remote.openai.Audio") as openai_audio_patch:
model = "whisper-1"
file_path = preview_samples_path / "audio" / "this is the content of the document.wav"
openai_audio_patch.transcribe.side_effect = mock_openai_response
transcriber = RemoteWhisperTranscriber(api_key="test_api_key", model_name=model, response_format="json")
with open(file_path, "rb") as audio_stream:
byte_stream = audio_stream.read()
audio_file = ByteStream(byte_stream, metadata={"file_path": str(file_path.absolute())})
result = transcriber.run(sources=[audio_file])
assert result["documents"][0].content == "test transcription"
assert result["documents"][0].meta["file_path"] == str(file_path.absolute())
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
@pytest.mark.integration
def test_whisper_remote_transcriber(self, preview_samples_path):
transcriber = RemoteWhisperTranscriber(api_key=os.environ.get("OPENAI_API_KEY"))
paths = [
preview_samples_path / "audio" / "this is the content of the document.wav",
str(preview_samples_path / "audio" / "the context for this answer is here.wav"),
ByteStream.from_file_path(preview_samples_path / "audio" / "answer.wav"),
]
output = transcriber.run(sources=paths)
docs = output["documents"]
assert len(docs) == 3
assert docs[0].content.strip().lower() == "this is the content of the document."
assert preview_samples_path / "audio" / "this is the content of the document.wav" == docs[0].meta["file_path"]
assert docs[1].content.strip().lower() == "the context for this answer is here."
assert (
str(preview_samples_path / "audio" / "the context for this answer is here.wav") == docs[1].meta["file_path"]
)
assert docs[2].content.strip().lower() == "answer."