haystack/test/preview/components/audio/test_whisper_remote.py

246 lines
11 KiB
Python
Raw Normal View History

import os
from unittest.mock import MagicMock, patch
import pytest
from haystack.preview.dataclasses import Document
from haystack.preview.components.audio.whisper_remote import RemoteWhisperTranscriber, OPENAI_TIMEOUT
class TestRemoteWhisperTranscriber:
@pytest.mark.unit
def test_init_unknown_model(self):
with pytest.raises(ValueError, match="not recognized"):
RemoteWhisperTranscriber(model_name="anything", api_key="something")
@pytest.mark.unit
def test_init_default(self):
transcriber = RemoteWhisperTranscriber(api_key="just a test")
assert transcriber.model_name == "whisper-1"
assert transcriber.api_key == "just a test"
assert transcriber.api_base == "https://api.openai.com/v1"
@pytest.mark.unit
def test_init_no_key(self):
with pytest.raises(ValueError, match="API key is None"):
RemoteWhisperTranscriber(api_key=None)
chore: migrate to `canals==0.7.0` (#5647) * add default_to_dict and default_from_dict placeholders to ease migration to canals 0.7.0 * canals==0.7.0 * whisper components * add to_dict/from_dict stubs * import serialization methods in init to hide canals imports * reno * export deserializationerror too * Update haystack/preview/__init__.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * serialization methods for LocalWhisperTranscriber (#5648) * chore: serialization methods for `FileExtensionClassifier` (#5651) * serialization methods for FileExtensionClassifier * Update test_file_classifier.py * chore: serialization methods for `SentenceTransformersDocumentEmbedder` (#5652) * serialization methods for SentenceTransformersDocumentEmbedder * fix device management * serialization methods for SentenceTransformersTextEmbedder (#5653) * serialization methods for TextFileToDocument (#5654) * chore: serialization methods for `RemoteWhisperTranscriber` (#5650) * serialization methods for RemoteWhisperTranscriber * remove patches * Add default to_dict and from_dict in document stores built with factory (#5674) * fix tests (#5671) * chore: simplify serialization methods for `MemoryDocumentStore` (#5667) * simplify serialization for MemoryDocumentStore * remove redundant tests * pylint * chore: serialization methods for `MemoryRetriever` (#5663) * serialization method for MemoryRetriever * more tests * remove hash from default_document_store_to_dict * remove diff in factory.py * chore: serialization methods for `DocumentWriter` (#5661) * serialization methods for DocumentWriter * more tests * use factory * black --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
2023-08-29 18:15:07 +02:00
@pytest.mark.unit
def test_to_dict(self):
transcriber = RemoteWhisperTranscriber(api_key="test")
data = transcriber.to_dict()
assert data == {
"type": "RemoteWhisperTranscriber",
"init_parameters": {
"model_name": "whisper-1",
"api_key": "test",
"api_base": "https://api.openai.com/v1",
"whisper_params": {},
},
}
@pytest.mark.unit
def test_to_dict_with_custom_init_parameters(self):
transcriber = RemoteWhisperTranscriber(
api_key="test",
model_name="whisper-1",
api_base="https://my.api.base/something_else/v3",
whisper_params={"return_segments": True, "temperature": [0.1, 0.6, 0.8]},
)
data = transcriber.to_dict()
assert data == {
"type": "RemoteWhisperTranscriber",
"init_parameters": {
"model_name": "whisper-1",
"api_key": "test",
"api_base": "https://my.api.base/something_else/v3",
"whisper_params": {"return_segments": True, "temperature": [0.1, 0.6, 0.8]},
},
}
@pytest.mark.unit
def test_from_dict(self):
data = {
"type": "RemoteWhisperTranscriber",
"init_parameters": {
"model_name": "whisper-1",
"api_key": "test",
"api_base": "https://my.api.base/something_else/v3",
"whisper_params": {"return_segments": True, "temperature": [0.1, 0.6, 0.8]},
},
}
transcriber = RemoteWhisperTranscriber.from_dict(data)
assert transcriber.model_name == "whisper-1"
assert transcriber.api_key == "test"
assert transcriber.api_base == "https://my.api.base/something_else/v3"
assert transcriber.whisper_params == {"return_segments": True, "temperature": [0.1, 0.6, 0.8]}
@pytest.mark.unit
def test_run_with_path(self, preview_samples_path):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
comp = RemoteWhisperTranscriber(api_key="whatever")
with patch("haystack.preview.utils.requests_utils.requests") as mocked_requests:
mocked_requests.request.return_value = mock_response
result = comp.run(audio_files=[preview_samples_path / "audio" / "this is the content of the document.wav"])
expected = Document(
text="test transcription",
metadata={
"audio_file": preview_samples_path / "audio" / "this is the content of the document.wav",
"other_metadata": ["other", "meta", "data"],
},
)
assert result["documents"] == [expected]
@pytest.mark.unit
def test_run_with_str(self, preview_samples_path):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
comp = RemoteWhisperTranscriber(api_key="whatever")
with patch("haystack.preview.utils.requests_utils.requests") as mocked_requests:
mocked_requests.request.return_value = mock_response
result = comp.run(
audio_files=[
str((preview_samples_path / "audio" / "this is the content of the document.wav").absolute())
]
)
expected = Document(
text="test transcription",
metadata={
"audio_file": str(
(preview_samples_path / "audio" / "this is the content of the document.wav").absolute()
),
"other_metadata": ["other", "meta", "data"],
},
)
assert result["documents"] == [expected]
@pytest.mark.unit
def test_transcribe_with_stream(self, preview_samples_path):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
comp = RemoteWhisperTranscriber(api_key="whatever")
with patch("haystack.preview.utils.requests_utils.requests") as mocked_requests:
mocked_requests.request.return_value = mock_response
with open(preview_samples_path / "audio" / "this is the content of the document.wav", "rb") as audio_stream:
result = comp.transcribe(audio_files=[audio_stream])
expected = Document(
text="test transcription",
metadata={"audio_file": "<<binary stream>>", "other_metadata": ["other", "meta", "data"]},
)
assert result == [expected]
@pytest.mark.unit
def test_api_transcription(self, preview_samples_path):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
comp = RemoteWhisperTranscriber(api_key="whatever")
with patch("haystack.preview.utils.requests_utils.requests") as mocked_requests:
mocked_requests.request.return_value = mock_response
comp.run(audio_files=[preview_samples_path / "audio" / "this is the content of the document.wav"])
requests_params = mocked_requests.request.call_args.kwargs
requests_params.pop("files")
assert requests_params == {
"method": "post",
"url": "https://api.openai.com/v1/audio/transcriptions",
"data": {"model": "whisper-1"},
"headers": {"Authorization": "Bearer whatever"},
"timeout": OPENAI_TIMEOUT,
}
@pytest.mark.unit
def test_api_translation(self, preview_samples_path):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
comp = RemoteWhisperTranscriber(api_key="whatever")
with patch("haystack.preview.utils.requests_utils.requests") as mocked_requests:
mocked_requests.request.return_value = mock_response
comp.run(
audio_files=[preview_samples_path / "audio" / "this is the content of the document.wav"],
whisper_params={"translate": True},
)
requests_params = mocked_requests.request.call_args.kwargs
requests_params.pop("files")
assert requests_params == {
"method": "post",
"url": "https://api.openai.com/v1/audio/translations",
"data": {"model": "whisper-1"},
"headers": {"Authorization": "Bearer whatever"},
"timeout": OPENAI_TIMEOUT,
}
@pytest.mark.unit
@patch("haystack.preview.components.audio.whisper_remote.request_with_retry")
def test_default_api_base(self, mock_request, preview_samples_path):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
mock_request.return_value = mock_response
transcriber = RemoteWhisperTranscriber(api_key="just a test")
assert transcriber.api_base == "https://api.openai.com/v1"
transcriber.transcribe(audio_files=[preview_samples_path / "audio" / "this is the content of the document.wav"])
assert mock_request.call_args.kwargs["url"] == "https://api.openai.com/v1/audio/transcriptions"
@pytest.mark.unit
@patch("haystack.preview.components.audio.whisper_remote.request_with_retry")
def test_custom_api_base(self, mock_request, preview_samples_path):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
mock_request.return_value = mock_response
transcriber = RemoteWhisperTranscriber(api_key="just a test", api_base="https://fake_api_base.com")
assert transcriber.api_base == "https://fake_api_base.com"
transcriber.transcribe(audio_files=[preview_samples_path / "audio" / "this is the content of the document.wav"])
assert mock_request.call_args.kwargs["url"] == "https://fake_api_base.com/audio/transcriptions"
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
@pytest.mark.integration
def test_whisper_remote_transcriber(preview_samples_path):
comp = RemoteWhisperTranscriber(api_key=os.environ.get("OPENAI_API_KEY"))
output = comp.run(
audio_files=[
preview_samples_path / "audio" / "this is the content of the document.wav",
str((preview_samples_path / "audio" / "the context for this answer is here.wav").absolute()),
open(preview_samples_path / "audio" / "answer.wav", "rb"),
]
)
docs = output["documents"]
assert len(docs) == 3
assert docs[0].text.strip().lower() == "this is the content of the document."
assert (
preview_samples_path / "audio" / "this is the content of the document.wav" == docs[0].metadata["audio_file"]
)
assert docs[1].text.strip().lower() == "the context for this answer is here."
assert (
str((preview_samples_path / "audio" / "the context for this answer is here.wav").absolute())
== docs[1].metadata["audio_file"]
)
assert docs[2].text.strip().lower() == "answer."
assert docs[2].metadata["audio_file"] == "<<binary stream>>"