mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-10-31 01:39:45 +00:00 
			
		
		
		
	fix: Fix predict_batch in TransformersReader for single nested Document list (#3748)
				
					
				
			* Fix restoring of list structure * Add tests
This commit is contained in:
		
							parent
							
								
									136928714c
								
							
						
					
					
						commit
						594d2a10f8
					
				| @ -153,7 +153,7 @@ class TransformersReader(BaseReader): | |||||||
|             max_seq_len=self.max_seq_len, |             max_seq_len=self.max_seq_len, | ||||||
|             doc_stride=self.doc_stride, |             doc_stride=self.doc_stride, | ||||||
|         ) |         ) | ||||||
|         # Transformers gives different output dependiing on top_k_per_candidate and number of inputs |         # Transformers gives different output depending on top_k_per_candidate and number of inputs | ||||||
|         if isinstance(predictions, dict): |         if isinstance(predictions, dict): | ||||||
|             predictions = [[predictions]] |             predictions = [[predictions]] | ||||||
|         elif len(inputs) == 1: |         elif len(inputs) == 1: | ||||||
| @ -224,7 +224,7 @@ class TransformersReader(BaseReader): | |||||||
|         # Transformers flattens lists of length 1. This restores the original list structure. |         # Transformers flattens lists of length 1. This restores the original list structure. | ||||||
|         if isinstance(predictions, dict): |         if isinstance(predictions, dict): | ||||||
|             predictions = [[predictions]] |             predictions = [[predictions]] | ||||||
|         elif len(number_of_docs) == 1: |         elif len(inputs) == 1: | ||||||
|             predictions = [predictions] |             predictions = [predictions] | ||||||
|         else: |         else: | ||||||
|             predictions = [p if isinstance(p, list) else [p] for p in predictions] |             predictions = [p if isinstance(p, list) else [p] for p in predictions] | ||||||
|  | |||||||
| @ -107,6 +107,18 @@ def test_output_batch_multiple_queries_multiple_doc_lists(reader, docs): | |||||||
|     assert len(prediction["answers"][0]) == 5  # top-k of 5 for collection of docs |     assert len(prediction["answers"][0]) == 5  # top-k of 5 for collection of docs | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def test_output_batch_single_query_single_nested_doc_list(reader, docs): | ||||||
|  |     prediction = reader.predict_batch(queries=["Who lives in Berlin?"], documents=[docs], top_k=5) | ||||||
|  |     assert prediction is not None | ||||||
|  |     assert prediction["queries"] == ["Who lives in Berlin?"] | ||||||
|  |     # Expected output: List of lists answers | ||||||
|  |     assert isinstance(prediction["answers"], list) | ||||||
|  |     assert isinstance(prediction["answers"][0], list) | ||||||
|  |     assert isinstance(prediction["answers"][0][0], Answer) | ||||||
|  |     assert len(prediction["answers"]) == 1  # Predictions for 1 collections of documents | ||||||
|  |     assert len(prediction["answers"][0]) == 5  # top-k of 5 for collection of docs | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| @pytest.mark.integration | @pytest.mark.integration | ||||||
| def test_no_answer_output(no_answer_reader, docs): | def test_no_answer_output(no_answer_reader, docs): | ||||||
|     no_answer_prediction = no_answer_reader.predict(query="What is the meaning of life?", documents=docs, top_k=5) |     no_answer_prediction = no_answer_reader.predict(query="What is the meaning of life?", documents=docs, top_k=5) | ||||||
|  | |||||||
| @ -118,7 +118,7 @@ EVAL_LABELS = [ | |||||||
| 
 | 
 | ||||||
| @pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True) | @pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True) | ||||||
| @pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True) | @pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True) | ||||||
| @pytest.mark.parametrize("reader", ["farm"], indirect=True) | @pytest.mark.parametrize("reader", ["farm", "transformers"], indirect=True) | ||||||
| def test_extractive_qa_eval(reader, retriever_with_docs, tmp_path): | def test_extractive_qa_eval(reader, retriever_with_docs, tmp_path): | ||||||
|     labels = EVAL_LABELS[:1] |     labels = EVAL_LABELS[:1] | ||||||
| 
 | 
 | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 bogdankostic
						bogdankostic