From 5e62e54875551b5f8eace8f52255fc7720512c0e Mon Sep 17 00:00:00 2001 From: Tanay Soni Date: Mon, 30 Nov 2020 17:50:04 +0100 Subject: [PATCH] Rename question parameter to query (#614) --- haystack/document_store/elasticsearch.py | 6 ++-- haystack/finder.py | 23 ++++++++++--- haystack/pipeline.py | 10 +++--- haystack/reader/base.py | 9 +++-- haystack/reader/farm.py | 42 ++++++++++++------------ haystack/reader/transformers.py | 21 ++++++------ haystack/retriever/base.py | 12 +++---- haystack/retriever/sparse.py | 4 +-- haystack/utils.py | 6 ++-- test/conftest.py | 4 +-- test/test_finder.py | 2 +- test/test_pipeline.py | 8 ++--- test/test_reader.py | 10 +++--- 13 files changed, 85 insertions(+), 72 deletions(-) diff --git a/haystack/document_store/elasticsearch.py b/haystack/document_store/elasticsearch.py index 54fe65969..ab8fe7720 100644 --- a/haystack/document_store/elasticsearch.py +++ b/haystack/document_store/elasticsearch.py @@ -455,10 +455,10 @@ class ElasticsearchDocumentStore(BaseDocumentStore): body["query"]["bool"]["filter"] = filter_clause # Retrieval via custom query - elif custom_query: # substitute placeholder for question and filters for the custom_query template string + elif custom_query: # substitute placeholder for query and filters for the custom_query template string template = Template(custom_query) - # replace all "${question}" placeholder(s) with query - substitutions = {"question": query} + # replace all "${query}" placeholder(s) with query + substitutions = {"query": query} # For each filter we got passed, we'll try to find & replace the corresponding placeholder in the template # Example: filters={"years":[2018]} => replaces {$years} in custom_query with '[2018]' if filters: diff --git a/haystack/finder.py b/haystack/finder.py index 31a62ca9e..84af6af7b 100644 --- a/haystack/finder.py +++ b/haystack/finder.py @@ -28,8 +28,14 @@ class Finder: :param reader: Reader instance :param retriever: Retriever instance """ - logger.warning("The 'Finder' class will be deprecated in the next Haystack release in favour of the new" - "`Pipeline` class.") + logger.warning( + """DEPRECATION WARNINGS: + 1. The 'Finder' class will be deprecated in the next Haystack release in + favour of a new `Pipeline` class that supports building custom search pipelines using Haystack components + including Retriever, Readers, and Generators. + For more details, please refer to the issue: https://github.com/deepset-ai/haystack/issues/544 + 2. The `question` parameter in search requests & results is renamed to `query`.""" + ) self.retriever = retriever self.reader = reader if self.reader is None and self.retriever is None: @@ -48,6 +54,14 @@ class Finder: :return: """ + logger.warning( + """DEPRECATION WARNINGS: + 1. The 'Finder' class will be deprecated in the next Haystack release in + favour of a new `Pipeline` class that supports building custom search pipelines using Haystack components + including Retriever, Readers, and Generators. + For more details, please refer to the issue: https://github.com/deepset-ai/haystack/issues/544 + 2. The `question` parameter in search requests & results is renamed to `query`.""" + ) if self.retriever is None or self.reader is None: raise AttributeError("Finder.get_answers requires self.retriever AND self.reader") @@ -65,9 +79,10 @@ class Finder: len_chars = sum([len(d.text) for d in documents]) logger.info(f"Reader is looking for detailed answer in {len_chars} chars ...") - results = self.reader.predict(question=question, + results = self.reader.predict(query=question, documents=documents, top_k=top_k_reader) # type: Dict[str, Any] + results["question"] = results["query"] # Add corresponding document_name and more meta data, if an answer contains the document_id for ans in results["answers"]: @@ -364,7 +379,7 @@ class Finder: self.reader.return_no_answers = True reader_start_time = time.time() predictions = self.reader.predict_batch(questions_with_correct_doc, - top_k_per_question=top_k_reader, batch_size=batch_size) + top_k=top_k_reader, batch_size=batch_size) reader_total_time = time.time() - reader_start_time for pred in predictions: diff --git a/haystack/pipeline.py b/haystack/pipeline.py index 0e68a2b71..b45b184bd 100644 --- a/haystack/pipeline.py +++ b/haystack/pipeline.py @@ -127,8 +127,8 @@ class ExtractiveQAPipeline: self.pipeline.add_node(component=retriever, name="Retriever", inputs=["Query"]) self.pipeline.add_node(component=reader, name="Reader", inputs=["Retriever"]) - def run(self, question, top_k_retriever=5, top_k_reader=5): - output = self.pipeline.run(question=question, + def run(self, query, top_k_retriever=5, top_k_reader=5): + output = self.pipeline.run(query=query, top_k_retriever=top_k_retriever, top_k_reader=top_k_reader) return output @@ -150,8 +150,8 @@ class DocumentSearchPipeline: self.pipeline = Pipeline() self.pipeline.add_node(component=retriever, name="Retriever", inputs=["Query"]) - def run(self, question, top_k_retriever=5): - output = self.pipeline.run(question=question, top_k_retriever=top_k_retriever) + def run(self, query, top_k_retriever=5): + output = self.pipeline.run(query=query, top_k_retriever=top_k_retriever) document_dicts = [doc.to_dict() for doc in output["documents"]] output["documents"] = document_dicts return output @@ -183,7 +183,7 @@ class JoinDocuments: for i, _ in inputs: documents.extend(i["documents"]) output = { - "question": inputs[0][0]["question"], + "query": inputs[0][0]["query"], "documents": documents } return output, "output_1" diff --git a/haystack/reader/base.py b/haystack/reader/base.py index 1b2d54382..49fdb7132 100644 --- a/haystack/reader/base.py +++ b/haystack/reader/base.py @@ -12,12 +12,11 @@ class BaseReader(ABC): outgoing_edges = 1 @abstractmethod - def predict(self, question: str, documents: List[Document], top_k: Optional[int] = None): + def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None): pass @abstractmethod - def predict_batch(self, question_doc_list: List[dict], top_k_per_question: Optional[int] = None, - batch_size: Optional[int] = None): + def predict_batch(self, query_doc_list: List[dict], top_k: Optional[int] = None, batch_size: Optional[int] = None): pass @staticmethod @@ -47,9 +46,9 @@ class BaseReader(ABC): "meta": None,} return no_ans_prediction, max_no_ans_gap - def run(self, question: str, documents: List[Document], top_k: Optional[int] = None): + def run(self, query: str, documents: List[Document], top_k: Optional[int] = None): if documents: - results = self.predict(question=question, documents=documents, top_k=top_k) + results = self.predict(query=query, documents=documents, top_k=top_k) else: results = {"answers": []} diff --git a/haystack/reader/farm.py b/haystack/reader/farm.py index 5be14a920..2956fd9f3 100644 --- a/haystack/reader/farm.py +++ b/haystack/reader/farm.py @@ -243,16 +243,16 @@ class FARMReader(BaseReader): self.inferencer.model.save(directory) self.inferencer.processor.save(directory) - def predict_batch(self, question_doc_list: List[dict], top_k_per_question: int = None, batch_size: int = None): + def predict_batch(self, query_doc_list: List[dict], top_k: int = None, batch_size: int = None): """ - Use loaded QA model to find answers for a list of questions in each question's supplied list of Document. + Use loaded QA model to find answers for a list of queries in each query's supplied list of Document. Returns list of dictionaries containing answers sorted by (desc.) probability - :param question_doc_list: List of dictionaries containing questions with their retrieved documents - :param top_k_per_question: The maximum number of answers to return for each question + :param query_doc_list: List of dictionaries containing queries with their retrieved documents + :param top_k: The maximum number of answers to return for each query :param batch_size: Number of samples the model receives in one batch for inference - :return: List of dictionaries containing question and answers + :return: List of dictionaries containing query and answers """ # convert input to FARM format @@ -261,20 +261,20 @@ class FARMReader(BaseReader): labels = [] # build input objects for inference_from_objects - for question_with_docs in question_doc_list: - documents = question_with_docs["docs"] - question = question_with_docs["question"] - labels.append(question) + for query_with_docs in query_doc_list: + documents = query_with_docs["docs"] + query = query_with_docs["question"] + labels.append(query) number_of_docs.append(len(documents)) for doc in documents: cur = QAInput(doc_text=doc.text, - questions=Question(text=question.question, + questions=Question(text=query.question, uid=doc.id)) inputs.append(cur) self.inferencer.batch_size = batch_size - # make predictions on all document-question pairs + # make predictions on all document-query pairs predictions = self.inferencer.inference_from_objects( objects=inputs, return_json=False, multiprocessing_chunksize=1 ) @@ -290,11 +290,11 @@ class FARMReader(BaseReader): result = [] for idx, group in enumerate(grouped_predictions): - answers, max_no_ans_gap = self._extract_answers_of_predictions(group, top_k_per_question) - question = group[0].question + answers, max_no_ans_gap = self._extract_answers_of_predictions(group, top_k) + query = group[0].question cur_label = labels[idx] result.append({ - "question": question, + "query": query, "no_ans_gap": max_no_ans_gap, "answers": answers, "label": cur_label @@ -302,15 +302,15 @@ class FARMReader(BaseReader): return result - def predict(self, question: str, documents: List[Document], top_k: Optional[int] = None): + def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None): """ - Use loaded QA model to find answers for a question in the supplied list of Document. + Use loaded QA model to find answers for a query in the supplied list of Document. Returns dictionaries containing answers sorted by (desc.) probability. Example: ```python |{ - | 'question': 'Who is the father of Arya Stark?', + | 'query': 'Who is the father of Arya Stark?', | 'answers':[ | {'answer': 'Eddard,', | 'context': " She travels with her father, Eddard, to King's Landing when he is ", @@ -324,17 +324,17 @@ class FARMReader(BaseReader): |} ``` - :param question: Question string + :param query: Query string :param documents: List of Document in which to search for the answer :param top_k: The maximum number of answers to return - :return: Dict containing question and answers + :return: Dict containing query and answers """ # convert input to FARM format inputs = [] for doc in documents: cur = QAInput(doc_text=doc.text, - questions=Question(text=question, + questions=Question(text=query, uid=doc.id)) inputs.append(cur) @@ -345,7 +345,7 @@ class FARMReader(BaseReader): ) # assemble answers from all the different documents & format them. answers, max_no_ans_gap = self._extract_answers_of_predictions(predictions, top_k) - result = {"question": question, + result = {"query": query, "no_ans_gap": max_no_ans_gap, "answers": answers} diff --git a/haystack/reader/transformers.py b/haystack/reader/transformers.py index 9548ec82c..61abcb03c 100644 --- a/haystack/reader/transformers.py +++ b/haystack/reader/transformers.py @@ -62,16 +62,16 @@ class TransformersReader(BaseReader): # TODO context_window_size behaviour different from behavior in FARMReader - def predict(self, question: str, documents: List[Document], top_k: Optional[int] = None): + def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None): """ - Use loaded QA model to find answers for a question in the supplied list of Document. + Use loaded QA model to find answers for a query in the supplied list of Document. Returns dictionaries containing answers sorted by (desc.) probability. Example: - + ```python |{ - | 'question': 'Who is the father of Arya Stark?', + | 'query': 'Who is the father of Arya Stark?', | 'answers':[ | {'answer': 'Eddard,', | 'context': " She travels with her father, Eddard, to King's Landing when he is ", @@ -85,10 +85,10 @@ class TransformersReader(BaseReader): |} ``` - :param question: Question string + :param query: Query string :param documents: List of Document in which to search for the answer :param top_k: The maximum number of answers to return - :return: Dict containing question and answers + :return: Dict containing query and answers """ # get top-answers for each candidate passage @@ -96,8 +96,8 @@ class TransformersReader(BaseReader): no_ans_gaps = [] best_overall_score = 0 for doc in documents: - query = {"context": doc.text, "question": question} - predictions = self.model(query, + transformers_query = {"context": doc.text, "question": query} + predictions = self.model(transformers_query, topk=self.top_k_per_candidate, handle_impossible_answer=self.return_no_answers, max_seq_len=self.max_seq_len, @@ -146,12 +146,11 @@ class TransformersReader(BaseReader): ) answers = answers[:top_k] - results = {"question": question, + results = {"query": query, "answers": answers} return results - def predict_batch(self, question_doc_list: List[dict], top_k_per_question: Optional[int] = None, - batch_size: Optional[int] = None): + def predict_batch(self, query_doc_list: List[dict], top_k: Optional[int] = None, batch_size: Optional[int] = None): raise NotImplementedError("Batch prediction not yet available in TransformersReader.") diff --git a/haystack/retriever/base.py b/haystack/retriever/base.py index b0d83dfb2..261d0fac4 100644 --- a/haystack/retriever/base.py +++ b/haystack/retriever/base.py @@ -52,7 +52,7 @@ class BaseRetriever(ABC): ) -> dict: """ Performs evaluation on the Retriever. - Retriever is evaluated based on whether it finds the correct document given the question string and at which + Retriever is evaluated based on whether it finds the correct document given the query string and at which position in the ranking of documents the correct document is. | Returns a dict containing the following metrics: @@ -68,7 +68,7 @@ class BaseRetriever(ABC): :param label_index: Index/Table in DocumentStore where labeled questions are stored :param doc_index: Index/Table in DocumentStore where documents that are used for evaluation are stored - :param top_k: How many documents to return per question + :param top_k: How many documents to return per query :param open_domain: If ``True``, retrieval will be evaluated by checking if the answer string to a question is contained in the retrieved docs (common approach in open-domain QA). If ``False``, retrieval uses a stricter evaluation that checks if the retrieved document ids @@ -170,17 +170,17 @@ class BaseRetriever(ABC): def run( self, - question: str, + query: str, filters: Optional[dict] = None, top_k_retriever: Optional[int] = None, top_k_reader: Optional[int] = None, ): if top_k_retriever: - documents = self.retrieve(query=question, filters=filters, top_k=top_k_retriever) + documents = self.retrieve(query=query, filters=filters, top_k=top_k_retriever) else: - documents = self.retrieve(query=question, filters=filters) + documents = self.retrieve(query=query, filters=filters) output = { - "question": question, + "query": query, "documents": documents, "top_k": top_k_reader } diff --git a/haystack/retriever/sparse.py b/haystack/retriever/sparse.py index 186f70f96..fc40278b7 100644 --- a/haystack/retriever/sparse.py +++ b/haystack/retriever/sparse.py @@ -18,7 +18,7 @@ class ElasticsearchRetriever(BaseRetriever): def __init__(self, document_store: ElasticsearchDocumentStore, custom_query: str = None): """ :param document_store: an instance of a DocumentStore to retrieve documents from. - :param custom_query: query string as per Elasticsearch DSL with a mandatory question placeholder($question). + :param custom_query: query string as per Elasticsearch DSL with a mandatory query placeholder(query). Optionally, ES `filter` clause can be added where the values of `terms` are placeholders that get substituted during runtime. The placeholder(${filter_name_1}, ${filter_name_2}..) @@ -32,7 +32,7 @@ class ElasticsearchRetriever(BaseRetriever): | "query": { | "bool": { | "should": [{"multi_match": { - | "query": "${question}", // mandatory $question placeholder + | "query": "${query}", // mandatory query placeholder | "type": "most_fields", | "fields": ["text", "title"]}}], | "filter": [ // optional custom filters diff --git a/haystack/utils.py b/haystack/utils.py index 6a7e483b9..9f161a231 100644 --- a/haystack/utils.py +++ b/haystack/utils.py @@ -40,11 +40,11 @@ def export_answers_to_csv(agg_results: list, output_file): if isinstance(agg_results, dict): agg_results = [agg_results] - assert "question" in agg_results[0], f"Wrong format used for {agg_results[0]}" + assert "query" in agg_results[0], f"Wrong format used for {agg_results[0]}" assert "answers" in agg_results[0], f"Wrong format used for {agg_results[0]}" data = {} # type: Dict[str, List[Any]] - data["question"] = [] + data["query"] = [] data["prediction"] = [] data["prediction_rank"] = [] data["prediction_context"] = [] @@ -52,7 +52,7 @@ def export_answers_to_csv(agg_results: list, output_file): for res in agg_results: for i in range(len(res["answers"])): temp = res["answers"][i] - data["question"].append(res["question"]) + data["query"].append(res["query"]) data["prediction"].append(temp["answer"]) data["prediction_rank"].append(i + 1) data["prediction_context"].append(temp["context"]) diff --git a/test/conftest.py b/test/conftest.py index af6888eb0..e12639a17 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -231,14 +231,14 @@ def no_answer_reader(request, transformers_roberta, farm_roberta): @pytest.fixture() def prediction(reader, test_docs_xs): docs = [Document.from_dict(d) if isinstance(d, dict) else d for d in test_docs_xs] - prediction = reader.predict(question="Who lives in Berlin?", documents=docs, top_k=5) + prediction = reader.predict(query="Who lives in Berlin?", documents=docs, top_k=5) return prediction @pytest.fixture() def no_answer_prediction(no_answer_reader, test_docs_xs): docs = [Document.from_dict(d) if isinstance(d, dict) else d for d in test_docs_xs] - prediction = no_answer_reader.predict(question="What is the meaning of life?", documents=docs, top_k=5) + prediction = no_answer_reader.predict(query="What is the meaning of life?", documents=docs, top_k=5) return prediction diff --git a/test/test_finder.py b/test/test_finder.py index 2c842f47d..4dbeb18e7 100644 --- a/test/test_finder.py +++ b/test/test_finder.py @@ -10,7 +10,7 @@ def test_finder_get_answers(reader, retriever_with_docs, document_store_with_doc prediction = finder.get_answers(question="Who lives in Berlin?", top_k_retriever=10, top_k_reader=3) assert prediction is not None - assert prediction["question"] == "Who lives in Berlin?" + assert prediction["query"] == "Who lives in Berlin?" assert prediction["answers"][0]["answer"] == "Carla" assert prediction["answers"][0]["probability"] <= 1 assert prediction["answers"][0]["probability"] >= 0 diff --git a/test/test_pipeline.py b/test/test_pipeline.py index 9e0aa1c3d..0136a2157 100644 --- a/test/test_pipeline.py +++ b/test/test_pipeline.py @@ -25,9 +25,9 @@ def test_graph_creation(reader, retriever_with_docs, document_store_with_docs): @pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True) def test_extractive_qa_answers(reader, retriever_with_docs, document_store_with_docs): pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs) - prediction = pipeline.run(question="Who lives in Berlin?", top_k_retriever=10, top_k_reader=3) + prediction = pipeline.run(query="Who lives in Berlin?", top_k_retriever=10, top_k_reader=3) assert prediction is not None - assert prediction["question"] == "Who lives in Berlin?" + assert prediction["query"] == "Who lives in Berlin?" assert prediction["answers"][0]["answer"] == "Carla" assert prediction["answers"][0]["probability"] <= 1 assert prediction["answers"][0]["probability"] >= 0 @@ -41,7 +41,7 @@ def test_extractive_qa_answers(reader, retriever_with_docs, document_store_with_ @pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True) def test_extractive_qa_offsets(reader, retriever_with_docs, document_store_with_docs): pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs) - prediction = pipeline.run(question="Who lives in Berlin?", top_k_retriever=10, top_k_reader=5) + prediction = pipeline.run(query="Who lives in Berlin?", top_k_retriever=10, top_k_reader=5) assert prediction["answers"][0]["offset_start"] == 11 assert prediction["answers"][0]["offset_end"] == 16 @@ -56,7 +56,7 @@ def test_extractive_qa_offsets(reader, retriever_with_docs, document_store_with_ def test_extractive_qa_answers_single_result(reader, retriever_with_docs, document_store_with_docs): pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs) query = "testing finder" - prediction = pipeline.run(question=query, top_k_retriever=1, top_k_reader=1) + prediction = pipeline.run(query=query, top_k_retriever=1, top_k_reader=1) assert prediction is not None assert len(prediction["answers"]) == 1 diff --git a/test/test_reader.py b/test/test_reader.py index 819875773..14e68dd49 100644 --- a/test/test_reader.py +++ b/test/test_reader.py @@ -14,7 +14,7 @@ def test_reader_basic(reader): def test_output(prediction): assert prediction is not None - assert prediction["question"] == "Who lives in Berlin?" + assert prediction["query"] == "Who lives in Berlin?" assert prediction["answers"][0]["answer"] == "Carla" assert prediction["answers"][0]["offset_start"] == 11 assert prediction["answers"][0]["offset_end"] == 16 @@ -27,7 +27,7 @@ def test_output(prediction): @pytest.mark.slow def test_no_answer_output(no_answer_prediction): assert no_answer_prediction is not None - assert no_answer_prediction["question"] == "What is the meaning of life?" + assert no_answer_prediction["query"] == "What is the meaning of life?" assert math.isclose(no_answer_prediction["no_ans_gap"], -13.048564434051514, rel_tol=0.0001) assert no_answer_prediction["answers"][0]["answer"] is None assert no_answer_prediction["answers"][0]["offset_start"] == 0 @@ -48,7 +48,7 @@ def test_no_answer_output(no_answer_prediction): @pytest.mark.slow def test_prediction_attributes(prediction): # TODO FARM's prediction also has no_ans_gap - attributes_gold = ["question", "answers"] + attributes_gold = ["query", "answers"] for ag in attributes_gold: assert ag in prediction @@ -73,7 +73,7 @@ def test_context_window_size(reader, test_docs_xs, window_size): old_window_size = reader.inferencer.model.prediction_heads[0].context_window_size reader.inferencer.model.prediction_heads[0].context_window_size = window_size - prediction = reader.predict(question="Who lives in Berlin?", documents=docs, top_k=5) + prediction = reader.predict(query="Who lives in Berlin?", documents=docs, top_k=5) for answer in prediction["answers"]: # If the extracted answer is larger than the context window, the context window is expanded. # If the extracted answer is odd in length, the resulting context window is one less than context_window_size @@ -106,7 +106,7 @@ def test_top_k(reader, test_docs_xs, top_k): except: print("WARNING: Could not set `top_k_per_sample` in FARM. Please update FARM version.") - prediction = reader.predict(question="Who lives in Berlin?", documents=docs, top_k=top_k) + prediction = reader.predict(query="Who lives in Berlin?", documents=docs, top_k=top_k) assert len(prediction["answers"]) == top_k reader.top_k_per_candidate = old_top_k_per_candidate