diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 956209af6..67a4866d9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -75,6 +75,27 @@ jobs: run: | pylint -ry -j 0 haystack/ rest_api/ ui/ + unit-tests: + name: Unit / ${{ matrix.os }} + needs: + - mypy + - pylint + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v2 + + - name: Setup Python + uses: ./.github/actions/python_cache/ + + - name: Install Haystack + run: pip install .[all] + + - name: Run + run: pytest -m "unit" test/ unit-tests-linux: needs: @@ -88,7 +109,6 @@ jobs: - "pipelines" - "modeling" - "others" - - "document_stores/test_opensearch.py" runs-on: ubuntu-latest timeout-minutes: 30 diff --git a/pyproject.toml b/pyproject.toml index 53ceb5122..9bebb99ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,6 +90,7 @@ minversion = "6.0" addopts = "--strict-markers" markers = [ "integration: integration tests", + "unit: unit tests", "generator: generator tests", "summarizer: summarizer tests", diff --git a/test/document_stores/test_opensearch.py b/test/document_stores/test_opensearch.py index 0ad8ca323..899ceb42c 100644 --- a/test/document_stores/test_opensearch.py +++ b/test/document_stores/test_opensearch.py @@ -18,14 +18,9 @@ from haystack.document_stores.opensearch import ( from haystack.schema import Document, Label, Answer from haystack.errors import DocumentStoreError - -# Skip OpenSearchDocumentStore tests on Windows -pytestmark = pytest.mark.skipif(sys.platform in ["win32", "cygwin"], reason="Opensearch not running on Windows CI") - # Being all the tests in this module, ideally we wouldn't need a marker here, # but this is to allow this test suite to be skipped when running (e.g.) # `pytest test/document_stores --document-store-type=faiss` -@pytest.mark.opensearch class TestOpenSearchDocumentStore: # Constants @@ -210,6 +205,7 @@ class TestOpenSearchDocumentStore: # Unit tests + @pytest.mark.unit def test___init___api_key_raises_warning(self, mocked_document_store, caplog): with caplog.at_level(logging.WARN, logger="haystack.document_stores.opensearch"): mocked_document_store.__init__(api_key="foo") @@ -220,6 +216,7 @@ class TestOpenSearchDocumentStore: for r in caplog.records: assert r.levelname == "WARNING" + @pytest.mark.unit def test___init___connection_test_fails(self, mocked_document_store): failing_client = MagicMock() failing_client.indices.get.side_effect = Exception("The client failed!") @@ -227,6 +224,7 @@ class TestOpenSearchDocumentStore: with pytest.raises(ConnectionError): mocked_document_store.__init__() + @pytest.mark.unit def test___init___client_params(self, mocked_open_search_init, _init_client_params): """ Ensure the Opensearch-py client was initialized with the right params @@ -244,18 +242,21 @@ class TestOpenSearchDocumentStore: "connection_class": RequestsHttpConnection, } + @pytest.mark.unit def test__init_client_use_system_proxy_use_sys_proxy(self, mocked_open_search_init, _init_client_params): _init_client_params["use_system_proxy"] = False OpenSearchDocumentStore._init_client(**_init_client_params) _, kwargs = mocked_open_search_init.call_args assert kwargs["connection_class"] == Urllib3HttpConnection + @pytest.mark.unit def test__init_client_use_system_proxy_dont_use_sys_proxy(self, mocked_open_search_init, _init_client_params): _init_client_params["use_system_proxy"] = True OpenSearchDocumentStore._init_client(**_init_client_params) _, kwargs = mocked_open_search_init.call_args assert kwargs["connection_class"] == RequestsHttpConnection + @pytest.mark.unit def test__init_client_auth_methods_username_password(self, mocked_open_search_init, _init_client_params): _init_client_params["username"] = "user" _init_client_params["aws4auth"] = None @@ -263,6 +264,7 @@ class TestOpenSearchDocumentStore: _, kwargs = mocked_open_search_init.call_args assert kwargs["http_auth"] == ("user", "pass") + @pytest.mark.unit def test__init_client_auth_methods_aws_iam(self, mocked_open_search_init, _init_client_params): _init_client_params["username"] = "" _init_client_params["aws4auth"] = "foo" @@ -270,6 +272,7 @@ class TestOpenSearchDocumentStore: _, kwargs = mocked_open_search_init.call_args assert kwargs["http_auth"] == "foo" + @pytest.mark.unit def test__init_client_auth_methods_no_auth(self, mocked_open_search_init, _init_client_params): _init_client_params["username"] = "" _init_client_params["aws4auth"] = None @@ -277,11 +280,13 @@ class TestOpenSearchDocumentStore: _, kwargs = mocked_open_search_init.call_args assert "http_auth" not in kwargs + @pytest.mark.unit def test_query_by_embedding_raises_if_missing_field(self, mocked_document_store): mocked_document_store.embedding_field = "" with pytest.raises(DocumentStoreError): mocked_document_store.query_by_embedding(self.query_emb) + @pytest.mark.unit def test_query_by_embedding_filters(self, mocked_document_store): expected_filters = {"type": "article", "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}} mocked_document_store.query_by_embedding(self.query_emb, filters=expected_filters) @@ -293,6 +298,7 @@ class TestOpenSearchDocumentStore: {"range": {"date": {"gte": "2015-01-01", "lt": "2021-01-01"}}}, ] + @pytest.mark.unit def test_query_by_embedding_return_embedding_false(self, mocked_document_store): mocked_document_store.return_embedding = False mocked_document_store.query_by_embedding(self.query_emb) @@ -300,6 +306,7 @@ class TestOpenSearchDocumentStore: _, kwargs = mocked_document_store.client.search.call_args assert kwargs["body"]["_source"] == {"excludes": ["embedding"]} + @pytest.mark.unit def test_query_by_embedding_excluded_meta_data_return_embedding_true(self, mocked_document_store): """ Test that when `return_embedding==True` the field should NOT be excluded even if it @@ -312,6 +319,7 @@ class TestOpenSearchDocumentStore: # we expect "embedding" was removed from the final query assert kwargs["body"]["_source"] == {"excludes": ["foo"]} + @pytest.mark.unit def test_query_by_embedding_excluded_meta_data_return_embedding_false(self, mocked_document_store): """ Test that when `return_embedding==False`, the final query excludes the `embedding` field @@ -324,6 +332,7 @@ class TestOpenSearchDocumentStore: _, kwargs = mocked_document_store.client.search.call_args assert kwargs["body"]["_source"] == {"excludes": ["foo", "embedding"]} + @pytest.mark.unit def test__create_document_index_with_alias(self, mocked_document_store, caplog): mocked_document_store.client.indices.exists_alias.return_value = True @@ -332,6 +341,7 @@ class TestOpenSearchDocumentStore: assert f"Index name {self.index_name} is an alias." in caplog.text + @pytest.mark.unit def test__create_document_index_wrong_mapping_raises(self, mocked_document_store, index): """ Ensure the method raises if we specify a field in `search_fields` that's not text @@ -342,6 +352,7 @@ class TestOpenSearchDocumentStore: with pytest.raises(Exception, match=f"The search_field 'age' of index '{self.index_name}' with type 'integer'"): mocked_document_store._create_document_index(self.index_name) + @pytest.mark.unit def test__create_document_index_create_mapping_if_missing(self, mocked_document_store, index): mocked_document_store.client.indices.exists.return_value = True mocked_document_store.client.indices.get.return_value = {self.index_name: index} @@ -354,6 +365,7 @@ class TestOpenSearchDocumentStore: assert kwargs["index"] == self.index_name assert "doesnt_have_a_mapping" in kwargs["body"]["properties"] + @pytest.mark.unit def test__create_document_index_with_bad_field_raises(self, mocked_document_store, index): mocked_document_store.client.indices.exists.return_value = True mocked_document_store.client.indices.get.return_value = {self.index_name: index} @@ -364,6 +376,7 @@ class TestOpenSearchDocumentStore: ): mocked_document_store._create_document_index(self.index_name) + @pytest.mark.unit def test__create_document_index_with_existing_mapping_but_no_method(self, mocked_document_store, index): """ We call the method passing a properly mapped field but without the `method` specified in the mapping @@ -381,6 +394,7 @@ class TestOpenSearchDocumentStore: # False but I'm not sure this is by design assert mocked_document_store.embeddings_field_supports_similarity is False + @pytest.mark.unit def test__create_document_index_with_existing_mapping_similarity(self, mocked_document_store, index): mocked_document_store.client.indices.exists.return_value = True mocked_document_store.client.indices.get.return_value = {self.index_name: index} @@ -390,6 +404,7 @@ class TestOpenSearchDocumentStore: mocked_document_store._create_document_index(self.index_name) assert mocked_document_store.embeddings_field_supports_similarity is True + @pytest.mark.unit def test__create_document_index_with_existing_mapping_similarity_mismatch( self, mocked_document_store, index, caplog ): @@ -403,6 +418,7 @@ class TestOpenSearchDocumentStore: assert "Embedding field 'vec' is optimized for similarity 'dot_product'." in caplog.text assert mocked_document_store.embeddings_field_supports_similarity is False + @pytest.mark.unit def test__create_document_index_with_existing_mapping_adjust_params_hnsw_default( self, mocked_document_store, index ): @@ -420,6 +436,7 @@ class TestOpenSearchDocumentStore: _, kwargs = mocked_document_store.client.indices.put_settings.call_args assert kwargs["body"] == {"knn.algo_param.ef_search": 20} + @pytest.mark.unit def test__create_document_index_with_existing_mapping_adjust_params_hnsw(self, mocked_document_store, index): """ Test a value of `knn.algo_param` that needs to be adjusted @@ -436,6 +453,7 @@ class TestOpenSearchDocumentStore: _, kwargs = mocked_document_store.client.indices.put_settings.call_args assert kwargs["body"] == {"knn.algo_param.ef_search": 20} + @pytest.mark.unit def test__create_document_index_with_existing_mapping_adjust_params_flat_default( self, mocked_document_store, index ): @@ -451,6 +469,7 @@ class TestOpenSearchDocumentStore: mocked_document_store.client.indices.put_settings.assert_not_called + @pytest.mark.unit def test__create_document_index_with_existing_mapping_adjust_params_hnsw(self, mocked_document_store, index): """ Test a value of `knn.algo_param` that needs to be adjusted @@ -467,6 +486,7 @@ class TestOpenSearchDocumentStore: _, kwargs = mocked_document_store.client.indices.put_settings.call_args assert kwargs["body"] == {"knn.algo_param.ef_search": 512} + @pytest.mark.unit def test__create_document_index_no_index_custom_mapping(self, mocked_document_store): mocked_document_store.client.indices.exists.return_value = False mocked_document_store.custom_mapping = {"mappings": {"properties": {"a_number": {"type": "integer"}}}} @@ -475,6 +495,7 @@ class TestOpenSearchDocumentStore: _, kwargs = mocked_document_store.client.indices.create.call_args assert kwargs["body"] == {"mappings": {"properties": {"a_number": {"type": "integer"}}}} + @pytest.mark.unit def test__create_document_index_no_index_no_mapping(self, mocked_document_store): mocked_document_store.client.indices.exists.return_value = False mocked_document_store._create_document_index(self.index_name) @@ -502,6 +523,7 @@ class TestOpenSearchDocumentStore: "settings": {"analysis": {"analyzer": {"default": {"type": "standard"}}}, "index": {"knn": True}}, } + @pytest.mark.unit def test__create_document_index_no_index_no_mapping_with_synonyms(self, mocked_document_store): mocked_document_store.client.indices.exists.return_value = False mocked_document_store.search_fields = ["occupation"] @@ -542,6 +564,7 @@ class TestOpenSearchDocumentStore: }, } + @pytest.mark.unit def test__create_document_index_no_index_no_mapping_with_embedding_field(self, mocked_document_store): mocked_document_store.client.indices.exists.return_value = False mocked_document_store.embedding_field = "vec" @@ -575,6 +598,7 @@ class TestOpenSearchDocumentStore: }, } + @pytest.mark.unit def test__create_document_index_client_failure(self, mocked_document_store): mocked_document_store.client.indices.exists.return_value = False mocked_document_store.client.indices.create.side_effect = RequestError @@ -582,6 +606,7 @@ class TestOpenSearchDocumentStore: with pytest.raises(RequestError): mocked_document_store._create_document_index(self.index_name) + @pytest.mark.unit def test__get_embedding_field_mapping_flat(self, mocked_document_store): mocked_document_store.index_type = "flat" @@ -596,6 +621,7 @@ class TestOpenSearchDocumentStore: }, } + @pytest.mark.unit def test__get_embedding_field_mapping_hnsw(self, mocked_document_store): mocked_document_store.index_type = "hnsw" @@ -610,6 +636,7 @@ class TestOpenSearchDocumentStore: }, } + @pytest.mark.unit def test__get_embedding_field_mapping_wrong(self, mocked_document_store, caplog): mocked_document_store.index_type = "foo" @@ -623,12 +650,14 @@ class TestOpenSearchDocumentStore: "method": {"space_type": "innerproduct", "name": "hnsw", "engine": "nmslib"}, } + @pytest.mark.unit def test__create_label_index_already_exists(self, mocked_document_store): mocked_document_store.client.indices.exists.return_value = True mocked_document_store._create_label_index("foo") mocked_document_store.client.indices.create.assert_not_called() + @pytest.mark.unit def test__create_label_index_client_error(self, mocked_document_store): mocked_document_store.client.indices.exists.return_value = False mocked_document_store.client.indices.create.side_effect = RequestError @@ -636,6 +665,7 @@ class TestOpenSearchDocumentStore: with pytest.raises(RequestError): mocked_document_store._create_label_index("foo") + @pytest.mark.unit def test__get_vector_similarity_query_support_true(self, mocked_document_store): mocked_document_store.embedding_field = "FooField" mocked_document_store.embeddings_field_supports_similarity = True @@ -644,6 +674,7 @@ class TestOpenSearchDocumentStore: "bool": {"must": [{"knn": {"FooField": {"vector": self.query_emb.tolist(), "k": 3}}}]} } + @pytest.mark.unit def test__get_vector_similarity_query_support_false(self, mocked_document_store): mocked_document_store.embedding_field = "FooField" mocked_document_store.embeddings_field_supports_similarity = False @@ -664,15 +695,18 @@ class TestOpenSearchDocumentStore: } } + @pytest.mark.unit def test__get_raw_similarity_score_dot(self, mocked_document_store): mocked_document_store.similarity = "dot_product" assert mocked_document_store._get_raw_similarity_score(2) == 1 assert mocked_document_store._get_raw_similarity_score(-2) == 1.5 + @pytest.mark.unit def test__get_raw_similarity_score_l2(self, mocked_document_store): mocked_document_store.similarity = "l2" assert mocked_document_store._get_raw_similarity_score(1) == 0 + @pytest.mark.unit def test__get_raw_similarity_score_cosine(self, mocked_document_store): mocked_document_store.similarity = "cosine" mocked_document_store.embeddings_field_supports_similarity = True @@ -680,12 +714,14 @@ class TestOpenSearchDocumentStore: mocked_document_store.embeddings_field_supports_similarity = False assert mocked_document_store._get_raw_similarity_score(1) == 0 + @pytest.mark.unit def test_clone_embedding_field_duplicate_mapping(self, mocked_document_store, index): mocked_document_store.client.indices.get.return_value = {self.index_name: index} mocked_document_store.index = self.index_name with pytest.raises(Exception, match="age already exists with mapping"): mocked_document_store.clone_embedding_field("age", "cosine") + @pytest.mark.unit def test_clone_embedding_field_update_mapping(self, mocked_document_store, index, monkeypatch): mocked_document_store.client.indices.get.return_value = {self.index_name: index} mocked_document_store.index = self.index_name @@ -709,6 +745,7 @@ class TestOpenSearchDocumentStore: class TestOpenDistroElasticsearchDocumentStore: + @pytest.mark.unit def test_deprecation_notice(self, monkeypatch, caplog): klass = OpenDistroElasticsearchDocumentStore monkeypatch.setattr(klass, "_init_client", MagicMock()) diff --git a/test/nodes/test_audio.py b/test/nodes/test_audio.py index 774a61409..1fdaf31b7 100644 --- a/test/nodes/test_audio.py +++ b/test/nodes/test_audio.py @@ -1,7 +1,14 @@ import os +import pytest import numpy as np -import soundfile as sf + +try: + import soundfile as sf + + soundfile_not_found = False +except: + soundfile_not_found = True from haystack.schema import Span, Answer, SpeechAnswer, Document, SpeechDocument from haystack.nodes.audio import AnswerToSpeech, DocumentToSpeech @@ -10,111 +17,108 @@ from haystack.nodes.audio._text_to_speech import TextToSpeech from ..conftest import SAMPLES_PATH -def test_text_to_speech_audio_data(): - text2speech = TextToSpeech( - model_name_or_path="espnet/kan-bayashi_ljspeech_vits", - transformers_params={"seed": 777, "always_fix_seed": True}, - ) - expected_audio_data, _ = sf.read(SAMPLES_PATH / "audio" / "answer.wav") - audio_data = text2speech.text_to_audio_data(text="answer") +@pytest.mark.skipif(soundfile_not_found, reason="soundfile not found") +class TestTextToSpeech: + def test_text_to_speech_audio_data(self): + text2speech = TextToSpeech( + model_name_or_path="espnet/kan-bayashi_ljspeech_vits", + transformers_params={"seed": 777, "always_fix_seed": True}, + ) + expected_audio_data, _ = sf.read(SAMPLES_PATH / "audio" / "answer.wav") + audio_data = text2speech.text_to_audio_data(text="answer") - assert np.allclose(expected_audio_data, audio_data, atol=0.001) + assert np.allclose(expected_audio_data, audio_data, atol=0.001) + def test_text_to_speech_audio_file(self, tmp_path): + text2speech = TextToSpeech( + model_name_or_path="espnet/kan-bayashi_ljspeech_vits", + transformers_params={"seed": 777, "always_fix_seed": True}, + ) + expected_audio_data, _ = sf.read(SAMPLES_PATH / "audio" / "answer.wav") + audio_file = text2speech.text_to_audio_file(text="answer", generated_audio_dir=tmp_path / "test_audio") + assert os.path.exists(audio_file) + assert np.allclose(expected_audio_data, sf.read(audio_file)[0], atol=0.001) -def test_text_to_speech_audio_file(tmp_path): - text2speech = TextToSpeech( - model_name_or_path="espnet/kan-bayashi_ljspeech_vits", - transformers_params={"seed": 777, "always_fix_seed": True}, - ) - expected_audio_data, _ = sf.read(SAMPLES_PATH / "audio" / "answer.wav") - audio_file = text2speech.text_to_audio_file(text="answer", generated_audio_dir=tmp_path / "test_audio") - assert os.path.exists(audio_file) - assert np.allclose(expected_audio_data, sf.read(audio_file)[0], atol=0.001) + def test_text_to_speech_compress_audio(self, tmp_path): + text2speech = TextToSpeech( + model_name_or_path="espnet/kan-bayashi_ljspeech_vits", + transformers_params={"seed": 777, "always_fix_seed": True}, + ) + expected_audio_file = SAMPLES_PATH / "audio" / "answer.wav" + audio_file = text2speech.text_to_audio_file( + text="answer", generated_audio_dir=tmp_path / "test_audio", audio_format="mp3" + ) + assert os.path.exists(audio_file) + assert audio_file.suffix == ".mp3" + # FIXME find a way to make sure the compressed audio is similar enough to the wav version. + # At a manual inspection, the code seems to be working well. + def test_text_to_speech_naming_function(self, tmp_path): + text2speech = TextToSpeech( + model_name_or_path="espnet/kan-bayashi_ljspeech_vits", + transformers_params={"seed": 777, "always_fix_seed": True}, + ) + expected_audio_file = SAMPLES_PATH / "audio" / "answer.wav" + audio_file = text2speech.text_to_audio_file( + text="answer", generated_audio_dir=tmp_path / "test_audio", audio_naming_function=lambda text: text + ) + assert os.path.exists(audio_file) + assert audio_file.name == expected_audio_file.name + assert np.allclose(sf.read(expected_audio_file)[0], sf.read(audio_file)[0], atol=0.001) -def test_text_to_speech_compress_audio(tmp_path): - text2speech = TextToSpeech( - model_name_or_path="espnet/kan-bayashi_ljspeech_vits", - transformers_params={"seed": 777, "always_fix_seed": True}, - ) - expected_audio_file = SAMPLES_PATH / "audio" / "answer.wav" - audio_file = text2speech.text_to_audio_file( - text="answer", generated_audio_dir=tmp_path / "test_audio", audio_format="mp3" - ) - assert os.path.exists(audio_file) - assert audio_file.suffix == ".mp3" - # FIXME find a way to make sure the compressed audio is similar enough to the wav version. - # At a manual inspection, the code seems to be working well. + def test_answer_to_speech(self, tmp_path): + text_answer = Answer( + answer="answer", + type="extractive", + context="the context for this answer is here", + offsets_in_document=[Span(31, 37)], + offsets_in_context=[Span(21, 27)], + meta={"some_meta": "some_value"}, + ) + expected_audio_answer = SAMPLES_PATH / "audio" / "answer.wav" + expected_audio_context = SAMPLES_PATH / "audio" / "the context for this answer is here.wav" + answer2speech = AnswerToSpeech( + generated_audio_dir=tmp_path / "test_audio", + audio_params={"audio_naming_function": lambda text: text}, + transformers_params={"seed": 777, "always_fix_seed": True}, + ) + results, _ = answer2speech.run(answers=[text_answer]) -def test_text_to_speech_naming_function(tmp_path): - text2speech = TextToSpeech( - model_name_or_path="espnet/kan-bayashi_ljspeech_vits", - transformers_params={"seed": 777, "always_fix_seed": True}, - ) - expected_audio_file = SAMPLES_PATH / "audio" / "answer.wav" - audio_file = text2speech.text_to_audio_file( - text="answer", generated_audio_dir=tmp_path / "test_audio", audio_naming_function=lambda text: text - ) - assert os.path.exists(audio_file) - assert audio_file.name == expected_audio_file.name - assert np.allclose(sf.read(expected_audio_file)[0], sf.read(audio_file)[0], atol=0.001) + audio_answer: SpeechAnswer = results["answers"][0] + assert isinstance(audio_answer, SpeechAnswer) + assert audio_answer.type == "generative" + assert audio_answer.answer_audio.name == expected_audio_answer.name + assert audio_answer.context_audio.name == expected_audio_context.name + assert audio_answer.answer == "answer" + assert audio_answer.context == "the context for this answer is here" + assert audio_answer.offsets_in_document == [Span(31, 37)] + assert audio_answer.offsets_in_context == [Span(21, 27)] + assert audio_answer.meta["some_meta"] == "some_value" + assert audio_answer.meta["audio_format"] == "wav" + assert np.allclose(sf.read(audio_answer.answer_audio)[0], sf.read(expected_audio_answer)[0], atol=0.001) + assert np.allclose(sf.read(audio_answer.context_audio)[0], sf.read(expected_audio_context)[0], atol=0.001) -def test_answer_to_speech(tmp_path): - text_answer = Answer( - answer="answer", - type="extractive", - context="the context for this answer is here", - offsets_in_document=[Span(31, 37)], - offsets_in_context=[Span(21, 27)], - meta={"some_meta": "some_value"}, - ) - expected_audio_answer = SAMPLES_PATH / "audio" / "answer.wav" - expected_audio_context = SAMPLES_PATH / "audio" / "the context for this answer is here.wav" + def test_document_to_speech(self, tmp_path): + text_doc = Document( + content="this is the content of the document", content_type="text", meta={"name": "test_document.txt"} + ) + expected_audio_content = SAMPLES_PATH / "audio" / "this is the content of the document.wav" - answer2speech = AnswerToSpeech( - generated_audio_dir=tmp_path / "test_audio", - audio_params={"audio_naming_function": lambda text: text}, - transformers_params={"seed": 777, "always_fix_seed": True}, - ) - results, _ = answer2speech.run(answers=[text_answer]) + doc2speech = DocumentToSpeech( + generated_audio_dir=tmp_path / "test_audio", + audio_params={"audio_naming_function": lambda text: text}, + transformers_params={"seed": 777, "always_fix_seed": True}, + ) + results, _ = doc2speech.run(documents=[text_doc]) - audio_answer: SpeechAnswer = results["answers"][0] - assert isinstance(audio_answer, SpeechAnswer) - assert audio_answer.type == "generative" - assert audio_answer.answer_audio.name == expected_audio_answer.name - assert audio_answer.context_audio.name == expected_audio_context.name - assert audio_answer.answer == "answer" - assert audio_answer.context == "the context for this answer is here" - assert audio_answer.offsets_in_document == [Span(31, 37)] - assert audio_answer.offsets_in_context == [Span(21, 27)] - assert audio_answer.meta["some_meta"] == "some_value" - assert audio_answer.meta["audio_format"] == "wav" + audio_doc: SpeechDocument = results["documents"][0] + assert isinstance(audio_doc, SpeechDocument) + assert audio_doc.content_type == "audio" + assert audio_doc.content_audio.name == expected_audio_content.name + assert audio_doc.content == "this is the content of the document" + assert audio_doc.meta["name"] == "test_document.txt" + assert audio_doc.meta["audio_format"] == "wav" - assert np.allclose(sf.read(audio_answer.answer_audio)[0], sf.read(expected_audio_answer)[0], atol=0.001) - assert np.allclose(sf.read(audio_answer.context_audio)[0], sf.read(expected_audio_context)[0], atol=0.001) - - -def test_document_to_speech(tmp_path): - text_doc = Document( - content="this is the content of the document", content_type="text", meta={"name": "test_document.txt"} - ) - expected_audio_content = SAMPLES_PATH / "audio" / "this is the content of the document.wav" - - doc2speech = DocumentToSpeech( - generated_audio_dir=tmp_path / "test_audio", - audio_params={"audio_naming_function": lambda text: text}, - transformers_params={"seed": 777, "always_fix_seed": True}, - ) - results, _ = doc2speech.run(documents=[text_doc]) - - audio_doc: SpeechDocument = results["documents"][0] - assert isinstance(audio_doc, SpeechDocument) - assert audio_doc.content_type == "audio" - assert audio_doc.content_audio.name == expected_audio_content.name - assert audio_doc.content == "this is the content of the document" - assert audio_doc.meta["name"] == "test_document.txt" - assert audio_doc.meta["audio_format"] == "wav" - - assert np.allclose(sf.read(audio_doc.content_audio)[0], sf.read(expected_audio_content)[0], atol=0.001) + assert np.allclose(sf.read(audio_doc.content_audio)[0], sf.read(expected_audio_content)[0], atol=0.001)