mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-01 10:19:23 +00:00
Fix multi-gpu training via DataParallel (#234)
This commit is contained in:
parent
5c1a5fe61d
commit
c9d3146fae
@ -12,5 +12,6 @@ logging.getLogger('farm.utils').setLevel(logging.INFO)
|
||||
logging.getLogger('farm.infer').setLevel(logging.INFO)
|
||||
logging.getLogger('transformers').setLevel(logging.WARNING)
|
||||
logging.getLogger('farm.eval').setLevel(logging.INFO)
|
||||
logging.getLogger('farm.modeling.optimization').setLevel(logging.INFO)
|
||||
|
||||
|
||||
|
||||
@ -10,10 +10,12 @@ from farm.data_handler.inputs import QAInput, Question
|
||||
from farm.infer import QAInferencer
|
||||
from farm.modeling.optimization import initialize_optimizer
|
||||
from farm.modeling.predictions import QAPred, QACandidate
|
||||
from farm.modeling.adaptive_model import BaseAdaptiveModel
|
||||
from farm.train import Trainer
|
||||
from farm.eval import Evaluator
|
||||
from farm.utils import set_all_seeds, initialize_device_settings
|
||||
from scipy.special import expit
|
||||
import shutil
|
||||
|
||||
from haystack.database.base import Document
|
||||
from haystack.database.elasticsearch import ElasticsearchDocumentStore
|
||||
@ -177,9 +179,17 @@ class FARMReader(BaseReader):
|
||||
# and calculates a few descriptive statistics of our datasets
|
||||
data_silo = DataSilo(processor=processor, batch_size=batch_size, distributed=False)
|
||||
|
||||
# Quick-fix until this is fixed upstream in FARM:
|
||||
# We must avoid applying DataParallel twice (once when loading the inferencer,
|
||||
# once when calling initalize_optimizer)
|
||||
self.inferencer.model.save("tmp_model")
|
||||
model = BaseAdaptiveModel.load(load_dir="tmp_model", device=device, strict=True)
|
||||
shutil.rmtree('tmp_model')
|
||||
|
||||
# 3. Create an optimizer and pass the already initialized model
|
||||
model, optimizer, lr_schedule = initialize_optimizer(
|
||||
model=self.inferencer.model,
|
||||
model=model,
|
||||
# model=self.inferencer.model,
|
||||
learning_rate=learning_rate,
|
||||
schedule_opts={"name": "LinearWarmup", "warmup_proportion": warmup_proportion},
|
||||
n_batches=len(data_silo.loaders["train"]),
|
||||
@ -197,6 +207,8 @@ class FARMReader(BaseReader):
|
||||
evaluate_every=evaluate_every,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
||||
# 5. Let it grow!
|
||||
self.inferencer.model = trainer.train()
|
||||
self.save(Path(save_dir))
|
||||
|
||||
@ -90,10 +90,10 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"reader = FARMReader(model_name_or_path=\"distilbert-base-uncased-distilled-squad\", use_gpu=False)\n",
|
||||
"reader = FARMReader(model_name_or_path=\"distilbert-base-uncased-distilled-squad\", use_gpu=True)\n",
|
||||
"train_data = \"data/squad20\"\n",
|
||||
"# train_data = \"PATH/TO_YOUR/TRAIN_DATA\" \n",
|
||||
"reader.train(data_dir=train_data, train_filename=\"dev-v2.0.json\", use_gpu=False, n_epochs=1, save_dir=\"my_model\")"
|
||||
"reader.train(data_dir=train_data, train_filename=\"dev-v2.0.json\", use_gpu=True, n_epochs=1, save_dir=\"my_model\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@ -34,10 +34,10 @@ from haystack.reader.farm import FARMReader
|
||||
|
||||
#**Recommendation: Run training on a GPU. To do so change the `use_gpu` arguments below to `True`
|
||||
|
||||
reader = FARMReader(model_name_or_path="distilbert-base-uncased-distilled-squad", use_gpu=False)
|
||||
reader = FARMReader(model_name_or_path="distilbert-base-uncased-distilled-squad", use_gpu=True)
|
||||
train_data = "data/squad20"
|
||||
# train_data = "PATH/TO_YOUR/TRAIN_DATA"
|
||||
reader.train(data_dir=train_data, train_filename="dev-v2.0.json", use_gpu=False, n_epochs=1, save_dir="my_model")
|
||||
reader.train(data_dir=train_data, train_filename="dev-v2.0.json", use_gpu=True, n_epochs=1, save_dir="my_model")
|
||||
|
||||
# Saving the model happens automatically at the end of training into the `save_dir` you specified
|
||||
# However, you could also save a reader manually again via:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user