mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-25 16:15:35 +00:00
FARMReader.train() uses values from class
This commit is contained in:
parent
9897c40c41
commit
d41dcca813
@ -87,10 +87,12 @@ class FARMReader:
|
||||
except:
|
||||
logger.warning("Could not set `top_k_per_sample` in FARM. Please update FARM version.")
|
||||
self.max_processes = max_processes
|
||||
self.max_seq_len = max_seq_len
|
||||
self.use_gpu = use_gpu
|
||||
|
||||
def train(self, data_dir, train_filename, dev_filename=None, test_file_name=None,
|
||||
use_gpu=True, batch_size=10, n_epochs=2, learning_rate=1e-5,
|
||||
max_seq_len=256, warmup_proportion=0.2, dev_split=0.1, evaluate_every=300, save_dir=None):
|
||||
use_gpu=None, batch_size=10, n_epochs=2, learning_rate=1e-5,
|
||||
max_seq_len=None, warmup_proportion=0.2, dev_split=0.1, evaluate_every=300, save_dir=None):
|
||||
"""
|
||||
Fine-tune a model on a QA dataset. Options:
|
||||
- Take a plain language model (e.g. `bert-base-cased`) and train it for QA (e.g. on SQuAD data)
|
||||
@ -120,6 +122,14 @@ class FARMReader:
|
||||
dev_split = None
|
||||
|
||||
set_all_seeds(seed=42)
|
||||
|
||||
# For these variables, by default, we use the value set when initializing the FARMReader.
|
||||
# This can also be manually set when train() is called if you want a different value at train vs inference
|
||||
if use_gpu is None:
|
||||
use_gpu = self.use_gpu
|
||||
if max_seq_len is None:
|
||||
max_seq_len = self.max_seq_len
|
||||
|
||||
device, n_gpu = initialize_device_settings(use_cuda=use_gpu)
|
||||
|
||||
if not save_dir:
|
||||
|
Loading…
x
Reference in New Issue
Block a user