mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-10 14:54:10 +00:00
Enable dynamic parameter updates for the FARMReader (#650)
This commit is contained in:
parent
e6ada08d0e
commit
4152ad8426
@ -62,6 +62,10 @@ class Pipeline:
|
|||||||
input_edge_name = "output_1"
|
input_edge_name = "output_1"
|
||||||
self.graph.add_edge(input_node_name, name, label=input_edge_name)
|
self.graph.add_edge(input_node_name, name, label=input_edge_name)
|
||||||
|
|
||||||
|
def get_node(self, name: str):
|
||||||
|
component = self.graph.nodes[name]["component"]
|
||||||
|
return component
|
||||||
|
|
||||||
def run(self, **kwargs):
|
def run(self, **kwargs):
|
||||||
has_next_node = True
|
has_next_node = True
|
||||||
current_node_id = self.root_node_id
|
current_node_id = self.root_node_id
|
||||||
@ -120,11 +124,17 @@ class Pipeline:
|
|||||||
|
|
||||||
|
|
||||||
class BaseStandardPipeline:
|
class BaseStandardPipeline:
|
||||||
|
pipeline: Pipeline
|
||||||
|
|
||||||
def add_node(self, component, name: str, inputs: List[str]):
|
def add_node(self, component, name: str, inputs: List[str]):
|
||||||
self.pipeline.add_node(component=component, name=name, inputs=inputs) # type: ignore
|
self.pipeline.add_node(component=component, name=name, inputs=inputs)
|
||||||
|
|
||||||
|
def get_node(self, name: str):
|
||||||
|
component = self.pipeline.get_node(name)
|
||||||
|
return component
|
||||||
|
|
||||||
def draw(self, path: Path = Path("pipeline.png")):
|
def draw(self, path: Path = Path("pipeline.png")):
|
||||||
self.pipeline.draw(path) # type: ignore
|
self.pipeline.draw(path)
|
||||||
|
|
||||||
|
|
||||||
class ExtractiveQAPipeline(BaseStandardPipeline):
|
class ExtractiveQAPipeline(BaseStandardPipeline):
|
||||||
|
|||||||
@ -44,7 +44,8 @@ class FARMReader(BaseReader):
|
|||||||
context_window_size: int = 150,
|
context_window_size: int = 150,
|
||||||
batch_size: int = 50,
|
batch_size: int = 50,
|
||||||
use_gpu: bool = True,
|
use_gpu: bool = True,
|
||||||
no_ans_boost: Optional[float] = None,
|
no_ans_boost: float = 0.0,
|
||||||
|
return_no_answer: bool = False,
|
||||||
top_k_per_candidate: int = 3,
|
top_k_per_candidate: int = 3,
|
||||||
top_k_per_sample: int = 1,
|
top_k_per_sample: int = 1,
|
||||||
num_processes: Optional[int] = None,
|
num_processes: Optional[int] = None,
|
||||||
@ -63,9 +64,10 @@ class FARMReader(BaseReader):
|
|||||||
to a value so only a single batch is used.
|
to a value so only a single batch is used.
|
||||||
:param use_gpu: Whether to use GPU (if available)
|
:param use_gpu: Whether to use GPU (if available)
|
||||||
:param no_ans_boost: How much the no_answer logit is boosted/increased.
|
:param no_ans_boost: How much the no_answer logit is boosted/increased.
|
||||||
If set to None (default), disables returning "no answer" predictions.
|
If set to 0 (default), the no_answer logit is not changed.
|
||||||
If a negative number, there is a lower chance of "no_answer" being predicted.
|
If a negative number, there is a lower chance of "no_answer" being predicted.
|
||||||
If a positive number, there is an increased chance of "no_answer"
|
If a positive number, there is an increased chance of "no_answer"
|
||||||
|
:param return_no_answer: Whether to include no_answer predictions in the results.
|
||||||
:param top_k_per_candidate: How many answers to extract for each candidate doc that is coming from the retriever (might be a long text).
|
:param top_k_per_candidate: How many answers to extract for each candidate doc that is coming from the retriever (might be a long text).
|
||||||
Note that this is not the number of "final answers" you will receive
|
Note that this is not the number of "final answers" you will receive
|
||||||
(see `top_k` in FARMReader.predict() or Finder.get_answers() for that)
|
(see `top_k` in FARMReader.predict() or Finder.get_answers() for that)
|
||||||
@ -85,11 +87,7 @@ class FARMReader(BaseReader):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if no_ans_boost is None:
|
self.return_no_answers = return_no_answer
|
||||||
no_ans_boost = 0
|
|
||||||
self.return_no_answers = False
|
|
||||||
else:
|
|
||||||
self.return_no_answers = True
|
|
||||||
self.top_k_per_candidate = top_k_per_candidate
|
self.top_k_per_candidate = top_k_per_candidate
|
||||||
self.inferencer = QAInferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu,
|
self.inferencer = QAInferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu,
|
||||||
task_type="question_answering", max_seq_len=max_seq_len,
|
task_type="question_answering", max_seq_len=max_seq_len,
|
||||||
@ -233,6 +231,29 @@ class FARMReader(BaseReader):
|
|||||||
self.inferencer.model = trainer.train()
|
self.inferencer.model = trainer.train()
|
||||||
self.save(Path(save_dir))
|
self.save(Path(save_dir))
|
||||||
|
|
||||||
|
def update_parameters(
|
||||||
|
self,
|
||||||
|
context_window_size: Optional[int] = None,
|
||||||
|
no_ans_boost: Optional[float] = None,
|
||||||
|
return_no_answer: Optional[bool] = None,
|
||||||
|
max_seq_len: Optional[int] = None,
|
||||||
|
doc_stride: Optional[int] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Hot update parameters of a loaded Reader. It may not to be safe when processing concurrent requests.
|
||||||
|
"""
|
||||||
|
if no_ans_boost is not None:
|
||||||
|
self.inferencer.model.prediction_heads[0].no_ans_boost = no_ans_boost
|
||||||
|
if return_no_answer is not None:
|
||||||
|
self.return_no_answers = return_no_answer
|
||||||
|
if doc_stride is not None:
|
||||||
|
self.inferencer.processor.doc_stride = doc_stride
|
||||||
|
if context_window_size is not None:
|
||||||
|
self.inferencer.model.prediction_heads[0].context_window_size = context_window_size
|
||||||
|
if max_seq_len is not None:
|
||||||
|
self.inferencer.processor.max_seq_len = max_seq_len
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
|
||||||
def save(self, directory: Path):
|
def save(self, directory: Path):
|
||||||
"""
|
"""
|
||||||
Saves the Reader model so that it can be reused at a later point in time.
|
Saves the Reader model so that it can be reused at a later point in time.
|
||||||
|
|||||||
@ -199,6 +199,7 @@ def no_answer_reader(request):
|
|||||||
use_gpu=False,
|
use_gpu=False,
|
||||||
top_k_per_sample=5,
|
top_k_per_sample=5,
|
||||||
no_ans_boost=0,
|
no_ans_boost=0,
|
||||||
|
return_no_answer=True,
|
||||||
num_processes=0
|
num_processes=0
|
||||||
)
|
)
|
||||||
if request.param == "transformers":
|
if request.param == "transformers":
|
||||||
|
|||||||
@ -115,3 +115,51 @@ def test_top_k(reader, test_docs_xs, top_k):
|
|||||||
reader.inferencer.model.prediction_heads[0].n_best_per_sample = old_top_k_per_sample
|
reader.inferencer.model.prediction_heads[0].n_best_per_sample = old_top_k_per_sample
|
||||||
except:
|
except:
|
||||||
print("WARNING: Could not set `top_k_per_sample` in FARM. Please update FARM version.")
|
print("WARNING: Could not set `top_k_per_sample` in FARM. Please update FARM version.")
|
||||||
|
|
||||||
|
|
||||||
|
def test_farm_reader_update_params(test_docs_xs):
|
||||||
|
reader = FARMReader(
|
||||||
|
model_name_or_path="deepset/roberta-base-squad2",
|
||||||
|
use_gpu=False,
|
||||||
|
no_ans_boost=0,
|
||||||
|
num_processes=0
|
||||||
|
)
|
||||||
|
|
||||||
|
docs = [Document.from_dict(d) if isinstance(d, dict) else d for d in test_docs_xs]
|
||||||
|
|
||||||
|
# original reader
|
||||||
|
prediction = reader.predict(query="Who lives in Berlin?", documents=docs, top_k=3)
|
||||||
|
assert len(prediction["answers"]) == 3
|
||||||
|
assert prediction["answers"][0]["answer"] == "Carla"
|
||||||
|
|
||||||
|
# update no_ans_boost
|
||||||
|
reader.update_parameters(
|
||||||
|
context_window_size=100, no_ans_boost=100, return_no_answer=True, max_seq_len=384, doc_stride=128
|
||||||
|
)
|
||||||
|
prediction = reader.predict(query="Who lives in Berlin?", documents=docs, top_k=3)
|
||||||
|
assert len(prediction["answers"]) == 3
|
||||||
|
assert prediction["answers"][0]["answer"] is None
|
||||||
|
|
||||||
|
# update no_ans_boost
|
||||||
|
reader.update_parameters(
|
||||||
|
context_window_size=100, no_ans_boost=0, return_no_answer=False, max_seq_len=384, doc_stride=128
|
||||||
|
)
|
||||||
|
prediction = reader.predict(query="Who lives in Berlin?", documents=docs, top_k=3)
|
||||||
|
assert len(prediction["answers"]) == 3
|
||||||
|
assert None not in [ans["answer"] for ans in prediction["answers"]]
|
||||||
|
|
||||||
|
# update context_window_size
|
||||||
|
reader.update_parameters(context_window_size=6, no_ans_boost=-10, max_seq_len=384, doc_stride=128)
|
||||||
|
prediction = reader.predict(query="Who lives in Berlin?", documents=docs, top_k=3)
|
||||||
|
assert len(prediction["answers"]) == 3
|
||||||
|
assert len(prediction["answers"][0]["context"]) == 6
|
||||||
|
|
||||||
|
# update doc_stride with invalid value
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
reader.update_parameters(context_window_size=100, no_ans_boost=-10, max_seq_len=384, doc_stride=999)
|
||||||
|
reader.predict(query="Who lives in Berlin?", documents=docs, top_k=3)
|
||||||
|
|
||||||
|
# update max_seq_len with invalid value
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
reader.update_parameters(context_window_size=6, no_ans_boost=-10, max_seq_len=99, doc_stride=128)
|
||||||
|
reader.predict(query="Who lives in Berlin?", documents=docs, top_k=3)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user