2023-09-20 14:48:09 +02:00
|
|
|
import os
|
2023-05-22 16:02:58 +02:00
|
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
from haystack.preview.dataclasses import Document
|
|
|
|
from haystack.preview.components.audio.whisper_remote import RemoteWhisperTranscriber, OPENAI_TIMEOUT
|
|
|
|
|
|
|
|
|
2023-08-23 17:03:37 +02:00
|
|
|
class TestRemoteWhisperTranscriber:
|
2023-05-22 16:02:58 +02:00
|
|
|
@pytest.mark.unit
|
|
|
|
def test_init_unknown_model(self):
|
|
|
|
with pytest.raises(ValueError, match="not recognized"):
|
|
|
|
RemoteWhisperTranscriber(model_name="anything", api_key="something")
|
|
|
|
|
|
|
|
@pytest.mark.unit
|
|
|
|
def test_init_default(self):
|
|
|
|
transcriber = RemoteWhisperTranscriber(api_key="just a test")
|
|
|
|
assert transcriber.model_name == "whisper-1"
|
|
|
|
assert transcriber.api_key == "just a test"
|
2023-06-05 11:32:06 +02:00
|
|
|
assert transcriber.api_base == "https://api.openai.com/v1"
|
2023-05-22 16:02:58 +02:00
|
|
|
|
|
|
|
@pytest.mark.unit
|
|
|
|
def test_init_no_key(self):
|
|
|
|
with pytest.raises(ValueError, match="API key is None"):
|
|
|
|
RemoteWhisperTranscriber(api_key=None)
|
|
|
|
|
2023-08-29 18:15:07 +02:00
|
|
|
@pytest.mark.unit
|
|
|
|
def test_to_dict(self):
|
|
|
|
transcriber = RemoteWhisperTranscriber(api_key="test")
|
|
|
|
data = transcriber.to_dict()
|
|
|
|
assert data == {
|
|
|
|
"type": "RemoteWhisperTranscriber",
|
|
|
|
"init_parameters": {
|
|
|
|
"model_name": "whisper-1",
|
|
|
|
"api_key": "test",
|
|
|
|
"api_base": "https://api.openai.com/v1",
|
|
|
|
"whisper_params": {},
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
@pytest.mark.unit
|
|
|
|
def test_to_dict_with_custom_init_parameters(self):
|
|
|
|
transcriber = RemoteWhisperTranscriber(
|
|
|
|
api_key="test",
|
|
|
|
model_name="whisper-1",
|
|
|
|
api_base="https://my.api.base/something_else/v3",
|
|
|
|
whisper_params={"return_segments": True, "temperature": [0.1, 0.6, 0.8]},
|
|
|
|
)
|
|
|
|
data = transcriber.to_dict()
|
|
|
|
assert data == {
|
|
|
|
"type": "RemoteWhisperTranscriber",
|
|
|
|
"init_parameters": {
|
|
|
|
"model_name": "whisper-1",
|
|
|
|
"api_key": "test",
|
|
|
|
"api_base": "https://my.api.base/something_else/v3",
|
|
|
|
"whisper_params": {"return_segments": True, "temperature": [0.1, 0.6, 0.8]},
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
@pytest.mark.unit
|
|
|
|
def test_from_dict(self):
|
|
|
|
data = {
|
|
|
|
"type": "RemoteWhisperTranscriber",
|
|
|
|
"init_parameters": {
|
|
|
|
"model_name": "whisper-1",
|
|
|
|
"api_key": "test",
|
|
|
|
"api_base": "https://my.api.base/something_else/v3",
|
|
|
|
"whisper_params": {"return_segments": True, "temperature": [0.1, 0.6, 0.8]},
|
|
|
|
},
|
|
|
|
}
|
|
|
|
transcriber = RemoteWhisperTranscriber.from_dict(data)
|
|
|
|
assert transcriber.model_name == "whisper-1"
|
|
|
|
assert transcriber.api_key == "test"
|
|
|
|
assert transcriber.api_base == "https://my.api.base/something_else/v3"
|
|
|
|
assert transcriber.whisper_params == {"return_segments": True, "temperature": [0.1, 0.6, 0.8]}
|
|
|
|
|
2023-05-22 16:02:58 +02:00
|
|
|
@pytest.mark.unit
|
|
|
|
def test_run_with_path(self, preview_samples_path):
|
|
|
|
mock_response = MagicMock()
|
|
|
|
mock_response.status_code = 200
|
|
|
|
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
|
|
|
|
comp = RemoteWhisperTranscriber(api_key="whatever")
|
|
|
|
|
2023-08-16 12:45:28 +02:00
|
|
|
with patch("haystack.preview.utils.requests_utils.requests") as mocked_requests:
|
2023-05-22 16:02:58 +02:00
|
|
|
mocked_requests.request.return_value = mock_response
|
|
|
|
|
2023-08-09 15:51:32 +02:00
|
|
|
result = comp.run(audio_files=[preview_samples_path / "audio" / "this is the content of the document.wav"])
|
2023-05-22 16:02:58 +02:00
|
|
|
expected = Document(
|
2023-09-11 16:40:00 +01:00
|
|
|
text="test transcription",
|
2023-05-22 16:02:58 +02:00
|
|
|
metadata={
|
|
|
|
"audio_file": preview_samples_path / "audio" / "this is the content of the document.wav",
|
|
|
|
"other_metadata": ["other", "meta", "data"],
|
|
|
|
},
|
|
|
|
)
|
2023-08-09 15:51:32 +02:00
|
|
|
assert result["documents"] == [expected]
|
2023-05-22 16:02:58 +02:00
|
|
|
|
|
|
|
@pytest.mark.unit
|
|
|
|
def test_run_with_str(self, preview_samples_path):
|
|
|
|
mock_response = MagicMock()
|
|
|
|
mock_response.status_code = 200
|
|
|
|
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
|
|
|
|
comp = RemoteWhisperTranscriber(api_key="whatever")
|
|
|
|
|
2023-08-16 12:45:28 +02:00
|
|
|
with patch("haystack.preview.utils.requests_utils.requests") as mocked_requests:
|
2023-05-22 16:02:58 +02:00
|
|
|
mocked_requests.request.return_value = mock_response
|
|
|
|
|
|
|
|
result = comp.run(
|
2023-08-09 15:51:32 +02:00
|
|
|
audio_files=[
|
|
|
|
str((preview_samples_path / "audio" / "this is the content of the document.wav").absolute())
|
|
|
|
]
|
2023-05-22 16:02:58 +02:00
|
|
|
)
|
|
|
|
expected = Document(
|
2023-09-11 16:40:00 +01:00
|
|
|
text="test transcription",
|
2023-05-22 16:02:58 +02:00
|
|
|
metadata={
|
|
|
|
"audio_file": str(
|
|
|
|
(preview_samples_path / "audio" / "this is the content of the document.wav").absolute()
|
|
|
|
),
|
|
|
|
"other_metadata": ["other", "meta", "data"],
|
|
|
|
},
|
|
|
|
)
|
2023-08-09 15:51:32 +02:00
|
|
|
assert result["documents"] == [expected]
|
2023-05-22 16:02:58 +02:00
|
|
|
|
|
|
|
@pytest.mark.unit
|
|
|
|
def test_transcribe_with_stream(self, preview_samples_path):
|
|
|
|
mock_response = MagicMock()
|
|
|
|
mock_response.status_code = 200
|
|
|
|
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
|
|
|
|
comp = RemoteWhisperTranscriber(api_key="whatever")
|
|
|
|
|
2023-08-16 12:45:28 +02:00
|
|
|
with patch("haystack.preview.utils.requests_utils.requests") as mocked_requests:
|
2023-05-22 16:02:58 +02:00
|
|
|
mocked_requests.request.return_value = mock_response
|
|
|
|
|
|
|
|
with open(preview_samples_path / "audio" / "this is the content of the document.wav", "rb") as audio_stream:
|
|
|
|
result = comp.transcribe(audio_files=[audio_stream])
|
|
|
|
expected = Document(
|
2023-09-11 16:40:00 +01:00
|
|
|
text="test transcription",
|
2023-05-22 16:02:58 +02:00
|
|
|
metadata={"audio_file": "<<binary stream>>", "other_metadata": ["other", "meta", "data"]},
|
|
|
|
)
|
|
|
|
assert result == [expected]
|
|
|
|
|
|
|
|
@pytest.mark.unit
|
|
|
|
def test_api_transcription(self, preview_samples_path):
|
|
|
|
mock_response = MagicMock()
|
|
|
|
mock_response.status_code = 200
|
|
|
|
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
|
|
|
|
comp = RemoteWhisperTranscriber(api_key="whatever")
|
|
|
|
|
2023-08-16 12:45:28 +02:00
|
|
|
with patch("haystack.preview.utils.requests_utils.requests") as mocked_requests:
|
2023-05-22 16:02:58 +02:00
|
|
|
mocked_requests.request.return_value = mock_response
|
|
|
|
|
2023-08-09 15:51:32 +02:00
|
|
|
comp.run(audio_files=[preview_samples_path / "audio" / "this is the content of the document.wav"])
|
2023-05-22 16:02:58 +02:00
|
|
|
requests_params = mocked_requests.request.call_args.kwargs
|
|
|
|
requests_params.pop("files")
|
|
|
|
assert requests_params == {
|
|
|
|
"method": "post",
|
|
|
|
"url": "https://api.openai.com/v1/audio/transcriptions",
|
|
|
|
"data": {"model": "whisper-1"},
|
2023-09-15 18:30:33 +02:00
|
|
|
"headers": {"Authorization": "Bearer whatever"},
|
2023-05-22 16:02:58 +02:00
|
|
|
"timeout": OPENAI_TIMEOUT,
|
|
|
|
}
|
|
|
|
|
|
|
|
@pytest.mark.unit
|
|
|
|
def test_api_translation(self, preview_samples_path):
|
|
|
|
mock_response = MagicMock()
|
|
|
|
mock_response.status_code = 200
|
|
|
|
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
|
|
|
|
comp = RemoteWhisperTranscriber(api_key="whatever")
|
|
|
|
|
2023-08-16 12:45:28 +02:00
|
|
|
with patch("haystack.preview.utils.requests_utils.requests") as mocked_requests:
|
2023-05-22 16:02:58 +02:00
|
|
|
mocked_requests.request.return_value = mock_response
|
|
|
|
|
|
|
|
comp.run(
|
2023-08-09 15:51:32 +02:00
|
|
|
audio_files=[preview_samples_path / "audio" / "this is the content of the document.wav"],
|
|
|
|
whisper_params={"translate": True},
|
2023-05-22 16:02:58 +02:00
|
|
|
)
|
|
|
|
requests_params = mocked_requests.request.call_args.kwargs
|
|
|
|
requests_params.pop("files")
|
|
|
|
assert requests_params == {
|
|
|
|
"method": "post",
|
|
|
|
"url": "https://api.openai.com/v1/audio/translations",
|
|
|
|
"data": {"model": "whisper-1"},
|
2023-09-15 18:30:33 +02:00
|
|
|
"headers": {"Authorization": "Bearer whatever"},
|
2023-05-22 16:02:58 +02:00
|
|
|
"timeout": OPENAI_TIMEOUT,
|
|
|
|
}
|
2023-06-05 11:32:06 +02:00
|
|
|
|
|
|
|
@pytest.mark.unit
|
|
|
|
@patch("haystack.preview.components.audio.whisper_remote.request_with_retry")
|
|
|
|
def test_default_api_base(self, mock_request, preview_samples_path):
|
|
|
|
mock_response = MagicMock()
|
|
|
|
mock_response.status_code = 200
|
|
|
|
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
|
|
|
|
mock_request.return_value = mock_response
|
|
|
|
|
|
|
|
transcriber = RemoteWhisperTranscriber(api_key="just a test")
|
|
|
|
assert transcriber.api_base == "https://api.openai.com/v1"
|
|
|
|
|
|
|
|
transcriber.transcribe(audio_files=[preview_samples_path / "audio" / "this is the content of the document.wav"])
|
|
|
|
assert mock_request.call_args.kwargs["url"] == "https://api.openai.com/v1/audio/transcriptions"
|
|
|
|
|
|
|
|
@pytest.mark.unit
|
|
|
|
@patch("haystack.preview.components.audio.whisper_remote.request_with_retry")
|
|
|
|
def test_custom_api_base(self, mock_request, preview_samples_path):
|
|
|
|
mock_response = MagicMock()
|
|
|
|
mock_response.status_code = 200
|
|
|
|
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
|
|
|
|
mock_request.return_value = mock_response
|
|
|
|
|
|
|
|
transcriber = RemoteWhisperTranscriber(api_key="just a test", api_base="https://fake_api_base.com")
|
|
|
|
assert transcriber.api_base == "https://fake_api_base.com"
|
|
|
|
|
|
|
|
transcriber.transcribe(audio_files=[preview_samples_path / "audio" / "this is the content of the document.wav"])
|
|
|
|
assert mock_request.call_args.kwargs["url"] == "https://fake_api_base.com/audio/transcriptions"
|
2023-09-20 14:48:09 +02:00
|
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
|
|
not os.environ.get("OPENAI_API_KEY", None),
|
|
|
|
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
|
|
|
)
|
|
|
|
@pytest.mark.integration
|
|
|
|
def test_whisper_remote_transcriber(preview_samples_path):
|
|
|
|
comp = RemoteWhisperTranscriber(api_key=os.environ.get("OPENAI_API_KEY"))
|
|
|
|
|
|
|
|
output = comp.run(
|
|
|
|
audio_files=[
|
|
|
|
preview_samples_path / "audio" / "this is the content of the document.wav",
|
|
|
|
str((preview_samples_path / "audio" / "the context for this answer is here.wav").absolute()),
|
|
|
|
open(preview_samples_path / "audio" / "answer.wav", "rb"),
|
|
|
|
]
|
|
|
|
)
|
|
|
|
docs = output["documents"]
|
|
|
|
assert len(docs) == 3
|
|
|
|
|
|
|
|
assert docs[0].text.strip().lower() == "this is the content of the document."
|
|
|
|
assert (
|
|
|
|
preview_samples_path / "audio" / "this is the content of the document.wav" == docs[0].metadata["audio_file"]
|
|
|
|
)
|
|
|
|
|
|
|
|
assert docs[1].text.strip().lower() == "the context for this answer is here."
|
|
|
|
assert (
|
|
|
|
str((preview_samples_path / "audio" / "the context for this answer is here.wav").absolute())
|
|
|
|
== docs[1].metadata["audio_file"]
|
|
|
|
)
|
|
|
|
|
|
|
|
assert docs[2].text.strip().lower() == "answer."
|
|
|
|
assert docs[2].metadata["audio_file"] == "<<binary stream>>"
|