Enable dynamic parameter updates for the FARMReader (#650)

This commit is contained in:
Tanay Soni 2020-12-07 14:07:20 +01:00 committed by GitHub
parent e6ada08d0e
commit 4152ad8426
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 89 additions and 9 deletions

View File

@ -62,6 +62,10 @@ class Pipeline:
input_edge_name = "output_1"
self.graph.add_edge(input_node_name, name, label=input_edge_name)
def get_node(self, name: str):
component = self.graph.nodes[name]["component"]
return component
def run(self, **kwargs):
has_next_node = True
current_node_id = self.root_node_id
@ -120,11 +124,17 @@ class Pipeline:
class BaseStandardPipeline:
pipeline: Pipeline
def add_node(self, component, name: str, inputs: List[str]):
self.pipeline.add_node(component=component, name=name, inputs=inputs) # type: ignore
self.pipeline.add_node(component=component, name=name, inputs=inputs)
def get_node(self, name: str):
component = self.pipeline.get_node(name)
return component
def draw(self, path: Path = Path("pipeline.png")):
self.pipeline.draw(path) # type: ignore
self.pipeline.draw(path)
class ExtractiveQAPipeline(BaseStandardPipeline):

View File

@ -44,7 +44,8 @@ class FARMReader(BaseReader):
context_window_size: int = 150,
batch_size: int = 50,
use_gpu: bool = True,
no_ans_boost: Optional[float] = None,
no_ans_boost: float = 0.0,
return_no_answer: bool = False,
top_k_per_candidate: int = 3,
top_k_per_sample: int = 1,
num_processes: Optional[int] = None,
@ -63,9 +64,10 @@ class FARMReader(BaseReader):
to a value so only a single batch is used.
:param use_gpu: Whether to use GPU (if available)
:param no_ans_boost: How much the no_answer logit is boosted/increased.
If set to None (default), disables returning "no answer" predictions.
If set to 0 (default), the no_answer logit is not changed.
If a negative number, there is a lower chance of "no_answer" being predicted.
If a positive number, there is an increased chance of "no_answer"
:param return_no_answer: Whether to include no_answer predictions in the results.
:param top_k_per_candidate: How many answers to extract for each candidate doc that is coming from the retriever (might be a long text).
Note that this is not the number of "final answers" you will receive
(see `top_k` in FARMReader.predict() or Finder.get_answers() for that)
@ -85,11 +87,7 @@ class FARMReader(BaseReader):
"""
if no_ans_boost is None:
no_ans_boost = 0
self.return_no_answers = False
else:
self.return_no_answers = True
self.return_no_answers = return_no_answer
self.top_k_per_candidate = top_k_per_candidate
self.inferencer = QAInferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu,
task_type="question_answering", max_seq_len=max_seq_len,
@ -233,6 +231,29 @@ class FARMReader(BaseReader):
self.inferencer.model = trainer.train()
self.save(Path(save_dir))
def update_parameters(
self,
context_window_size: Optional[int] = None,
no_ans_boost: Optional[float] = None,
return_no_answer: Optional[bool] = None,
max_seq_len: Optional[int] = None,
doc_stride: Optional[int] = None,
):
"""
Hot update parameters of a loaded Reader. It may not to be safe when processing concurrent requests.
"""
if no_ans_boost is not None:
self.inferencer.model.prediction_heads[0].no_ans_boost = no_ans_boost
if return_no_answer is not None:
self.return_no_answers = return_no_answer
if doc_stride is not None:
self.inferencer.processor.doc_stride = doc_stride
if context_window_size is not None:
self.inferencer.model.prediction_heads[0].context_window_size = context_window_size
if max_seq_len is not None:
self.inferencer.processor.max_seq_len = max_seq_len
self.max_seq_len = max_seq_len
def save(self, directory: Path):
"""
Saves the Reader model so that it can be reused at a later point in time.

View File

@ -199,6 +199,7 @@ def no_answer_reader(request):
use_gpu=False,
top_k_per_sample=5,
no_ans_boost=0,
return_no_answer=True,
num_processes=0
)
if request.param == "transformers":

View File

@ -115,3 +115,51 @@ def test_top_k(reader, test_docs_xs, top_k):
reader.inferencer.model.prediction_heads[0].n_best_per_sample = old_top_k_per_sample
except:
print("WARNING: Could not set `top_k_per_sample` in FARM. Please update FARM version.")
def test_farm_reader_update_params(test_docs_xs):
reader = FARMReader(
model_name_or_path="deepset/roberta-base-squad2",
use_gpu=False,
no_ans_boost=0,
num_processes=0
)
docs = [Document.from_dict(d) if isinstance(d, dict) else d for d in test_docs_xs]
# original reader
prediction = reader.predict(query="Who lives in Berlin?", documents=docs, top_k=3)
assert len(prediction["answers"]) == 3
assert prediction["answers"][0]["answer"] == "Carla"
# update no_ans_boost
reader.update_parameters(
context_window_size=100, no_ans_boost=100, return_no_answer=True, max_seq_len=384, doc_stride=128
)
prediction = reader.predict(query="Who lives in Berlin?", documents=docs, top_k=3)
assert len(prediction["answers"]) == 3
assert prediction["answers"][0]["answer"] is None
# update no_ans_boost
reader.update_parameters(
context_window_size=100, no_ans_boost=0, return_no_answer=False, max_seq_len=384, doc_stride=128
)
prediction = reader.predict(query="Who lives in Berlin?", documents=docs, top_k=3)
assert len(prediction["answers"]) == 3
assert None not in [ans["answer"] for ans in prediction["answers"]]
# update context_window_size
reader.update_parameters(context_window_size=6, no_ans_boost=-10, max_seq_len=384, doc_stride=128)
prediction = reader.predict(query="Who lives in Berlin?", documents=docs, top_k=3)
assert len(prediction["answers"]) == 3
assert len(prediction["answers"][0]["context"]) == 6
# update doc_stride with invalid value
with pytest.raises(Exception):
reader.update_parameters(context_window_size=100, no_ans_boost=-10, max_seq_len=384, doc_stride=999)
reader.predict(query="Who lives in Berlin?", documents=docs, top_k=3)
# update max_seq_len with invalid value
with pytest.raises(Exception):
reader.update_parameters(context_window_size=6, no_ans_boost=-10, max_seq_len=99, doc_stride=128)
reader.predict(query="Who lives in Berlin?", documents=docs, top_k=3)