mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-10-25 14:59:01 +00:00 
			
		
		
		
	 dde9d59271
			
		
	
	
		dde9d59271
		
			
		
	
	
	
	
		
			
			* fix pip backtracking issue * restrict azure-core version * Remove the trailing comma * Add skip_magic_trailing_comma in pyproject.toml for pydoc compatibility * Pin pydoc-markdown _again_ Co-authored-by: Sara Zan <sarazanzo94@gmail.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
		
			
				
	
	
		
			96 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			96 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import pytest
 | |
| 
 | |
| from haystack.schema import Document
 | |
| from haystack.nodes.document_classifier.base import BaseDocumentClassifier
 | |
| 
 | |
| 
 | |
| @pytest.mark.slow
 | |
| def test_document_classifier(document_classifier):
 | |
|     assert isinstance(document_classifier, BaseDocumentClassifier)
 | |
| 
 | |
|     docs = [
 | |
|         Document(
 | |
|             content="""That's good. I like it.""" * 700,  # extra long text to check truncation
 | |
|             meta={"name": "0"},
 | |
|             id="1",
 | |
|         ),
 | |
|         Document(content="""That's bad. I don't like it.""", meta={"name": "1"}, id="2"),
 | |
|     ]
 | |
|     results = document_classifier.predict(documents=docs)
 | |
|     expected_labels = ["joy", "sadness"]
 | |
|     for i, doc in enumerate(results):
 | |
|         assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i]
 | |
| 
 | |
| 
 | |
| @pytest.mark.slow
 | |
| def test_zero_shot_document_classifier(zero_shot_document_classifier):
 | |
|     assert isinstance(zero_shot_document_classifier, BaseDocumentClassifier)
 | |
| 
 | |
|     docs = [
 | |
|         Document(
 | |
|             content="""That's good. I like it.""" * 700,  # extra long text to check truncation
 | |
|             meta={"name": "0"},
 | |
|             id="1",
 | |
|         ),
 | |
|         Document(content="""That's bad. I don't like it.""", meta={"name": "1"}, id="2"),
 | |
|     ]
 | |
|     results = zero_shot_document_classifier.predict(documents=docs)
 | |
|     expected_labels = ["positive", "negative"]
 | |
|     for i, doc in enumerate(results):
 | |
|         assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i]
 | |
| 
 | |
| 
 | |
| @pytest.mark.slow
 | |
| def test_document_classifier_batch_size(batched_document_classifier):
 | |
|     assert isinstance(batched_document_classifier, BaseDocumentClassifier)
 | |
| 
 | |
|     docs = [
 | |
|         Document(
 | |
|             content="""That's good. I like it.""" * 700,  # extra long text to check truncation
 | |
|             meta={"name": "0"},
 | |
|             id="1",
 | |
|         ),
 | |
|         Document(content="""That's bad. I don't like it.""", meta={"name": "1"}, id="2"),
 | |
|     ]
 | |
|     results = batched_document_classifier.predict(documents=docs)
 | |
|     expected_labels = ["joy", "sadness"]
 | |
|     for i, doc in enumerate(results):
 | |
|         assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i]
 | |
| 
 | |
| 
 | |
| @pytest.mark.slow
 | |
| def test_document_classifier_as_index_node(indexing_document_classifier):
 | |
|     assert isinstance(indexing_document_classifier, BaseDocumentClassifier)
 | |
| 
 | |
|     docs = [
 | |
|         {
 | |
|             "content": """That's good. I like it.""" * 700,  # extra long text to check truncation
 | |
|             "meta": {"name": "0"},
 | |
|             "id": "1",
 | |
|             "class_field": "That's bad.",
 | |
|         },
 | |
|         {"content": """That's bad. I like it.""", "meta": {"name": "1"}, "id": "2", "class_field": "That's good."},
 | |
|     ]
 | |
|     output, output_name = indexing_document_classifier.run(documents=docs, root_node="File")
 | |
|     expected_labels = ["sadness", "joy"]
 | |
|     for i, doc in enumerate(output["documents"]):
 | |
|         assert doc["meta"]["classification"]["label"] == expected_labels[i]
 | |
| 
 | |
| 
 | |
| @pytest.mark.slow
 | |
| def test_document_classifier_as_query_node(document_classifier):
 | |
|     assert isinstance(document_classifier, BaseDocumentClassifier)
 | |
| 
 | |
|     docs = [
 | |
|         Document(
 | |
|             content="""That's good. I like it.""" * 700,  # extra long text to check truncation
 | |
|             meta={"name": "0"},
 | |
|             id="1",
 | |
|         ),
 | |
|         Document(content="""That's bad. I don't like it.""", meta={"name": "1"}, id="2"),
 | |
|     ]
 | |
|     output, output_name = document_classifier.run(documents=docs, root_node="Query")
 | |
|     expected_labels = ["joy", "sadness"]
 | |
|     for i, doc in enumerate(output["documents"]):
 | |
|         assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i]
 |