Fix for allocate memory exception by specifing max_processes (#910)

This commit is contained in:
Moshe Berchansky 2021-03-19 19:11:25 +02:00 committed by GitHub
parent f954f0db38
commit 47dc069afe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -270,6 +270,7 @@ class DensePassageRetriever(BaseRetriever):
train_filename: str,
dev_filename: str = None,
test_filename: str = None,
max_processes: int = 128,
dev_split: float = 0,
batch_size: int = 2,
embed_title: bool = True,
@ -295,6 +296,8 @@ class DensePassageRetriever(BaseRetriever):
:param train_filename: training filename
:param dev_filename: development set filename, file to be used by model in eval step of training
:param test_filename: test set filename, file to be used by model in test step after training
:param max_processes: the maximum number of processes to spawn in the multiprocessing.Pool used in DataSilo.
It can be set to 1 to disable the use of multiprocessing or make debugging easier.
:param dev_split: The proportion of the train set that will sliced. Only works if dev_filename is set to None
:param batch_size: total number of samples in 1 batch of data
:param embed_title: whether to concatenate passage title with each passage. The default setting in official DPR embeds passage title with the corresponding passage
@ -333,7 +336,7 @@ class DensePassageRetriever(BaseRetriever):
self.model.connect_heads_with_processor(self.processor.tasks, require_labels=True)
data_silo = DataSilo(processor=self.processor, batch_size=batch_size, distributed=False)
data_silo = DataSilo(processor=self.processor, batch_size=batch_size, distributed=False, max_processes=max_processes)
# 5. Create an optimizer
self.model, optimizer, lr_schedule = initialize_optimizer(