From 4152ad8426cf5132cabf0792548c7c0c4affbae9 Mon Sep 17 00:00:00 2001 From: Tanay Soni Date: Mon, 7 Dec 2020 14:07:20 +0100 Subject: [PATCH] Enable dynamic parameter updates for the FARMReader (#650) --- haystack/pipeline.py | 14 ++++++++++-- haystack/reader/farm.py | 35 ++++++++++++++++++++++++------ test/conftest.py | 1 + test/test_reader.py | 48 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 89 insertions(+), 9 deletions(-) diff --git a/haystack/pipeline.py b/haystack/pipeline.py index ad213d3ff..5a06e7f33 100644 --- a/haystack/pipeline.py +++ b/haystack/pipeline.py @@ -62,6 +62,10 @@ class Pipeline: input_edge_name = "output_1" self.graph.add_edge(input_node_name, name, label=input_edge_name) + def get_node(self, name: str): + component = self.graph.nodes[name]["component"] + return component + def run(self, **kwargs): has_next_node = True current_node_id = self.root_node_id @@ -120,11 +124,17 @@ class Pipeline: class BaseStandardPipeline: + pipeline: Pipeline + def add_node(self, component, name: str, inputs: List[str]): - self.pipeline.add_node(component=component, name=name, inputs=inputs) # type: ignore + self.pipeline.add_node(component=component, name=name, inputs=inputs) + + def get_node(self, name: str): + component = self.pipeline.get_node(name) + return component def draw(self, path: Path = Path("pipeline.png")): - self.pipeline.draw(path) # type: ignore + self.pipeline.draw(path) class ExtractiveQAPipeline(BaseStandardPipeline): diff --git a/haystack/reader/farm.py b/haystack/reader/farm.py index 2956fd9f3..61400ffb8 100644 --- a/haystack/reader/farm.py +++ b/haystack/reader/farm.py @@ -44,7 +44,8 @@ class FARMReader(BaseReader): context_window_size: int = 150, batch_size: int = 50, use_gpu: bool = True, - no_ans_boost: Optional[float] = None, + no_ans_boost: float = 0.0, + return_no_answer: bool = False, top_k_per_candidate: int = 3, top_k_per_sample: int = 1, num_processes: Optional[int] = None, @@ -63,9 +64,10 @@ class FARMReader(BaseReader): to a value so only a single batch is used. :param use_gpu: Whether to use GPU (if available) :param no_ans_boost: How much the no_answer logit is boosted/increased. - If set to None (default), disables returning "no answer" predictions. + If set to 0 (default), the no_answer logit is not changed. If a negative number, there is a lower chance of "no_answer" being predicted. If a positive number, there is an increased chance of "no_answer" + :param return_no_answer: Whether to include no_answer predictions in the results. :param top_k_per_candidate: How many answers to extract for each candidate doc that is coming from the retriever (might be a long text). Note that this is not the number of "final answers" you will receive (see `top_k` in FARMReader.predict() or Finder.get_answers() for that) @@ -85,11 +87,7 @@ class FARMReader(BaseReader): """ - if no_ans_boost is None: - no_ans_boost = 0 - self.return_no_answers = False - else: - self.return_no_answers = True + self.return_no_answers = return_no_answer self.top_k_per_candidate = top_k_per_candidate self.inferencer = QAInferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu, task_type="question_answering", max_seq_len=max_seq_len, @@ -233,6 +231,29 @@ class FARMReader(BaseReader): self.inferencer.model = trainer.train() self.save(Path(save_dir)) + def update_parameters( + self, + context_window_size: Optional[int] = None, + no_ans_boost: Optional[float] = None, + return_no_answer: Optional[bool] = None, + max_seq_len: Optional[int] = None, + doc_stride: Optional[int] = None, + ): + """ + Hot update parameters of a loaded Reader. It may not to be safe when processing concurrent requests. + """ + if no_ans_boost is not None: + self.inferencer.model.prediction_heads[0].no_ans_boost = no_ans_boost + if return_no_answer is not None: + self.return_no_answers = return_no_answer + if doc_stride is not None: + self.inferencer.processor.doc_stride = doc_stride + if context_window_size is not None: + self.inferencer.model.prediction_heads[0].context_window_size = context_window_size + if max_seq_len is not None: + self.inferencer.processor.max_seq_len = max_seq_len + self.max_seq_len = max_seq_len + def save(self, directory: Path): """ Saves the Reader model so that it can be reused at a later point in time. diff --git a/test/conftest.py b/test/conftest.py index 17bdce1f1..67ba2bdf3 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -199,6 +199,7 @@ def no_answer_reader(request): use_gpu=False, top_k_per_sample=5, no_ans_boost=0, + return_no_answer=True, num_processes=0 ) if request.param == "transformers": diff --git a/test/test_reader.py b/test/test_reader.py index 14e68dd49..e39e10099 100644 --- a/test/test_reader.py +++ b/test/test_reader.py @@ -115,3 +115,51 @@ def test_top_k(reader, test_docs_xs, top_k): reader.inferencer.model.prediction_heads[0].n_best_per_sample = old_top_k_per_sample except: print("WARNING: Could not set `top_k_per_sample` in FARM. Please update FARM version.") + + +def test_farm_reader_update_params(test_docs_xs): + reader = FARMReader( + model_name_or_path="deepset/roberta-base-squad2", + use_gpu=False, + no_ans_boost=0, + num_processes=0 + ) + + docs = [Document.from_dict(d) if isinstance(d, dict) else d for d in test_docs_xs] + + # original reader + prediction = reader.predict(query="Who lives in Berlin?", documents=docs, top_k=3) + assert len(prediction["answers"]) == 3 + assert prediction["answers"][0]["answer"] == "Carla" + + # update no_ans_boost + reader.update_parameters( + context_window_size=100, no_ans_boost=100, return_no_answer=True, max_seq_len=384, doc_stride=128 + ) + prediction = reader.predict(query="Who lives in Berlin?", documents=docs, top_k=3) + assert len(prediction["answers"]) == 3 + assert prediction["answers"][0]["answer"] is None + + # update no_ans_boost + reader.update_parameters( + context_window_size=100, no_ans_boost=0, return_no_answer=False, max_seq_len=384, doc_stride=128 + ) + prediction = reader.predict(query="Who lives in Berlin?", documents=docs, top_k=3) + assert len(prediction["answers"]) == 3 + assert None not in [ans["answer"] for ans in prediction["answers"]] + + # update context_window_size + reader.update_parameters(context_window_size=6, no_ans_boost=-10, max_seq_len=384, doc_stride=128) + prediction = reader.predict(query="Who lives in Berlin?", documents=docs, top_k=3) + assert len(prediction["answers"]) == 3 + assert len(prediction["answers"][0]["context"]) == 6 + + # update doc_stride with invalid value + with pytest.raises(Exception): + reader.update_parameters(context_window_size=100, no_ans_boost=-10, max_seq_len=384, doc_stride=999) + reader.predict(query="Who lives in Berlin?", documents=docs, top_k=3) + + # update max_seq_len with invalid value + with pytest.raises(Exception): + reader.update_parameters(context_window_size=6, no_ans_boost=-10, max_seq_len=99, doc_stride=128) + reader.predict(query="Who lives in Berlin?", documents=docs, top_k=3)