| 
									
										
										
										
											2021-10-29 13:52:28 +05:30
										 |  |  | import sys | 
					
						
							| 
									
										
										
										
											2021-10-19 15:22:44 +02:00
										 |  |  | from typing import List | 
					
						
							| 
									
										
										
										
											2021-06-14 17:53:43 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-12-03 10:27:06 +01:00
										 |  |  | import numpy as np | 
					
						
							| 
									
										
										
										
											2020-10-30 18:06:02 +01:00
										 |  |  | import pytest | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-10-25 15:50:23 +02:00
										 |  |  | from haystack.schema import Document | 
					
						
							|  |  |  | from haystack.nodes.answer_generator import Seq2SeqGenerator | 
					
						
							|  |  |  | from haystack.pipelines import TranslationWrapperPipeline, GenerativeQAPipeline | 
					
						
							| 
									
										
										
										
											2020-10-30 18:06:02 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-10-19 15:22:44 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-03-15 11:17:26 +01:00
										 |  |  | from .conftest import DOCS_WITH_EMBEDDINGS | 
					
						
							| 
									
										
										
										
											2020-10-30 18:06:02 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-12-22 17:20:23 +01:00
										 |  |  | # Keeping few (retriever,document_store) combination to reduce test time | 
					
						
							| 
									
										
										
										
											2022-02-09 21:29:05 +01:00
										 |  |  | @pytest.mark.skipif(sys.platform in ["win32", "cygwin"], reason="Causes OOM on windows github runner") | 
					
						
							| 
									
										
										
										
											2021-12-22 17:20:23 +01:00
										 |  |  | @pytest.mark.slow | 
					
						
							|  |  |  | @pytest.mark.generator | 
					
						
							| 
									
										
										
										
											2022-03-07 19:25:33 +01:00
										 |  |  | @pytest.mark.parametrize("retriever,document_store", [("embedding", "memory")], indirect=True) | 
					
						
							| 
									
										
										
										
											2021-12-22 17:20:23 +01:00
										 |  |  | def test_generator_pipeline_with_translator( | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     document_store, retriever, rag_generator, en_to_de_translator, de_to_en_translator | 
					
						
							| 
									
										
										
										
											2021-12-22 17:20:23 +01:00
										 |  |  | ): | 
					
						
							|  |  |  |     document_store.write_documents(DOCS_WITH_EMBEDDINGS) | 
					
						
							|  |  |  |     query = "Was ist die Hauptstadt der Bundesrepublik Deutschland?" | 
					
						
							|  |  |  |     base_pipeline = GenerativeQAPipeline(retriever=retriever, generator=rag_generator) | 
					
						
							|  |  |  |     pipeline = TranslationWrapperPipeline( | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |         input_translator=de_to_en_translator, output_translator=en_to_de_translator, pipeline=base_pipeline | 
					
						
							| 
									
										
										
										
											2021-12-22 17:20:23 +01:00
										 |  |  |     ) | 
					
						
							|  |  |  |     output = pipeline.run(query=query, params={"Generator": {"top_k": 2}, "Retriever": {"top_k": 1}}) | 
					
						
							|  |  |  |     answers = output["answers"] | 
					
						
							|  |  |  |     assert len(answers) == 2 | 
					
						
							|  |  |  |     assert "berlin" in answers[0].answer | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-10-30 18:06:02 +01:00
										 |  |  | @pytest.mark.slow | 
					
						
							|  |  |  | @pytest.mark.generator | 
					
						
							|  |  |  | def test_rag_token_generator(rag_generator): | 
					
						
							| 
									
										
										
										
											2020-12-03 10:27:06 +01:00
										 |  |  |     query = "What is capital of the Germany?" | 
					
						
							|  |  |  |     generated_docs = rag_generator.predict(query=query, documents=DOCS_WITH_EMBEDDINGS, top_k=1) | 
					
						
							| 
									
										
										
										
											2020-10-30 18:06:02 +01:00
										 |  |  |     answers = generated_docs["answers"] | 
					
						
							|  |  |  |     assert len(answers) == 1 | 
					
						
							| 
									
										
										
										
											2021-11-12 16:44:28 +01:00
										 |  |  |     assert "berlin" in answers[0].answer | 
					
						
							| 
									
										
										
										
											2020-12-03 10:27:06 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @pytest.mark.slow | 
					
						
							|  |  |  | @pytest.mark.generator | 
					
						
							| 
									
										
										
										
											2021-10-15 15:37:46 +02:00
										 |  |  | @pytest.mark.parametrize("document_store", ["memory"], indirect=True) | 
					
						
							| 
									
										
										
										
											2021-09-27 10:52:07 +02:00
										 |  |  | @pytest.mark.parametrize("retriever", ["embedding"], indirect=True) | 
					
						
							| 
									
										
										
										
											2020-12-03 10:27:06 +01:00
										 |  |  | def test_generator_pipeline(document_store, retriever, rag_generator): | 
					
						
							|  |  |  |     document_store.write_documents(DOCS_WITH_EMBEDDINGS) | 
					
						
							|  |  |  |     query = "What is capital of the Germany?" | 
					
						
							|  |  |  |     pipeline = GenerativeQAPipeline(retriever=retriever, generator=rag_generator) | 
					
						
							| 
									
										
										
										
											2021-09-10 11:41:16 +02:00
										 |  |  |     output = pipeline.run(query=query, params={"Generator": {"top_k": 2}, "Retriever": {"top_k": 1}}) | 
					
						
							| 
									
										
										
										
											2020-12-03 10:27:06 +01:00
										 |  |  |     answers = output["answers"] | 
					
						
							|  |  |  |     assert len(answers) == 2 | 
					
						
							| 
									
										
										
										
											2021-11-12 16:44:28 +01:00
										 |  |  |     assert "berlin" in answers[0].answer | 
					
						
							| 
									
										
										
										
											2021-02-12 15:58:26 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-09 21:29:05 +01:00
										 |  |  | @pytest.mark.skipif(sys.platform in ["win32", "cygwin"], reason="Causes OOM on windows github runner") | 
					
						
							| 
									
										
										
										
											2021-06-14 17:53:43 +02:00
										 |  |  | @pytest.mark.slow | 
					
						
							|  |  |  | @pytest.mark.generator | 
					
						
							| 
									
										
										
										
											2021-10-15 15:37:46 +02:00
										 |  |  | @pytest.mark.parametrize("document_store", ["memory"], indirect=True) | 
					
						
							| 
									
										
										
										
											2022-03-08 15:11:41 +01:00
										 |  |  | @pytest.mark.parametrize("retriever", ["retribert", "dpr_lfqa"], indirect=True) | 
					
						
							|  |  |  | @pytest.mark.parametrize("lfqa_generator", ["yjernite/bart_eli5", "vblagoje/bart_lfqa"], indirect=True) | 
					
						
							| 
									
										
										
										
											2022-01-10 17:10:32 +00:00
										 |  |  | @pytest.mark.embedding_dim(128) | 
					
						
							| 
									
										
										
										
											2022-03-08 15:11:41 +01:00
										 |  |  | def test_lfqa_pipeline(document_store, retriever, lfqa_generator): | 
					
						
							| 
									
										
										
										
											2021-06-14 17:53:43 +02:00
										 |  |  |     # reuse existing DOCS but regenerate embeddings with retribert | 
					
						
							|  |  |  |     docs: List[Document] = [] | 
					
						
							|  |  |  |     for idx, d in enumerate(DOCS_WITH_EMBEDDINGS): | 
					
						
							| 
									
										
										
										
											2021-10-13 14:23:23 +02:00
										 |  |  |         docs.append(Document(d.content, str(idx))) | 
					
						
							| 
									
										
										
										
											2021-06-14 17:53:43 +02:00
										 |  |  |     document_store.write_documents(docs) | 
					
						
							|  |  |  |     document_store.update_embeddings(retriever) | 
					
						
							|  |  |  |     query = "Tell me about Berlin?" | 
					
						
							| 
									
										
										
										
											2022-03-08 15:11:41 +01:00
										 |  |  |     pipeline = GenerativeQAPipeline(generator=lfqa_generator, retriever=retriever) | 
					
						
							| 
									
										
										
										
											2021-09-10 11:41:16 +02:00
										 |  |  |     output = pipeline.run(query=query, params={"top_k": 1}) | 
					
						
							| 
									
										
										
										
											2021-06-14 17:53:43 +02:00
										 |  |  |     answers = output["answers"] | 
					
						
							| 
									
										
										
										
											2022-03-08 15:11:41 +01:00
										 |  |  |     assert len(answers) == 1, answers | 
					
						
							|  |  |  |     assert "Germany" in answers[0].answer, answers[0].answer | 
					
						
							| 
									
										
										
										
											2021-06-14 17:53:43 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @pytest.mark.slow | 
					
						
							|  |  |  | @pytest.mark.generator | 
					
						
							|  |  |  | @pytest.mark.parametrize("document_store", ["memory"], indirect=True) | 
					
						
							|  |  |  | @pytest.mark.parametrize("retriever", ["retribert"], indirect=True) | 
					
						
							| 
									
										
										
										
											2022-01-10 17:10:32 +00:00
										 |  |  | @pytest.mark.embedding_dim(128) | 
					
						
							| 
									
										
										
										
											2021-06-14 17:53:43 +02:00
										 |  |  | def test_lfqa_pipeline_unknown_converter(document_store, retriever): | 
					
						
							|  |  |  |     # reuse existing DOCS but regenerate embeddings with retribert | 
					
						
							|  |  |  |     docs: List[Document] = [] | 
					
						
							|  |  |  |     for idx, d in enumerate(DOCS_WITH_EMBEDDINGS): | 
					
						
							| 
									
										
										
										
											2021-10-13 14:23:23 +02:00
										 |  |  |         docs.append(Document(d.content, str(idx))) | 
					
						
							| 
									
										
										
										
											2021-06-14 17:53:43 +02:00
										 |  |  |     document_store.write_documents(docs) | 
					
						
							|  |  |  |     document_store.update_embeddings(retriever) | 
					
						
							|  |  |  |     seq2seq = Seq2SeqGenerator(model_name_or_path="patrickvonplaten/t5-tiny-random") | 
					
						
							|  |  |  |     query = "Tell me about Berlin?" | 
					
						
							|  |  |  |     pipeline = GenerativeQAPipeline(retriever=retriever, generator=seq2seq) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # raises exception as we don't have converter for "patrickvonplaten/t5-tiny-random" in Seq2SeqGenerator | 
					
						
							| 
									
										
										
										
											2021-10-20 17:57:15 +02:00
										 |  |  |     with pytest.raises(Exception) as exception_info: | 
					
						
							| 
									
										
										
										
											2021-09-10 11:41:16 +02:00
										 |  |  |         output = pipeline.run(query=query, params={"top_k": 1}) | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     assert "doesn't have input converter registered for patrickvonplaten/t5-tiny-random" in str(exception_info.value) | 
					
						
							| 
									
										
										
										
											2021-06-14 17:53:43 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @pytest.mark.slow | 
					
						
							|  |  |  | @pytest.mark.generator | 
					
						
							|  |  |  | @pytest.mark.parametrize("document_store", ["memory"], indirect=True) | 
					
						
							|  |  |  | @pytest.mark.parametrize("retriever", ["retribert"], indirect=True) | 
					
						
							| 
									
										
										
										
											2022-01-10 17:10:32 +00:00
										 |  |  | @pytest.mark.embedding_dim(128) | 
					
						
							| 
									
										
										
										
											2021-06-14 17:53:43 +02:00
										 |  |  | def test_lfqa_pipeline_invalid_converter(document_store, retriever): | 
					
						
							|  |  |  |     # reuse existing DOCS but regenerate embeddings with retribert | 
					
						
							|  |  |  |     docs: List[Document] = [] | 
					
						
							|  |  |  |     for idx, d in enumerate(DOCS_WITH_EMBEDDINGS): | 
					
						
							| 
									
										
										
										
											2021-10-13 14:23:23 +02:00
										 |  |  |         docs.append(Document(d.content, str(idx))) | 
					
						
							| 
									
										
										
										
											2021-06-14 17:53:43 +02:00
										 |  |  |     document_store.write_documents(docs) | 
					
						
							|  |  |  |     document_store.update_embeddings(retriever) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     class _InvalidConverter: | 
					
						
							|  |  |  |         def __call__(self, some_invalid_para: str, another_invalid_param: str) -> None: | 
					
						
							|  |  |  |             pass | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     seq2seq = Seq2SeqGenerator( | 
					
						
							|  |  |  |         model_name_or_path="patrickvonplaten/t5-tiny-random", input_converter=_InvalidConverter() | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-06-14 17:53:43 +02:00
										 |  |  |     query = "This query will fail due to InvalidConverter used" | 
					
						
							|  |  |  |     pipeline = GenerativeQAPipeline(retriever=retriever, generator=seq2seq) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # raises exception as we are using invalid method signature in _InvalidConverter | 
					
						
							| 
									
										
										
										
											2021-10-20 17:57:15 +02:00
										 |  |  |     with pytest.raises(Exception) as exception_info: | 
					
						
							| 
									
										
										
										
											2021-09-10 11:41:16 +02:00
										 |  |  |         output = pipeline.run(query=query, params={"top_k": 1}) | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     assert "does not have a valid __call__ method signature" in str(exception_info.value) |