mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-17 05:51:22 +00:00
Redo changing of inti
This commit is contained in:
parent
f681026a56
commit
d15448f60a
@ -23,6 +23,49 @@ class FARMReader:
|
|||||||
- fine-tune the model on QA data via train()
|
- fine-tune the model on QA data via train()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name_or_path,
|
||||||
|
context_window_size=30,
|
||||||
|
batch_size=50,
|
||||||
|
use_gpu=True,
|
||||||
|
no_ans_boost=None,
|
||||||
|
n_candidates_per_paragraph=1):
|
||||||
|
"""
|
||||||
|
:param model_name_or_path: directory of a saved model or the name of a public model:
|
||||||
|
- 'bert-base-cased'
|
||||||
|
- 'deepset/bert-base-cased-squad2'
|
||||||
|
- 'deepset/bert-base-cased-squad2'
|
||||||
|
- 'distilbert-base-uncased-distilled-squad'
|
||||||
|
....
|
||||||
|
See https://huggingface.co/models for full list of available models.
|
||||||
|
:param context_window_size: The size, in characters, of the window around the answer span that is used when displaying the context around the answer.
|
||||||
|
:param batch_size: Number of samples the model receives in one batch for inference
|
||||||
|
Memory consumption is much lower in inference mode. Recommendation: increase the batch size to a value so only a single batch is used.
|
||||||
|
:param use_gpu: Whether to use GPU (if available)
|
||||||
|
:param no_ans_boost: How much the no_answer logit is boosted/increased.
|
||||||
|
Possible values: None (default) = disable returning "no answer" predictions
|
||||||
|
Negative = lower chance of "no answer" being predicted
|
||||||
|
Positive = increase chance of "no answer"
|
||||||
|
:param n_candidates_per_paragraph: How many candidate answers are extracted per text sequence that the model can process at once (depends on `max_seq_len`).
|
||||||
|
Note: - This is not the number of "final answers" you will receive
|
||||||
|
(see `top_k` in FARMReader.predict() or Finder.get_answers() for that)
|
||||||
|
- FARM includes no_answer in the sorted list of predictions
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
if no_ans_boost is None:
|
||||||
|
no_ans_boost = 0
|
||||||
|
self.return_no_answers = False
|
||||||
|
else:
|
||||||
|
self.return_no_answers = True
|
||||||
|
self.n_candidates_per_paragraph = n_candidates_per_paragraph
|
||||||
|
self.inferencer = Inferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu, task_type="question_answering")
|
||||||
|
self.inferencer.model.prediction_heads[0].context_window_size = context_window_size
|
||||||
|
self.inferencer.model.prediction_heads[0].no_ans_boost = no_ans_boost
|
||||||
|
self.inferencer.model.prediction_heads[0].n_best = n_candidates_per_paragraph + 1 # including possible no_answer
|
||||||
|
|
||||||
def train(self, data_dir, train_filename, dev_filename=None, test_file_name=None,
|
def train(self, data_dir, train_filename, dev_filename=None, test_file_name=None,
|
||||||
use_gpu=True, batch_size=10, n_epochs=2, learning_rate=1e-5,
|
use_gpu=True, batch_size=10, n_epochs=2, learning_rate=1e-5,
|
||||||
max_seq_len=256, warmup_proportion=0.2, dev_split=0.1, evaluate_every=300, save_dir=None):
|
max_seq_len=256, warmup_proportion=0.2, dev_split=0.1, evaluate_every=300, save_dir=None):
|
||||||
@ -104,49 +147,6 @@ class FARMReader:
|
|||||||
self.inferencer.model = trainer.train()
|
self.inferencer.model = trainer.train()
|
||||||
self.save(save_dir)
|
self.save(save_dir)
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_name_or_path,
|
|
||||||
context_window_size=30,
|
|
||||||
batch_size=50,
|
|
||||||
use_gpu=True,
|
|
||||||
no_ans_boost=None,
|
|
||||||
n_candidates_per_paragraph=1):
|
|
||||||
"""
|
|
||||||
:param model_name_or_path: directory of a saved model or the name of a public model:
|
|
||||||
- 'bert-base-cased'
|
|
||||||
- 'deepset/bert-base-cased-squad2'
|
|
||||||
- 'deepset/bert-base-cased-squad2'
|
|
||||||
- 'distilbert-base-uncased-distilled-squad'
|
|
||||||
....
|
|
||||||
See https://huggingface.co/models for full list of available models.
|
|
||||||
:param context_window_size: The size, in characters, of the window around the answer span that is used when displaying the context around the answer.
|
|
||||||
:param batch_size: Number of samples the model receives in one batch for inference
|
|
||||||
Memory consumption is much lower in inference mode. Recommendation: increase the batch size to a value so only a single batch is used.
|
|
||||||
:param use_gpu: Whether to use GPU (if available)
|
|
||||||
:param no_ans_boost: How much the no_answer logit is boosted/increased.
|
|
||||||
Possible values: None (default) = disable returning "no answer" predictions
|
|
||||||
Negative = lower chance of "no answer" being predicted
|
|
||||||
Positive = increase chance of "no answer"
|
|
||||||
:param n_candidates_per_paragraph: How many candidate answers are extracted per text sequence that the model can process at once (depends on `max_seq_len`).
|
|
||||||
Note: - This is not the number of "final answers" you will receive
|
|
||||||
(see `top_k` in FARMReader.predict() or Finder.get_answers() for that)
|
|
||||||
- FARM includes no_answer in the sorted list of predictions
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
if no_ans_boost is None:
|
|
||||||
no_ans_boost = 0
|
|
||||||
self.return_no_answers = False
|
|
||||||
else:
|
|
||||||
self.return_no_answers = True
|
|
||||||
self.n_candidates_per_paragraph = n_candidates_per_paragraph
|
|
||||||
self.inferencer = Inferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu, task_type="question_answering")
|
|
||||||
self.inferencer.model.prediction_heads[0].context_window_size = context_window_size
|
|
||||||
self.inferencer.model.prediction_heads[0].no_ans_boost = no_ans_boost
|
|
||||||
self.inferencer.model.prediction_heads[0].n_best = n_candidates_per_paragraph + 1 # including possible no_answer
|
|
||||||
|
|
||||||
def save(self, directory):
|
def save(self, directory):
|
||||||
logger.info(f"Saving reader model to {directory}")
|
logger.info(f"Saving reader model to {directory}")
|
||||||
self.inferencer.model.save(directory)
|
self.inferencer.model.save(directory)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user