diff --git a/docs/_src/api/api/pipelines.md b/docs/_src/api/api/pipelines.md index de88f5a1c..1d0f963da 100644 --- a/docs/_src/api/api/pipelines.md +++ b/docs/_src/api/api/pipelines.md @@ -5,7 +5,7 @@ ## Pipeline Objects ```python -class Pipeline() +class Pipeline(ABC) ``` Pipeline brings together building blocks to build a complex search pipeline with Haystack & user-defined components. @@ -131,7 +131,7 @@ variable 'MYDOCSTORE_PARAMS_INDEX=documents-2021' can be set. Note that an ## BaseStandardPipeline Objects ```python -class BaseStandardPipeline() +class BaseStandardPipeline(ABC) ``` @@ -316,6 +316,32 @@ Initialize a Pipeline for finding similar FAQs using semantic document search. - `retriever`: Retriever instance + +## TranslationWrapperPipeline Objects + +```python +class TranslationWrapperPipeline(BaseStandardPipeline) +``` + +Takes an existing search pipeline and adds one "input translation node" after the Query and one +"output translation" node just before returning the results + + +#### \_\_init\_\_ + +```python + | __init__(input_translator: BaseTranslator, output_translator: BaseTranslator, pipeline: BaseStandardPipeline) +``` + +Wrap a given `pipeline` with the `input_translator` and `output_translator`. + +**Arguments**: + +- `input_translator`: A Translator node that shall translate the input query from language A to B +- `output_translator`: A Translator node that shall translate the pipeline results from language B to A +- `pipeline`: The pipeline object (e.g. ExtractiveQAPipeline) you want to "wrap". +Note that pipelines with split or merge nodes are currently not supported. + ## JoinDocuments Objects diff --git a/docs/_src/usage/usage/translator.md b/docs/_src/usage/usage/translator.md new file mode 100644 index 000000000..d7b71dadf --- /dev/null +++ b/docs/_src/usage/usage/translator.md @@ -0,0 +1,58 @@ + + +# Translator + +Texts come in different languages. This is not different for search and there are plenty of options to deal with it. +One of them is actually to translate the incoming query, the documents or the search results. + +Let's imagine you have an English corpus of technical docs, but the mother tongue of many of your users is French. +You can use a Translator node in your pipeline to +1. Translate the incoming query from French to English +2. Search in your English corpus for the right document / answer +3. Translate the results back from English to French + +
+ +**Example (Stand-alone Translator)** + +You can use the Translator component directly to translate your query or document(s): +```python +DOCS = [ + Document( + text="""Heinz von Foerster was an Austrian American scientist + combining physics and philosophy, and widely attributed + as the originator of Second-order cybernetics.""" + ) + ] +translator = TransformersTranslator(model_name_or_path="Helsinki-NLP/opus-mt-en-fr") +res = translator.translate(documents=DOCS, query=None) +``` + +**Example (Wrapping another Pipeline)** + +You can also wrap one of your existing pipelines and "add" the translation nodes at the beginning and at the end of your pipeline. +For example, lets translate the incoming query to from French to English, then do our document retrieval and then translate the results back from English to French: + +```python +from haystack.pipeline import TranslationWrapperPipeline, DocumentSearchPipeline +from haystack.translator import TransformersTranslator + +pipeline = DocumentSearchPipeline(retriever=my_dpr_retriever) + +in_translator = TransformersTranslator(model_name_or_path="Helsinki-NLP/opus-mt-fr-en") +out_translator = TransformersTranslator(model_name_or_path="Helsinki-NLP/opus-mt-en-fr") + +pipeline_with_translation = TranslationWrapperPipeline(input_translator=in_translator, + output_translator=out_translator, + pipeline=pipeline) +``` + + +
diff --git a/haystack/pipeline.py b/haystack/pipeline.py index f05eb7a6c..a45394669 100644 --- a/haystack/pipeline.py +++ b/haystack/pipeline.py @@ -1,3 +1,4 @@ +from abc import ABC import os from copy import deepcopy from pathlib import Path @@ -13,9 +14,10 @@ from haystack.generator.base import BaseGenerator from haystack.reader.base import BaseReader from haystack.retriever.base import BaseRetriever from haystack.summarizer.base import BaseSummarizer +from haystack.translator.base import BaseTranslator -class Pipeline: +class Pipeline(ABC): """ Pipeline brings together building blocks to build a complex search pipeline with Haystack & user-defined components. @@ -45,7 +47,7 @@ class Pipeline: In cases when the predecessor node has multiple outputs, e.g., a "QueryClassifier", the output must be specified explicitly as "QueryClassifier.output_2". """ - self.graph.add_node(name, component=component) + self.graph.add_node(name, component=component, inputs=inputs) for i in inputs: if "." in i: @@ -93,7 +95,7 @@ class Pipeline: while has_next_node: output_dict, stream_id = self.graph.nodes[current_node_id]["component"].run(**input_dict) input_dict = output_dict - next_nodes = self._get_next_nodes(current_node_id, stream_id) + next_nodes = self.get_next_nodes(current_node_id, stream_id) if len(next_nodes) > 1: join_node_id = list(nx.neighbors(self.graph, next_nodes[0]))[0] @@ -114,7 +116,7 @@ class Pipeline: return output_dict - def _get_next_nodes(self, node_id: str, stream_id: str): + def get_next_nodes(self, node_id: str, stream_id: str): current_node_edges = self.graph.edges(node_id, data=True) next_nodes = [ next_node @@ -259,7 +261,7 @@ class Pipeline: definition["params"][param_name] = value -class BaseStandardPipeline: +class BaseStandardPipeline(ABC): pipeline: Pipeline def add_node(self, component, name: str, inputs: List[str]): @@ -451,6 +453,52 @@ class FAQPipeline(BaseStandardPipeline): return results +class TranslationWrapperPipeline(BaseStandardPipeline): + + """ + Takes an existing search pipeline and adds one "input translation node" after the Query and one + "output translation" node just before returning the results + """ + + def __init__( + self, + input_translator: BaseTranslator, + output_translator: BaseTranslator, + pipeline: BaseStandardPipeline + ): + """ + Wrap a given `pipeline` with the `input_translator` and `output_translator`. + + :param input_translator: A Translator node that shall translate the input query from language A to B + :param output_translator: A Translator node that shall translate the pipeline results from language B to A + :param pipeline: The pipeline object (e.g. ExtractiveQAPipeline) you want to "wrap". + Note that pipelines with split or merge nodes are currently not supported. + """ + + self.pipeline = Pipeline() + self.pipeline.add_node(component=input_translator, name="InputTranslator", inputs=["Query"]) + + graph = pipeline.pipeline.graph + previous_node_name = ["InputTranslator"] + # Traverse in BFS + for node in graph.nodes: + if node == "Query": + continue + + # TODO: Do not work properly for Join Node and Answer format + if graph.nodes[node]["inputs"] and len(graph.nodes[node]["inputs"]) > 1: + raise AttributeError("Split and merge nodes are not supported currently") + + self.pipeline.add_node(name=node, component=graph.nodes[node]["component"], inputs=previous_node_name) + previous_node_name = [node] + + self.pipeline.add_node(component=output_translator, name="OutputTranslator", inputs=previous_node_name) + + def run(self, **kwargs): + output = self.pipeline.run(**kwargs) + return output + + class QueryNode: outgoing_edges = 1 diff --git a/haystack/retriever/dense.py b/haystack/retriever/dense.py index 90a619481..71674bd71 100644 --- a/haystack/retriever/dense.py +++ b/haystack/retriever/dense.py @@ -6,8 +6,6 @@ from pathlib import Path from tqdm import tqdm from haystack.document_store.base import BaseDocumentStore -from haystack.document_store.elasticsearch import ElasticsearchDocumentStore -from haystack.document_store.memory import InMemoryDocumentStore from haystack import Document from haystack.retriever.base import BaseRetriever diff --git a/haystack/summarizer/transformers.py b/haystack/summarizer/transformers.py index 8f6f9d986..6a548c24d 100644 --- a/haystack/summarizer/transformers.py +++ b/haystack/summarizer/transformers.py @@ -1,9 +1,8 @@ import logging -from typing import Any, Dict, List, Optional +from typing import List, Optional from transformers import pipeline from transformers.models.auto.modeling_auto import AutoModelForSeq2SeqLM -from transformers import AutoTokenizer from haystack import Document from haystack.summarizer.base import BaseSummarizer diff --git a/haystack/translator/__init__.py b/haystack/translator/__init__.py new file mode 100644 index 000000000..c4c6ee277 --- /dev/null +++ b/haystack/translator/__init__.py @@ -0,0 +1 @@ +from haystack.translator.transformers import TransformersTranslator diff --git a/haystack/translator/base.py b/haystack/translator/base.py new file mode 100644 index 000000000..d6c75efce --- /dev/null +++ b/haystack/translator/base.py @@ -0,0 +1,58 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Mapping, Optional, Union + +from haystack import Document + + +class BaseTranslator(ABC): + """ + Abstract class for a Translator component that translates either a query or a doc from language A to language B. + """ + + outgoing_edges = 1 + + @abstractmethod + def translate( + self, + query: Optional[str] = None, + documents: Optional[Union[List[Document], List[str], List[Dict[str, Any]]]] = None, + dict_key: Optional[str] = None, + **kwargs + ) -> Union[str, List[Document], List[str], List[Dict[str, Any]]]: + """ + Translate the passed query or a list of documents from language A to B. + """ + pass + + def run( + self, + query: Optional[str] = None, + documents: Optional[Union[List[Document], List[str], List[Dict[str, Any]]]] = None, + answers: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + dict_key: Optional[str] = None, + **kwargs + ): + """Method that gets executed when this class is used as a Node in a Haystack Pipeline""" + + results: Dict = { + **kwargs + } + + # This will cover input query stage + if query: + results["query"] = self.translate(query=query) + # This will cover retriever and summarizer + if documents: + dict_key = dict_key or "text" + results["documents"] = self.translate(documents=documents, dict_key=dict_key) + + if answers: + dict_key = dict_key or "answer" + if isinstance(answers, Mapping): + # This will cover reader + results["answers"] = self.translate(documents=answers["answers"], dict_key=dict_key) + else: + # This will cover generator + results["answers"] = self.translate(documents=answers, dict_key=dict_key) + + return results, "output_1" \ No newline at end of file diff --git a/haystack/translator/transformers.py b/haystack/translator/transformers.py new file mode 100644 index 000000000..de90a760c --- /dev/null +++ b/haystack/translator/transformers.py @@ -0,0 +1,127 @@ +import logging +from typing import Any, Dict, List, Optional, Union + +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer + +from haystack import Document +from haystack.translator.base import BaseTranslator + +logger = logging.getLogger(__name__) + + +class TransformersTranslator(BaseTranslator): + """ + Translator component based on Seq2Seq models from Huggingface's transformers library. + Exemplary use cases: + - Translate a query from Language A to B (e.g. if you only have good models + documents in language B) + - Translate a document from Language A to B (e.g. if you want to return results in the native language of the user) + + We currently recommend using OPUS models (see __init__() for details) + + **Example:** + + ```python + | DOCS = [ + | Document(text="Heinz von Foerster was an Austrian American scientist combining physics and philosophy, + | and widely attributed as the originator of Second-order cybernetics.") + | ] + | translator = TransformersTranslator(model_name_or_path="Helsinki-NLP/opus-mt-en-de") + | res = translator.translate(documents=DOCS, query=None) + ``` + """ + def __init__( + self, + model_name_or_path: str, + tokenizer_name: Optional[str] = None, + max_seq_len: Optional[int] = None, + clean_up_tokenization_spaces: Optional[bool] = True + ): + """ Initialize the translator with a model that fits your targeted languages. While we support all seq2seq + models from Hugging Face's model hub, we recommend using the OPUS models from Helsiniki NLP. They provide plenty + of different models, usually one model per language pair and translation direction. + They have a pretty standardized naming that should help you find the right model: + - "Helsinki-NLP/opus-mt-en-de" => translating from English to German + - "Helsinki-NLP/opus-mt-de-en" => translating from German to English + - "Helsinki-NLP/opus-mt-fr-en" => translating from French to English + - "Helsinki-NLP/opus-mt-hi-en"=> translating from Hindi to English + ... + + They also have a few multilingual models that support multiple languages at once. + + :param model_name_or_path: Name of the seq2seq model that shall be used for translation. + Can be a remote name from Huggingface's modelhub or a local path. + :param tokenizer_name: Optional tokenizer name. If not supplied, `model_name_or_path` will also be used for the + tokenizer. + :param max_seq_len: The maximum sentence length the model accepts. (Optional) + :param clean_up_tokenization_spaces: Whether or not to clean up the tokenization spaces. (default True) + """ + + self.max_seq_len = max_seq_len + self.clean_up_tokenization_spaces = clean_up_tokenization_spaces + tokenizer_name = tokenizer_name or model_name_or_path + self.tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name + ) + self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path) + + def translate( + self, + query: Optional[str] = None, + documents: Optional[Union[List[Document], List[str], List[Dict[str, Any]]]] = None, + dict_key: Optional[str] = None, + **kwargs + ) -> Union[str, List[Document], List[str], List[Dict[str, Any]]]: + """ + Run the actual translation. You can supply a query or a list of documents. Whatever is supplied will be translated. + """ + if not query and not documents: + raise AttributeError("Translator need query or documents to perform translation") + + if query and documents: + raise AttributeError("Translator need either query or documents but not both") + + if documents and len(documents) == 0: + logger.warning("Empty documents list is passed") + return documents + + dict_key = dict_key or "text" + + if isinstance(documents, list): + if isinstance(documents[0], Document): + text_for_translator = [doc.text for doc in documents] # type: ignore + elif isinstance(documents[0], str): + text_for_translator = documents # type: ignore + else: + if not isinstance(documents[0].get(dict_key, None), str): # type: ignore + raise AttributeError(f"Dictionary should have {dict_key} key and it's value should be `str` type") + text_for_translator = [doc[dict_key] for doc in documents] # type: ignore + else: + text_for_translator: List[str] = [query] # type: ignore + + batch = self.tokenizer.prepare_seq2seq_batch( + src_texts=text_for_translator, + return_tensors="pt", + max_length=self.max_seq_len + ) + generated_output = self.model.generate(**batch) + translated_texts = self.tokenizer.batch_decode( + generated_output, + skip_special_tokens=True, + clean_up_tokenization_spaces=self.clean_up_tokenization_spaces + ) + + if query: + return translated_texts[0] + elif documents: + if isinstance(documents, list) and isinstance(documents[0], str): + return [translated_text for translated_text in translated_texts] + + for translated_text, doc in zip(translated_texts, documents): + if isinstance(doc, Document): + doc.text = translated_text + else: + doc[dict_key] = translated_text # type: ignore + + return documents + + raise AttributeError("Translator need query or documents to perform translation") diff --git a/test/conftest.py b/test/conftest.py index 8aefa3945..89fb0d0a8 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -23,6 +23,7 @@ from haystack.document_store.sql import SQLDocumentStore from haystack.reader.farm import FARMReader from haystack.reader.transformers import TransformersReader from haystack.summarizer.transformers import TransformersSummarizer +from haystack.translator import TransformersTranslator def _sql_session_rollback(self, attr): @@ -161,6 +162,20 @@ def summarizer(): ) +@pytest.fixture(scope="module") +def en_to_de_translator(): + return TransformersTranslator( + model_name_or_path="Helsinki-NLP/opus-mt-en-de", + ) + + +@pytest.fixture(scope="module") +def de_to_en_translator(): + return TransformersTranslator( + model_name_or_path="Helsinki-NLP/opus-mt-de-en", + ) + + @pytest.fixture(scope="module") def test_docs_xs(): return [ diff --git a/test/test_generator.py b/test/test_generator.py index bf8a6083a..e4715ead6 100644 --- a/test/test_generator.py +++ b/test/test_generator.py @@ -2,7 +2,7 @@ import numpy as np import pytest from haystack import Document -from haystack.pipeline import GenerativeQAPipeline +from haystack.pipeline import TranslationWrapperPipeline, GenerativeQAPipeline DOCS_WITH_EMBEDDINGS = [ Document( @@ -426,3 +426,33 @@ def test_generator_pipeline(document_store, retriever, rag_generator): answers = output["answers"] assert len(answers) == 2 assert "berlin" in answers[0]["answer"] + + +# Keeping few (retriever,document_store) combination to reduce test time +@pytest.mark.slow +@pytest.mark.generator +@pytest.mark.elasticsearch +@pytest.mark.parametrize( + "retriever,document_store", + [("embedding", "memory"), ("elasticsearch", "elasticsearch")], + indirect=True, +) +def test_generator_pipeline_with_translator( + document_store, + retriever, + rag_generator, + en_to_de_translator, + de_to_en_translator +): + document_store.write_documents(DOCS_WITH_EMBEDDINGS) + query = "Was ist die Hauptstadt der Bundesrepublik Deutschland?" + base_pipeline = GenerativeQAPipeline(retriever=retriever, generator=rag_generator) + pipeline = TranslationWrapperPipeline( + input_translator=de_to_en_translator, + output_translator=en_to_de_translator, + pipeline=base_pipeline + ) + output = pipeline.run(query=query, top_k_generator=2, top_k_retriever=1) + answers = output["answers"] + assert len(answers) == 2 + assert "berlin" in answers[0]["answer"] diff --git a/test/test_pipeline.py b/test/test_pipeline.py index 7cad05d9f..68abaf910 100644 --- a/test/test_pipeline.py +++ b/test/test_pipeline.py @@ -3,7 +3,8 @@ from pathlib import Path import pytest from haystack.document_store.elasticsearch import ElasticsearchDocumentStore -from haystack.pipeline import JoinDocuments, ExtractiveQAPipeline, Pipeline, FAQPipeline, DocumentSearchPipeline +from haystack.pipeline import TranslationWrapperPipeline, JoinDocuments, ExtractiveQAPipeline, Pipeline, FAQPipeline, \ + DocumentSearchPipeline from haystack.retriever.dense import DensePassageRetriever from haystack.retriever.sparse import ElasticsearchRetriever @@ -137,6 +138,27 @@ def test_document_search_pipeline(retriever, document_store): assert len(output["documents"]) == 1 +@pytest.mark.slow +@pytest.mark.elasticsearch +@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True) +def test_extractive_qa_answers_with_translator(reader, retriever_with_docs, en_to_de_translator, de_to_en_translator): + base_pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs) + pipeline = TranslationWrapperPipeline( + input_translator=de_to_en_translator, + output_translator=en_to_de_translator, + pipeline=base_pipeline + ) + + prediction = pipeline.run(query="Wer lebt in Berlin?", top_k_retriever=10, top_k_reader=3) + assert prediction is not None + assert prediction["query"] == "Wer lebt in Berlin?" + assert "Carla" in prediction["answers"][0]["answer"] + assert prediction["answers"][0]["probability"] <= 1 + assert prediction["answers"][0]["probability"] >= 0 + assert prediction["answers"][0]["meta"]["meta_field"] == "test1" + assert prediction["answers"][0]["context"] == "My name is Carla and I live in Berlin" + + @pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True) @pytest.mark.parametrize("reader", ["farm"], indirect=True) def test_join_document_pipeline(document_store_with_docs, reader): diff --git a/test/test_summarizer.py b/test/test_summarizer.py index 0f9a6527e..17747cc1a 100644 --- a/test/test_summarizer.py +++ b/test/test_summarizer.py @@ -1,7 +1,7 @@ import pytest from haystack import Document -from haystack.pipeline import SearchSummarizationPipeline +from haystack.pipeline import TranslationWrapperPipeline, SearchSummarizationPipeline from haystack.retriever.dense import DensePassageRetriever, EmbeddingRetriever DOCS = [ @@ -94,3 +94,41 @@ def test_summarization_pipeline_one_summary(document_store, retriever, summarize answers = output["answers"] assert len(answers) == 1 assert answers[0]["answer"] in EXPECTED_ONE_SUMMARIES + + +# Keeping few (retriever,document_store) combination to reduce test time +@pytest.mark.slow +@pytest.mark.elasticsearch +@pytest.mark.summarizer +@pytest.mark.parametrize( + "retriever,document_store", + [("embedding", "memory"), ("elasticsearch", "elasticsearch")], + indirect=True, +) +def test_summarization_pipeline_with_translator( + document_store, + retriever, + summarizer, + en_to_de_translator, + de_to_en_translator +): + document_store.write_documents(SPLIT_DOCS) + + if isinstance(retriever, EmbeddingRetriever) or isinstance(retriever, DensePassageRetriever): + document_store.update_embeddings(retriever=retriever) + + query = "Wo steht der Eiffelturm?" + base_pipeline = SearchSummarizationPipeline(retriever=retriever, summarizer=summarizer) + pipeline = TranslationWrapperPipeline( + input_translator=de_to_en_translator, + output_translator=en_to_de_translator, + pipeline=base_pipeline + ) + output = pipeline.run(query=query, top_k_retriever=2, generate_single_summary=True) + # SearchSummarizationPipeline return answers but Summarizer return documents + documents = output["documents"] + assert len(documents) == 1 + assert documents[0].text in [ + "Der Eiffelturm ist ein Wahrzeichen in Paris, Frankreich.", + "Der Eiffelturm, der 1889 in Paris, Frankreich, erbaut wurde, ist das höchste freistehende Bauwerk der Welt." + ] diff --git a/test/test_translator.py b/test/test_translator.py new file mode 100644 index 000000000..803221f2b --- /dev/null +++ b/test/test_translator.py @@ -0,0 +1,46 @@ +from haystack import Document + +import pytest + +EXPECTED_OUTPUT = "Ich lebe in Berlin" +INPUT = "I live in Berlin" + + +def test_translator_with_query(en_to_de_translator): + assert en_to_de_translator.translate(query=INPUT) == EXPECTED_OUTPUT + + +def test_translator_with_list(en_to_de_translator): + assert en_to_de_translator.translate(documents=[INPUT])[0] == EXPECTED_OUTPUT + + +def test_translator_with_document(en_to_de_translator): + assert en_to_de_translator.translate(documents=[Document(text=INPUT)])[0].text == EXPECTED_OUTPUT + + +def test_translator_with_dictionary(en_to_de_translator): + assert en_to_de_translator.translate(documents=[{"text": INPUT}])[0]["text"] == EXPECTED_OUTPUT + + +def test_translator_with_dictionary_with_dict_key(en_to_de_translator): + assert en_to_de_translator.translate(documents=[{"key": INPUT}], dict_key="key")[0]["key"] == EXPECTED_OUTPUT + + +def test_translator_with_empty_input(en_to_de_translator): + with pytest.raises(AttributeError): + en_to_de_translator.translate() + + +def test_translator_with_query_and_documents(en_to_de_translator): + with pytest.raises(AttributeError): + en_to_de_translator.translate(query=INPUT, documents=[INPUT]) + + +def test_translator_with_dict_without_text_key(en_to_de_translator): + with pytest.raises(AttributeError): + en_to_de_translator.translate(documents=[{"text1": INPUT}]) + + +def test_translator_with_dict_with_non_string_value(en_to_de_translator): + with pytest.raises(AttributeError): + en_to_de_translator.translate(documents=[{"text": 123}])