mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-10-31 17:59:27 +00:00 
			
		
		
		
	
		
			
	
	
		
			122 lines
		
	
	
		
			5.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			122 lines
		
	
	
		
			5.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | import pytest | ||
|  | import logging | ||
|  | 
 | ||
|  | from haystack.schema import Document | ||
|  | from haystack.nodes.doc_language_classifier import ( | ||
|  |     LangdetectDocumentLanguageClassifier, | ||
|  |     TransformersDocumentLanguageClassifier, | ||
|  | ) | ||
|  | 
 | ||
|  | LANGUAGES_TO_ROUTE = ["en", "es", "it"] | ||
|  | DOCUMENTS = [ | ||
|  |     Document(content="My name is Matteo and I live in Rome"), | ||
|  |     Document(content="Mi chiamo Matteo e vivo a Roma"), | ||
|  |     Document(content="Mi nombre es Matteo y vivo en Roma"), | ||
|  | ] | ||
|  | 
 | ||
|  | EXPECTED_LANGUAGES = ["en", "it", "es"] | ||
|  | EXPECTED_OUTPUT_EDGES = ["output_1", "output_3", "output_2"] | ||
|  | 
 | ||
|  | 
 | ||
|  | @pytest.fixture(params=["langdetect", "transformers"]) | ||
|  | def doclangclassifier(request): | ||
|  |     if request.param == "langdetect": | ||
|  |         return LangdetectDocumentLanguageClassifier(route_by_language=True, languages_to_route=LANGUAGES_TO_ROUTE) | ||
|  |     elif request.param == "transformers": | ||
|  |         return TransformersDocumentLanguageClassifier( | ||
|  |             route_by_language=True, | ||
|  |             languages_to_route=LANGUAGES_TO_ROUTE, | ||
|  |             model_name_or_path="jb2k/bert-base-multilingual-cased-language-detection", | ||
|  |             labels_to_languages_mapping={"LABEL_11": "en", "LABEL_22": "it", "LABEL_38": "es"}, | ||
|  |         ) | ||
|  | 
 | ||
|  | 
 | ||
|  | @pytest.mark.integration | ||
|  | @pytest.mark.parametrize("doclangclassifier", ["langdetect", "transformers"], indirect=True) | ||
|  | def test_doclangclassifier_predict(doclangclassifier): | ||
|  |     results = doclangclassifier.predict(documents=DOCUMENTS) | ||
|  |     for doc, expected_language in zip(results, EXPECTED_LANGUAGES): | ||
|  |         assert doc.to_dict()["meta"]["language"] == expected_language | ||
|  | 
 | ||
|  | 
 | ||
|  | @pytest.mark.integration | ||
|  | @pytest.mark.parametrize("doclangclassifier", ["transformers"], indirect=True) | ||
|  | def test_transformers_doclangclassifier_predict_wo_mapping(doclangclassifier): | ||
|  |     doclangclassifier.labels_to_languages_mapping = {} | ||
|  |     expected_labels = ["LABEL_11", "LABEL_22", "LABEL_38"] | ||
|  |     results = doclangclassifier.predict(documents=DOCUMENTS) | ||
|  |     for doc, expected_label in zip(results, expected_labels): | ||
|  |         assert doc.to_dict()["meta"]["language"] == expected_label | ||
|  | 
 | ||
|  | 
 | ||
|  | @pytest.mark.integration | ||
|  | @pytest.mark.parametrize("doclangclassifier", ["langdetect", "transformers"], indirect=True) | ||
|  | def test_doclangclassifier_predict_batch(doclangclassifier): | ||
|  |     results = doclangclassifier.predict_batch(documents=[DOCUMENTS, DOCUMENTS[:2]]) | ||
|  |     expected_languages = [EXPECTED_LANGUAGES, EXPECTED_LANGUAGES[:2]] | ||
|  |     for lst_docs, lst_expected_languages in zip(results, expected_languages): | ||
|  |         for doc, expected_language in zip(lst_docs, lst_expected_languages): | ||
|  |             assert doc.to_dict()["meta"]["language"] == expected_language | ||
|  | 
 | ||
|  | 
 | ||
|  | @pytest.mark.integration | ||
|  | @pytest.mark.parametrize("doclangclassifier", ["langdetect", "transformers"], indirect=True) | ||
|  | def test_doclangclassifier_run_not_route(doclangclassifier): | ||
|  |     doclangclassifier.route_by_language = False | ||
|  |     results, edge = doclangclassifier.run(documents=DOCUMENTS) | ||
|  |     assert edge == "output_1" | ||
|  |     for doc, expected_language in zip(results["documents"], EXPECTED_LANGUAGES): | ||
|  |         assert doc.to_dict()["meta"]["language"] == expected_language | ||
|  | 
 | ||
|  | 
 | ||
|  | @pytest.mark.integration | ||
|  | @pytest.mark.parametrize("doclangclassifier", ["langdetect", "transformers"], indirect=True) | ||
|  | def test_doclangclassifier_run_route(doclangclassifier): | ||
|  |     for doc, expected_language, expected_edge in zip(DOCUMENTS, EXPECTED_LANGUAGES, EXPECTED_OUTPUT_EDGES): | ||
|  |         result, edge = doclangclassifier.run(documents=[doc]) | ||
|  |         document = result["documents"][0] | ||
|  | 
 | ||
|  |         assert edge == expected_edge | ||
|  |         assert document.to_dict()["meta"]["language"] == expected_language | ||
|  | 
 | ||
|  | 
 | ||
|  | @pytest.mark.integration | ||
|  | @pytest.mark.parametrize("doclangclassifier", ["langdetect", "transformers"], indirect=True) | ||
|  | def test_doclangclassifier_run_route_fail_on_mixed_languages(doclangclassifier): | ||
|  |     with pytest.raises(ValueError, match="Documents of multiple languages"): | ||
|  |         doclangclassifier.run(documents=DOCUMENTS) | ||
|  | 
 | ||
|  | 
 | ||
|  | # not testing transformers because current models always predict a language | ||
|  | @pytest.mark.integration | ||
|  | @pytest.mark.parametrize("doclangclassifier", ["langdetect"], indirect=True) | ||
|  | def test_doclangclassifier_run_route_cannot_detect_language(doclangclassifier, caplog): | ||
|  |     doc_unidentifiable_lang = Document("01234, 56789, ") | ||
|  |     with caplog.at_level(logging.INFO): | ||
|  |         results, edge = doclangclassifier.run(documents=[doc_unidentifiable_lang]) | ||
|  |         assert "The model cannot detect the language of any of the documents." in caplog.text | ||
|  |     assert edge == "output_1" | ||
|  |     assert results["documents"][0].to_dict()["meta"]["language"] is None | ||
|  | 
 | ||
|  | 
 | ||
|  | @pytest.mark.integration | ||
|  | @pytest.mark.parametrize("doclangclassifier", ["langdetect", "transformers"], indirect=True) | ||
|  | def test_doclangclassifier_run_route_fail_on_language_not_in_list(doclangclassifier, caplog): | ||
|  |     doc_other_lang = Document("Meu nome é Matteo e moro em Roma") | ||
|  |     with pytest.raises(ValueError, match="is not in the list of languages to route"): | ||
|  |         doclangclassifier.run(documents=[doc_other_lang]) | ||
|  | 
 | ||
|  | 
 | ||
|  | @pytest.mark.integration | ||
|  | @pytest.mark.parametrize("doclangclassifier", ["langdetect", "transformers"], indirect=True) | ||
|  | def test_doclangclassifier_run_batch(doclangclassifier): | ||
|  |     docs = [[doc] for doc in DOCUMENTS] | ||
|  |     results, split_edge = doclangclassifier.run_batch(documents=docs) | ||
|  |     assert split_edge == "split" | ||
|  |     for edge, result in results.items(): | ||
|  |         document = result["documents"][0][0] | ||
|  |         num_document = DOCUMENTS.index(document) | ||
|  |         expected_language = EXPECTED_LANGUAGES[num_document] | ||
|  |         assert edge == EXPECTED_OUTPUT_EDGES[num_document] | ||
|  |         assert document.to_dict()["meta"]["language"] == expected_language |