mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-03 11:19:57 +00:00
Enable dynamic parameter updates for the FARMReader (#650)
This commit is contained in:
parent
e6ada08d0e
commit
4152ad8426
@ -62,6 +62,10 @@ class Pipeline:
|
||||
input_edge_name = "output_1"
|
||||
self.graph.add_edge(input_node_name, name, label=input_edge_name)
|
||||
|
||||
def get_node(self, name: str):
|
||||
component = self.graph.nodes[name]["component"]
|
||||
return component
|
||||
|
||||
def run(self, **kwargs):
|
||||
has_next_node = True
|
||||
current_node_id = self.root_node_id
|
||||
@ -120,11 +124,17 @@ class Pipeline:
|
||||
|
||||
|
||||
class BaseStandardPipeline:
|
||||
pipeline: Pipeline
|
||||
|
||||
def add_node(self, component, name: str, inputs: List[str]):
|
||||
self.pipeline.add_node(component=component, name=name, inputs=inputs) # type: ignore
|
||||
self.pipeline.add_node(component=component, name=name, inputs=inputs)
|
||||
|
||||
def get_node(self, name: str):
|
||||
component = self.pipeline.get_node(name)
|
||||
return component
|
||||
|
||||
def draw(self, path: Path = Path("pipeline.png")):
|
||||
self.pipeline.draw(path) # type: ignore
|
||||
self.pipeline.draw(path)
|
||||
|
||||
|
||||
class ExtractiveQAPipeline(BaseStandardPipeline):
|
||||
|
||||
@ -44,7 +44,8 @@ class FARMReader(BaseReader):
|
||||
context_window_size: int = 150,
|
||||
batch_size: int = 50,
|
||||
use_gpu: bool = True,
|
||||
no_ans_boost: Optional[float] = None,
|
||||
no_ans_boost: float = 0.0,
|
||||
return_no_answer: bool = False,
|
||||
top_k_per_candidate: int = 3,
|
||||
top_k_per_sample: int = 1,
|
||||
num_processes: Optional[int] = None,
|
||||
@ -63,9 +64,10 @@ class FARMReader(BaseReader):
|
||||
to a value so only a single batch is used.
|
||||
:param use_gpu: Whether to use GPU (if available)
|
||||
:param no_ans_boost: How much the no_answer logit is boosted/increased.
|
||||
If set to None (default), disables returning "no answer" predictions.
|
||||
If set to 0 (default), the no_answer logit is not changed.
|
||||
If a negative number, there is a lower chance of "no_answer" being predicted.
|
||||
If a positive number, there is an increased chance of "no_answer"
|
||||
:param return_no_answer: Whether to include no_answer predictions in the results.
|
||||
:param top_k_per_candidate: How many answers to extract for each candidate doc that is coming from the retriever (might be a long text).
|
||||
Note that this is not the number of "final answers" you will receive
|
||||
(see `top_k` in FARMReader.predict() or Finder.get_answers() for that)
|
||||
@ -85,11 +87,7 @@ class FARMReader(BaseReader):
|
||||
|
||||
"""
|
||||
|
||||
if no_ans_boost is None:
|
||||
no_ans_boost = 0
|
||||
self.return_no_answers = False
|
||||
else:
|
||||
self.return_no_answers = True
|
||||
self.return_no_answers = return_no_answer
|
||||
self.top_k_per_candidate = top_k_per_candidate
|
||||
self.inferencer = QAInferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu,
|
||||
task_type="question_answering", max_seq_len=max_seq_len,
|
||||
@ -233,6 +231,29 @@ class FARMReader(BaseReader):
|
||||
self.inferencer.model = trainer.train()
|
||||
self.save(Path(save_dir))
|
||||
|
||||
def update_parameters(
|
||||
self,
|
||||
context_window_size: Optional[int] = None,
|
||||
no_ans_boost: Optional[float] = None,
|
||||
return_no_answer: Optional[bool] = None,
|
||||
max_seq_len: Optional[int] = None,
|
||||
doc_stride: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Hot update parameters of a loaded Reader. It may not to be safe when processing concurrent requests.
|
||||
"""
|
||||
if no_ans_boost is not None:
|
||||
self.inferencer.model.prediction_heads[0].no_ans_boost = no_ans_boost
|
||||
if return_no_answer is not None:
|
||||
self.return_no_answers = return_no_answer
|
||||
if doc_stride is not None:
|
||||
self.inferencer.processor.doc_stride = doc_stride
|
||||
if context_window_size is not None:
|
||||
self.inferencer.model.prediction_heads[0].context_window_size = context_window_size
|
||||
if max_seq_len is not None:
|
||||
self.inferencer.processor.max_seq_len = max_seq_len
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
def save(self, directory: Path):
|
||||
"""
|
||||
Saves the Reader model so that it can be reused at a later point in time.
|
||||
|
||||
@ -199,6 +199,7 @@ def no_answer_reader(request):
|
||||
use_gpu=False,
|
||||
top_k_per_sample=5,
|
||||
no_ans_boost=0,
|
||||
return_no_answer=True,
|
||||
num_processes=0
|
||||
)
|
||||
if request.param == "transformers":
|
||||
|
||||
@ -115,3 +115,51 @@ def test_top_k(reader, test_docs_xs, top_k):
|
||||
reader.inferencer.model.prediction_heads[0].n_best_per_sample = old_top_k_per_sample
|
||||
except:
|
||||
print("WARNING: Could not set `top_k_per_sample` in FARM. Please update FARM version.")
|
||||
|
||||
|
||||
def test_farm_reader_update_params(test_docs_xs):
|
||||
reader = FARMReader(
|
||||
model_name_or_path="deepset/roberta-base-squad2",
|
||||
use_gpu=False,
|
||||
no_ans_boost=0,
|
||||
num_processes=0
|
||||
)
|
||||
|
||||
docs = [Document.from_dict(d) if isinstance(d, dict) else d for d in test_docs_xs]
|
||||
|
||||
# original reader
|
||||
prediction = reader.predict(query="Who lives in Berlin?", documents=docs, top_k=3)
|
||||
assert len(prediction["answers"]) == 3
|
||||
assert prediction["answers"][0]["answer"] == "Carla"
|
||||
|
||||
# update no_ans_boost
|
||||
reader.update_parameters(
|
||||
context_window_size=100, no_ans_boost=100, return_no_answer=True, max_seq_len=384, doc_stride=128
|
||||
)
|
||||
prediction = reader.predict(query="Who lives in Berlin?", documents=docs, top_k=3)
|
||||
assert len(prediction["answers"]) == 3
|
||||
assert prediction["answers"][0]["answer"] is None
|
||||
|
||||
# update no_ans_boost
|
||||
reader.update_parameters(
|
||||
context_window_size=100, no_ans_boost=0, return_no_answer=False, max_seq_len=384, doc_stride=128
|
||||
)
|
||||
prediction = reader.predict(query="Who lives in Berlin?", documents=docs, top_k=3)
|
||||
assert len(prediction["answers"]) == 3
|
||||
assert None not in [ans["answer"] for ans in prediction["answers"]]
|
||||
|
||||
# update context_window_size
|
||||
reader.update_parameters(context_window_size=6, no_ans_boost=-10, max_seq_len=384, doc_stride=128)
|
||||
prediction = reader.predict(query="Who lives in Berlin?", documents=docs, top_k=3)
|
||||
assert len(prediction["answers"]) == 3
|
||||
assert len(prediction["answers"][0]["context"]) == 6
|
||||
|
||||
# update doc_stride with invalid value
|
||||
with pytest.raises(Exception):
|
||||
reader.update_parameters(context_window_size=100, no_ans_boost=-10, max_seq_len=384, doc_stride=999)
|
||||
reader.predict(query="Who lives in Berlin?", documents=docs, top_k=3)
|
||||
|
||||
# update max_seq_len with invalid value
|
||||
with pytest.raises(Exception):
|
||||
reader.update_parameters(context_window_size=6, no_ans_boost=-10, max_seq_len=99, doc_stride=128)
|
||||
reader.predict(query="Who lives in Berlin?", documents=docs, top_k=3)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user