mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-18 03:18:42 +00:00
Enable Opensearch unit tests in Windows CI (#2936)
* enable Opensearch unit tests under Win * move unit tests into a dedicated job * skip audio tests on missing dependencies * avoid failing test collection when soundfile is not available * Update .github/workflows/tests.yml Co-authored-by: Sara Zan <sara.zanzottera@deepset.ai> Co-authored-by: Sara Zan <sara.zanzottera@deepset.ai>
This commit is contained in:
parent
1b238c880b
commit
40d07c2038
22
.github/workflows/tests.yml
vendored
22
.github/workflows/tests.yml
vendored
@ -75,6 +75,27 @@ jobs:
|
||||
run: |
|
||||
pylint -ry -j 0 haystack/ rest_api/ ui/
|
||||
|
||||
unit-tests:
|
||||
name: Unit / ${{ matrix.os }}
|
||||
needs:
|
||||
- mypy
|
||||
- pylint
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- name: Setup Python
|
||||
uses: ./.github/actions/python_cache/
|
||||
|
||||
- name: Install Haystack
|
||||
run: pip install .[all]
|
||||
|
||||
- name: Run
|
||||
run: pytest -m "unit" test/
|
||||
|
||||
unit-tests-linux:
|
||||
needs:
|
||||
@ -88,7 +109,6 @@ jobs:
|
||||
- "pipelines"
|
||||
- "modeling"
|
||||
- "others"
|
||||
- "document_stores/test_opensearch.py"
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
|
@ -90,6 +90,7 @@ minversion = "6.0"
|
||||
addopts = "--strict-markers"
|
||||
markers = [
|
||||
"integration: integration tests",
|
||||
"unit: unit tests",
|
||||
|
||||
"generator: generator tests",
|
||||
"summarizer: summarizer tests",
|
||||
|
@ -18,14 +18,9 @@ from haystack.document_stores.opensearch import (
|
||||
from haystack.schema import Document, Label, Answer
|
||||
from haystack.errors import DocumentStoreError
|
||||
|
||||
|
||||
# Skip OpenSearchDocumentStore tests on Windows
|
||||
pytestmark = pytest.mark.skipif(sys.platform in ["win32", "cygwin"], reason="Opensearch not running on Windows CI")
|
||||
|
||||
# Being all the tests in this module, ideally we wouldn't need a marker here,
|
||||
# but this is to allow this test suite to be skipped when running (e.g.)
|
||||
# `pytest test/document_stores --document-store-type=faiss`
|
||||
@pytest.mark.opensearch
|
||||
class TestOpenSearchDocumentStore:
|
||||
|
||||
# Constants
|
||||
@ -210,6 +205,7 @@ class TestOpenSearchDocumentStore:
|
||||
|
||||
# Unit tests
|
||||
|
||||
@pytest.mark.unit
|
||||
def test___init___api_key_raises_warning(self, mocked_document_store, caplog):
|
||||
with caplog.at_level(logging.WARN, logger="haystack.document_stores.opensearch"):
|
||||
mocked_document_store.__init__(api_key="foo")
|
||||
@ -220,6 +216,7 @@ class TestOpenSearchDocumentStore:
|
||||
for r in caplog.records:
|
||||
assert r.levelname == "WARNING"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test___init___connection_test_fails(self, mocked_document_store):
|
||||
failing_client = MagicMock()
|
||||
failing_client.indices.get.side_effect = Exception("The client failed!")
|
||||
@ -227,6 +224,7 @@ class TestOpenSearchDocumentStore:
|
||||
with pytest.raises(ConnectionError):
|
||||
mocked_document_store.__init__()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test___init___client_params(self, mocked_open_search_init, _init_client_params):
|
||||
"""
|
||||
Ensure the Opensearch-py client was initialized with the right params
|
||||
@ -244,18 +242,21 @@ class TestOpenSearchDocumentStore:
|
||||
"connection_class": RequestsHttpConnection,
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test__init_client_use_system_proxy_use_sys_proxy(self, mocked_open_search_init, _init_client_params):
|
||||
_init_client_params["use_system_proxy"] = False
|
||||
OpenSearchDocumentStore._init_client(**_init_client_params)
|
||||
_, kwargs = mocked_open_search_init.call_args
|
||||
assert kwargs["connection_class"] == Urllib3HttpConnection
|
||||
|
||||
@pytest.mark.unit
|
||||
def test__init_client_use_system_proxy_dont_use_sys_proxy(self, mocked_open_search_init, _init_client_params):
|
||||
_init_client_params["use_system_proxy"] = True
|
||||
OpenSearchDocumentStore._init_client(**_init_client_params)
|
||||
_, kwargs = mocked_open_search_init.call_args
|
||||
assert kwargs["connection_class"] == RequestsHttpConnection
|
||||
|
||||
@pytest.mark.unit
|
||||
def test__init_client_auth_methods_username_password(self, mocked_open_search_init, _init_client_params):
|
||||
_init_client_params["username"] = "user"
|
||||
_init_client_params["aws4auth"] = None
|
||||
@ -263,6 +264,7 @@ class TestOpenSearchDocumentStore:
|
||||
_, kwargs = mocked_open_search_init.call_args
|
||||
assert kwargs["http_auth"] == ("user", "pass")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test__init_client_auth_methods_aws_iam(self, mocked_open_search_init, _init_client_params):
|
||||
_init_client_params["username"] = ""
|
||||
_init_client_params["aws4auth"] = "foo"
|
||||
@ -270,6 +272,7 @@ class TestOpenSearchDocumentStore:
|
||||
_, kwargs = mocked_open_search_init.call_args
|
||||
assert kwargs["http_auth"] == "foo"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test__init_client_auth_methods_no_auth(self, mocked_open_search_init, _init_client_params):
|
||||
_init_client_params["username"] = ""
|
||||
_init_client_params["aws4auth"] = None
|
||||
@ -277,11 +280,13 @@ class TestOpenSearchDocumentStore:
|
||||
_, kwargs = mocked_open_search_init.call_args
|
||||
assert "http_auth" not in kwargs
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_query_by_embedding_raises_if_missing_field(self, mocked_document_store):
|
||||
mocked_document_store.embedding_field = ""
|
||||
with pytest.raises(DocumentStoreError):
|
||||
mocked_document_store.query_by_embedding(self.query_emb)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_query_by_embedding_filters(self, mocked_document_store):
|
||||
expected_filters = {"type": "article", "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}}
|
||||
mocked_document_store.query_by_embedding(self.query_emb, filters=expected_filters)
|
||||
@ -293,6 +298,7 @@ class TestOpenSearchDocumentStore:
|
||||
{"range": {"date": {"gte": "2015-01-01", "lt": "2021-01-01"}}},
|
||||
]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_query_by_embedding_return_embedding_false(self, mocked_document_store):
|
||||
mocked_document_store.return_embedding = False
|
||||
mocked_document_store.query_by_embedding(self.query_emb)
|
||||
@ -300,6 +306,7 @@ class TestOpenSearchDocumentStore:
|
||||
_, kwargs = mocked_document_store.client.search.call_args
|
||||
assert kwargs["body"]["_source"] == {"excludes": ["embedding"]}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_query_by_embedding_excluded_meta_data_return_embedding_true(self, mocked_document_store):
|
||||
"""
|
||||
Test that when `return_embedding==True` the field should NOT be excluded even if it
|
||||
@ -312,6 +319,7 @@ class TestOpenSearchDocumentStore:
|
||||
# we expect "embedding" was removed from the final query
|
||||
assert kwargs["body"]["_source"] == {"excludes": ["foo"]}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_query_by_embedding_excluded_meta_data_return_embedding_false(self, mocked_document_store):
|
||||
"""
|
||||
Test that when `return_embedding==False`, the final query excludes the `embedding` field
|
||||
@ -324,6 +332,7 @@ class TestOpenSearchDocumentStore:
|
||||
_, kwargs = mocked_document_store.client.search.call_args
|
||||
assert kwargs["body"]["_source"] == {"excludes": ["foo", "embedding"]}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test__create_document_index_with_alias(self, mocked_document_store, caplog):
|
||||
mocked_document_store.client.indices.exists_alias.return_value = True
|
||||
|
||||
@ -332,6 +341,7 @@ class TestOpenSearchDocumentStore:
|
||||
|
||||
assert f"Index name {self.index_name} is an alias." in caplog.text
|
||||
|
||||
@pytest.mark.unit
|
||||
def test__create_document_index_wrong_mapping_raises(self, mocked_document_store, index):
|
||||
"""
|
||||
Ensure the method raises if we specify a field in `search_fields` that's not text
|
||||
@ -342,6 +352,7 @@ class TestOpenSearchDocumentStore:
|
||||
with pytest.raises(Exception, match=f"The search_field 'age' of index '{self.index_name}' with type 'integer'"):
|
||||
mocked_document_store._create_document_index(self.index_name)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test__create_document_index_create_mapping_if_missing(self, mocked_document_store, index):
|
||||
mocked_document_store.client.indices.exists.return_value = True
|
||||
mocked_document_store.client.indices.get.return_value = {self.index_name: index}
|
||||
@ -354,6 +365,7 @@ class TestOpenSearchDocumentStore:
|
||||
assert kwargs["index"] == self.index_name
|
||||
assert "doesnt_have_a_mapping" in kwargs["body"]["properties"]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test__create_document_index_with_bad_field_raises(self, mocked_document_store, index):
|
||||
mocked_document_store.client.indices.exists.return_value = True
|
||||
mocked_document_store.client.indices.get.return_value = {self.index_name: index}
|
||||
@ -364,6 +376,7 @@ class TestOpenSearchDocumentStore:
|
||||
):
|
||||
mocked_document_store._create_document_index(self.index_name)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test__create_document_index_with_existing_mapping_but_no_method(self, mocked_document_store, index):
|
||||
"""
|
||||
We call the method passing a properly mapped field but without the `method` specified in the mapping
|
||||
@ -381,6 +394,7 @@ class TestOpenSearchDocumentStore:
|
||||
# False but I'm not sure this is by design
|
||||
assert mocked_document_store.embeddings_field_supports_similarity is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test__create_document_index_with_existing_mapping_similarity(self, mocked_document_store, index):
|
||||
mocked_document_store.client.indices.exists.return_value = True
|
||||
mocked_document_store.client.indices.get.return_value = {self.index_name: index}
|
||||
@ -390,6 +404,7 @@ class TestOpenSearchDocumentStore:
|
||||
mocked_document_store._create_document_index(self.index_name)
|
||||
assert mocked_document_store.embeddings_field_supports_similarity is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test__create_document_index_with_existing_mapping_similarity_mismatch(
|
||||
self, mocked_document_store, index, caplog
|
||||
):
|
||||
@ -403,6 +418,7 @@ class TestOpenSearchDocumentStore:
|
||||
assert "Embedding field 'vec' is optimized for similarity 'dot_product'." in caplog.text
|
||||
assert mocked_document_store.embeddings_field_supports_similarity is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test__create_document_index_with_existing_mapping_adjust_params_hnsw_default(
|
||||
self, mocked_document_store, index
|
||||
):
|
||||
@ -420,6 +436,7 @@ class TestOpenSearchDocumentStore:
|
||||
_, kwargs = mocked_document_store.client.indices.put_settings.call_args
|
||||
assert kwargs["body"] == {"knn.algo_param.ef_search": 20}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test__create_document_index_with_existing_mapping_adjust_params_hnsw(self, mocked_document_store, index):
|
||||
"""
|
||||
Test a value of `knn.algo_param` that needs to be adjusted
|
||||
@ -436,6 +453,7 @@ class TestOpenSearchDocumentStore:
|
||||
_, kwargs = mocked_document_store.client.indices.put_settings.call_args
|
||||
assert kwargs["body"] == {"knn.algo_param.ef_search": 20}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test__create_document_index_with_existing_mapping_adjust_params_flat_default(
|
||||
self, mocked_document_store, index
|
||||
):
|
||||
@ -451,6 +469,7 @@ class TestOpenSearchDocumentStore:
|
||||
|
||||
mocked_document_store.client.indices.put_settings.assert_not_called
|
||||
|
||||
@pytest.mark.unit
|
||||
def test__create_document_index_with_existing_mapping_adjust_params_hnsw(self, mocked_document_store, index):
|
||||
"""
|
||||
Test a value of `knn.algo_param` that needs to be adjusted
|
||||
@ -467,6 +486,7 @@ class TestOpenSearchDocumentStore:
|
||||
_, kwargs = mocked_document_store.client.indices.put_settings.call_args
|
||||
assert kwargs["body"] == {"knn.algo_param.ef_search": 512}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test__create_document_index_no_index_custom_mapping(self, mocked_document_store):
|
||||
mocked_document_store.client.indices.exists.return_value = False
|
||||
mocked_document_store.custom_mapping = {"mappings": {"properties": {"a_number": {"type": "integer"}}}}
|
||||
@ -475,6 +495,7 @@ class TestOpenSearchDocumentStore:
|
||||
_, kwargs = mocked_document_store.client.indices.create.call_args
|
||||
assert kwargs["body"] == {"mappings": {"properties": {"a_number": {"type": "integer"}}}}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test__create_document_index_no_index_no_mapping(self, mocked_document_store):
|
||||
mocked_document_store.client.indices.exists.return_value = False
|
||||
mocked_document_store._create_document_index(self.index_name)
|
||||
@ -502,6 +523,7 @@ class TestOpenSearchDocumentStore:
|
||||
"settings": {"analysis": {"analyzer": {"default": {"type": "standard"}}}, "index": {"knn": True}},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test__create_document_index_no_index_no_mapping_with_synonyms(self, mocked_document_store):
|
||||
mocked_document_store.client.indices.exists.return_value = False
|
||||
mocked_document_store.search_fields = ["occupation"]
|
||||
@ -542,6 +564,7 @@ class TestOpenSearchDocumentStore:
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test__create_document_index_no_index_no_mapping_with_embedding_field(self, mocked_document_store):
|
||||
mocked_document_store.client.indices.exists.return_value = False
|
||||
mocked_document_store.embedding_field = "vec"
|
||||
@ -575,6 +598,7 @@ class TestOpenSearchDocumentStore:
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test__create_document_index_client_failure(self, mocked_document_store):
|
||||
mocked_document_store.client.indices.exists.return_value = False
|
||||
mocked_document_store.client.indices.create.side_effect = RequestError
|
||||
@ -582,6 +606,7 @@ class TestOpenSearchDocumentStore:
|
||||
with pytest.raises(RequestError):
|
||||
mocked_document_store._create_document_index(self.index_name)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test__get_embedding_field_mapping_flat(self, mocked_document_store):
|
||||
mocked_document_store.index_type = "flat"
|
||||
|
||||
@ -596,6 +621,7 @@ class TestOpenSearchDocumentStore:
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test__get_embedding_field_mapping_hnsw(self, mocked_document_store):
|
||||
mocked_document_store.index_type = "hnsw"
|
||||
|
||||
@ -610,6 +636,7 @@ class TestOpenSearchDocumentStore:
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test__get_embedding_field_mapping_wrong(self, mocked_document_store, caplog):
|
||||
mocked_document_store.index_type = "foo"
|
||||
|
||||
@ -623,12 +650,14 @@ class TestOpenSearchDocumentStore:
|
||||
"method": {"space_type": "innerproduct", "name": "hnsw", "engine": "nmslib"},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test__create_label_index_already_exists(self, mocked_document_store):
|
||||
mocked_document_store.client.indices.exists.return_value = True
|
||||
|
||||
mocked_document_store._create_label_index("foo")
|
||||
mocked_document_store.client.indices.create.assert_not_called()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test__create_label_index_client_error(self, mocked_document_store):
|
||||
mocked_document_store.client.indices.exists.return_value = False
|
||||
mocked_document_store.client.indices.create.side_effect = RequestError
|
||||
@ -636,6 +665,7 @@ class TestOpenSearchDocumentStore:
|
||||
with pytest.raises(RequestError):
|
||||
mocked_document_store._create_label_index("foo")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test__get_vector_similarity_query_support_true(self, mocked_document_store):
|
||||
mocked_document_store.embedding_field = "FooField"
|
||||
mocked_document_store.embeddings_field_supports_similarity = True
|
||||
@ -644,6 +674,7 @@ class TestOpenSearchDocumentStore:
|
||||
"bool": {"must": [{"knn": {"FooField": {"vector": self.query_emb.tolist(), "k": 3}}}]}
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test__get_vector_similarity_query_support_false(self, mocked_document_store):
|
||||
mocked_document_store.embedding_field = "FooField"
|
||||
mocked_document_store.embeddings_field_supports_similarity = False
|
||||
@ -664,15 +695,18 @@ class TestOpenSearchDocumentStore:
|
||||
}
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test__get_raw_similarity_score_dot(self, mocked_document_store):
|
||||
mocked_document_store.similarity = "dot_product"
|
||||
assert mocked_document_store._get_raw_similarity_score(2) == 1
|
||||
assert mocked_document_store._get_raw_similarity_score(-2) == 1.5
|
||||
|
||||
@pytest.mark.unit
|
||||
def test__get_raw_similarity_score_l2(self, mocked_document_store):
|
||||
mocked_document_store.similarity = "l2"
|
||||
assert mocked_document_store._get_raw_similarity_score(1) == 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test__get_raw_similarity_score_cosine(self, mocked_document_store):
|
||||
mocked_document_store.similarity = "cosine"
|
||||
mocked_document_store.embeddings_field_supports_similarity = True
|
||||
@ -680,12 +714,14 @@ class TestOpenSearchDocumentStore:
|
||||
mocked_document_store.embeddings_field_supports_similarity = False
|
||||
assert mocked_document_store._get_raw_similarity_score(1) == 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_clone_embedding_field_duplicate_mapping(self, mocked_document_store, index):
|
||||
mocked_document_store.client.indices.get.return_value = {self.index_name: index}
|
||||
mocked_document_store.index = self.index_name
|
||||
with pytest.raises(Exception, match="age already exists with mapping"):
|
||||
mocked_document_store.clone_embedding_field("age", "cosine")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_clone_embedding_field_update_mapping(self, mocked_document_store, index, monkeypatch):
|
||||
mocked_document_store.client.indices.get.return_value = {self.index_name: index}
|
||||
mocked_document_store.index = self.index_name
|
||||
@ -709,6 +745,7 @@ class TestOpenSearchDocumentStore:
|
||||
|
||||
|
||||
class TestOpenDistroElasticsearchDocumentStore:
|
||||
@pytest.mark.unit
|
||||
def test_deprecation_notice(self, monkeypatch, caplog):
|
||||
klass = OpenDistroElasticsearchDocumentStore
|
||||
monkeypatch.setattr(klass, "_init_client", MagicMock())
|
||||
|
@ -1,7 +1,14 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
|
||||
try:
|
||||
import soundfile as sf
|
||||
|
||||
soundfile_not_found = False
|
||||
except:
|
||||
soundfile_not_found = True
|
||||
|
||||
from haystack.schema import Span, Answer, SpeechAnswer, Document, SpeechDocument
|
||||
from haystack.nodes.audio import AnswerToSpeech, DocumentToSpeech
|
||||
@ -10,111 +17,108 @@ from haystack.nodes.audio._text_to_speech import TextToSpeech
|
||||
from ..conftest import SAMPLES_PATH
|
||||
|
||||
|
||||
def test_text_to_speech_audio_data():
|
||||
text2speech = TextToSpeech(
|
||||
model_name_or_path="espnet/kan-bayashi_ljspeech_vits",
|
||||
transformers_params={"seed": 777, "always_fix_seed": True},
|
||||
)
|
||||
expected_audio_data, _ = sf.read(SAMPLES_PATH / "audio" / "answer.wav")
|
||||
audio_data = text2speech.text_to_audio_data(text="answer")
|
||||
@pytest.mark.skipif(soundfile_not_found, reason="soundfile not found")
|
||||
class TestTextToSpeech:
|
||||
def test_text_to_speech_audio_data(self):
|
||||
text2speech = TextToSpeech(
|
||||
model_name_or_path="espnet/kan-bayashi_ljspeech_vits",
|
||||
transformers_params={"seed": 777, "always_fix_seed": True},
|
||||
)
|
||||
expected_audio_data, _ = sf.read(SAMPLES_PATH / "audio" / "answer.wav")
|
||||
audio_data = text2speech.text_to_audio_data(text="answer")
|
||||
|
||||
assert np.allclose(expected_audio_data, audio_data, atol=0.001)
|
||||
assert np.allclose(expected_audio_data, audio_data, atol=0.001)
|
||||
|
||||
def test_text_to_speech_audio_file(self, tmp_path):
|
||||
text2speech = TextToSpeech(
|
||||
model_name_or_path="espnet/kan-bayashi_ljspeech_vits",
|
||||
transformers_params={"seed": 777, "always_fix_seed": True},
|
||||
)
|
||||
expected_audio_data, _ = sf.read(SAMPLES_PATH / "audio" / "answer.wav")
|
||||
audio_file = text2speech.text_to_audio_file(text="answer", generated_audio_dir=tmp_path / "test_audio")
|
||||
assert os.path.exists(audio_file)
|
||||
assert np.allclose(expected_audio_data, sf.read(audio_file)[0], atol=0.001)
|
||||
|
||||
def test_text_to_speech_audio_file(tmp_path):
|
||||
text2speech = TextToSpeech(
|
||||
model_name_or_path="espnet/kan-bayashi_ljspeech_vits",
|
||||
transformers_params={"seed": 777, "always_fix_seed": True},
|
||||
)
|
||||
expected_audio_data, _ = sf.read(SAMPLES_PATH / "audio" / "answer.wav")
|
||||
audio_file = text2speech.text_to_audio_file(text="answer", generated_audio_dir=tmp_path / "test_audio")
|
||||
assert os.path.exists(audio_file)
|
||||
assert np.allclose(expected_audio_data, sf.read(audio_file)[0], atol=0.001)
|
||||
def test_text_to_speech_compress_audio(self, tmp_path):
|
||||
text2speech = TextToSpeech(
|
||||
model_name_or_path="espnet/kan-bayashi_ljspeech_vits",
|
||||
transformers_params={"seed": 777, "always_fix_seed": True},
|
||||
)
|
||||
expected_audio_file = SAMPLES_PATH / "audio" / "answer.wav"
|
||||
audio_file = text2speech.text_to_audio_file(
|
||||
text="answer", generated_audio_dir=tmp_path / "test_audio", audio_format="mp3"
|
||||
)
|
||||
assert os.path.exists(audio_file)
|
||||
assert audio_file.suffix == ".mp3"
|
||||
# FIXME find a way to make sure the compressed audio is similar enough to the wav version.
|
||||
# At a manual inspection, the code seems to be working well.
|
||||
|
||||
def test_text_to_speech_naming_function(self, tmp_path):
|
||||
text2speech = TextToSpeech(
|
||||
model_name_or_path="espnet/kan-bayashi_ljspeech_vits",
|
||||
transformers_params={"seed": 777, "always_fix_seed": True},
|
||||
)
|
||||
expected_audio_file = SAMPLES_PATH / "audio" / "answer.wav"
|
||||
audio_file = text2speech.text_to_audio_file(
|
||||
text="answer", generated_audio_dir=tmp_path / "test_audio", audio_naming_function=lambda text: text
|
||||
)
|
||||
assert os.path.exists(audio_file)
|
||||
assert audio_file.name == expected_audio_file.name
|
||||
assert np.allclose(sf.read(expected_audio_file)[0], sf.read(audio_file)[0], atol=0.001)
|
||||
|
||||
def test_text_to_speech_compress_audio(tmp_path):
|
||||
text2speech = TextToSpeech(
|
||||
model_name_or_path="espnet/kan-bayashi_ljspeech_vits",
|
||||
transformers_params={"seed": 777, "always_fix_seed": True},
|
||||
)
|
||||
expected_audio_file = SAMPLES_PATH / "audio" / "answer.wav"
|
||||
audio_file = text2speech.text_to_audio_file(
|
||||
text="answer", generated_audio_dir=tmp_path / "test_audio", audio_format="mp3"
|
||||
)
|
||||
assert os.path.exists(audio_file)
|
||||
assert audio_file.suffix == ".mp3"
|
||||
# FIXME find a way to make sure the compressed audio is similar enough to the wav version.
|
||||
# At a manual inspection, the code seems to be working well.
|
||||
def test_answer_to_speech(self, tmp_path):
|
||||
text_answer = Answer(
|
||||
answer="answer",
|
||||
type="extractive",
|
||||
context="the context for this answer is here",
|
||||
offsets_in_document=[Span(31, 37)],
|
||||
offsets_in_context=[Span(21, 27)],
|
||||
meta={"some_meta": "some_value"},
|
||||
)
|
||||
expected_audio_answer = SAMPLES_PATH / "audio" / "answer.wav"
|
||||
expected_audio_context = SAMPLES_PATH / "audio" / "the context for this answer is here.wav"
|
||||
|
||||
answer2speech = AnswerToSpeech(
|
||||
generated_audio_dir=tmp_path / "test_audio",
|
||||
audio_params={"audio_naming_function": lambda text: text},
|
||||
transformers_params={"seed": 777, "always_fix_seed": True},
|
||||
)
|
||||
results, _ = answer2speech.run(answers=[text_answer])
|
||||
|
||||
def test_text_to_speech_naming_function(tmp_path):
|
||||
text2speech = TextToSpeech(
|
||||
model_name_or_path="espnet/kan-bayashi_ljspeech_vits",
|
||||
transformers_params={"seed": 777, "always_fix_seed": True},
|
||||
)
|
||||
expected_audio_file = SAMPLES_PATH / "audio" / "answer.wav"
|
||||
audio_file = text2speech.text_to_audio_file(
|
||||
text="answer", generated_audio_dir=tmp_path / "test_audio", audio_naming_function=lambda text: text
|
||||
)
|
||||
assert os.path.exists(audio_file)
|
||||
assert audio_file.name == expected_audio_file.name
|
||||
assert np.allclose(sf.read(expected_audio_file)[0], sf.read(audio_file)[0], atol=0.001)
|
||||
audio_answer: SpeechAnswer = results["answers"][0]
|
||||
assert isinstance(audio_answer, SpeechAnswer)
|
||||
assert audio_answer.type == "generative"
|
||||
assert audio_answer.answer_audio.name == expected_audio_answer.name
|
||||
assert audio_answer.context_audio.name == expected_audio_context.name
|
||||
assert audio_answer.answer == "answer"
|
||||
assert audio_answer.context == "the context for this answer is here"
|
||||
assert audio_answer.offsets_in_document == [Span(31, 37)]
|
||||
assert audio_answer.offsets_in_context == [Span(21, 27)]
|
||||
assert audio_answer.meta["some_meta"] == "some_value"
|
||||
assert audio_answer.meta["audio_format"] == "wav"
|
||||
|
||||
assert np.allclose(sf.read(audio_answer.answer_audio)[0], sf.read(expected_audio_answer)[0], atol=0.001)
|
||||
assert np.allclose(sf.read(audio_answer.context_audio)[0], sf.read(expected_audio_context)[0], atol=0.001)
|
||||
|
||||
def test_answer_to_speech(tmp_path):
|
||||
text_answer = Answer(
|
||||
answer="answer",
|
||||
type="extractive",
|
||||
context="the context for this answer is here",
|
||||
offsets_in_document=[Span(31, 37)],
|
||||
offsets_in_context=[Span(21, 27)],
|
||||
meta={"some_meta": "some_value"},
|
||||
)
|
||||
expected_audio_answer = SAMPLES_PATH / "audio" / "answer.wav"
|
||||
expected_audio_context = SAMPLES_PATH / "audio" / "the context for this answer is here.wav"
|
||||
def test_document_to_speech(self, tmp_path):
|
||||
text_doc = Document(
|
||||
content="this is the content of the document", content_type="text", meta={"name": "test_document.txt"}
|
||||
)
|
||||
expected_audio_content = SAMPLES_PATH / "audio" / "this is the content of the document.wav"
|
||||
|
||||
answer2speech = AnswerToSpeech(
|
||||
generated_audio_dir=tmp_path / "test_audio",
|
||||
audio_params={"audio_naming_function": lambda text: text},
|
||||
transformers_params={"seed": 777, "always_fix_seed": True},
|
||||
)
|
||||
results, _ = answer2speech.run(answers=[text_answer])
|
||||
doc2speech = DocumentToSpeech(
|
||||
generated_audio_dir=tmp_path / "test_audio",
|
||||
audio_params={"audio_naming_function": lambda text: text},
|
||||
transformers_params={"seed": 777, "always_fix_seed": True},
|
||||
)
|
||||
results, _ = doc2speech.run(documents=[text_doc])
|
||||
|
||||
audio_answer: SpeechAnswer = results["answers"][0]
|
||||
assert isinstance(audio_answer, SpeechAnswer)
|
||||
assert audio_answer.type == "generative"
|
||||
assert audio_answer.answer_audio.name == expected_audio_answer.name
|
||||
assert audio_answer.context_audio.name == expected_audio_context.name
|
||||
assert audio_answer.answer == "answer"
|
||||
assert audio_answer.context == "the context for this answer is here"
|
||||
assert audio_answer.offsets_in_document == [Span(31, 37)]
|
||||
assert audio_answer.offsets_in_context == [Span(21, 27)]
|
||||
assert audio_answer.meta["some_meta"] == "some_value"
|
||||
assert audio_answer.meta["audio_format"] == "wav"
|
||||
audio_doc: SpeechDocument = results["documents"][0]
|
||||
assert isinstance(audio_doc, SpeechDocument)
|
||||
assert audio_doc.content_type == "audio"
|
||||
assert audio_doc.content_audio.name == expected_audio_content.name
|
||||
assert audio_doc.content == "this is the content of the document"
|
||||
assert audio_doc.meta["name"] == "test_document.txt"
|
||||
assert audio_doc.meta["audio_format"] == "wav"
|
||||
|
||||
assert np.allclose(sf.read(audio_answer.answer_audio)[0], sf.read(expected_audio_answer)[0], atol=0.001)
|
||||
assert np.allclose(sf.read(audio_answer.context_audio)[0], sf.read(expected_audio_context)[0], atol=0.001)
|
||||
|
||||
|
||||
def test_document_to_speech(tmp_path):
|
||||
text_doc = Document(
|
||||
content="this is the content of the document", content_type="text", meta={"name": "test_document.txt"}
|
||||
)
|
||||
expected_audio_content = SAMPLES_PATH / "audio" / "this is the content of the document.wav"
|
||||
|
||||
doc2speech = DocumentToSpeech(
|
||||
generated_audio_dir=tmp_path / "test_audio",
|
||||
audio_params={"audio_naming_function": lambda text: text},
|
||||
transformers_params={"seed": 777, "always_fix_seed": True},
|
||||
)
|
||||
results, _ = doc2speech.run(documents=[text_doc])
|
||||
|
||||
audio_doc: SpeechDocument = results["documents"][0]
|
||||
assert isinstance(audio_doc, SpeechDocument)
|
||||
assert audio_doc.content_type == "audio"
|
||||
assert audio_doc.content_audio.name == expected_audio_content.name
|
||||
assert audio_doc.content == "this is the content of the document"
|
||||
assert audio_doc.meta["name"] == "test_document.txt"
|
||||
assert audio_doc.meta["audio_format"] == "wav"
|
||||
|
||||
assert np.allclose(sf.read(audio_doc.content_audio)[0], sf.read(expected_audio_content)[0], atol=0.001)
|
||||
assert np.allclose(sf.read(audio_doc.content_audio)[0], sf.read(expected_audio_content)[0], atol=0.001)
|
||||
|
Loading…
x
Reference in New Issue
Block a user