Fix #1927 - RuntimeError when loading data using data_silo due to many open file descriptors from multiprocessing (#1928)

* fix #1687

* fix RuntimeError: received 0 items of ancdata

* Add an arg multiprocessing_strategy to DataSilo and DPR.train()

* Add latest docstring and tutorial changes

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Alon Eirew 2022-01-04 14:29:26 +02:00 committed by GitHub
parent 381fc302cb
commit 7a4fa42fda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 28 additions and 2 deletions

View File

@ -366,7 +366,7 @@ Embeddings of documents / passages shape (batch_size, embedding_dim)
#### train
```python
| train(data_dir: str, train_filename: str, dev_filename: str = None, test_filename: str = None, max_samples: int = None, max_processes: int = 128, dev_split: float = 0, batch_size: int = 2, embed_title: bool = True, num_hard_negatives: int = 1, num_positives: int = 1, n_epochs: int = 3, evaluate_every: int = 1000, n_gpu: int = 1, learning_rate: float = 1e-5, epsilon: float = 1e-08, weight_decay: float = 0.0, num_warmup_steps: int = 100, grad_acc_steps: int = 1, use_amp: str = None, optimizer_name: str = "AdamW", optimizer_correct_bias: bool = True, save_dir: str = "../saved_models/dpr", query_encoder_save_dir: str = "query_encoder", passage_encoder_save_dir: str = "passage_encoder")
| train(data_dir: str, train_filename: str, dev_filename: str = None, test_filename: str = None, max_samples: int = None, max_processes: int = 128, multiprocessing_strategy: str = 'file_descriptor', dev_split: float = 0, batch_size: int = 2, embed_title: bool = True, num_hard_negatives: int = 1, num_positives: int = 1, n_epochs: int = 3, evaluate_every: int = 1000, n_gpu: int = 1, learning_rate: float = 1e-5, epsilon: float = 1e-08, weight_decay: float = 0.0, num_warmup_steps: int = 100, grad_acc_steps: int = 1, use_amp: str = None, optimizer_name: str = "AdamW", optimizer_correct_bias: bool = True, save_dir: str = "../saved_models/dpr", query_encoder_save_dir: str = "query_encoder", passage_encoder_save_dir: str = "passage_encoder")
```
train a DensePassageRetrieval model
@ -380,6 +380,9 @@ train a DensePassageRetrieval model
- `max_samples`: maximum number of input samples to convert. Can be used for debugging a smaller dataset.
- `max_processes`: the maximum number of processes to spawn in the multiprocessing.Pool used in DataSilo.
It can be set to 1 to disable the use of multiprocessing or make debugging easier.
- `multiprocessing_strategy`: Set the multiprocessing sharing strategy, this can be one of file_descriptor/file_system.
If your system has low limits for the number of open file descriptors, and you cant raise them,
you should use the file_system strategy.
- `dev_split`: The proportion of the train set that will sliced. Only works if dev_filename is set to None
- `batch_size`: total number of samples in 1 batch of data
- `embed_title`: whether to concatenate passage title with each passage. The default setting in official DPR embeds passage title with the corresponding passage

View File

@ -44,6 +44,7 @@ class DataSilo:
automatic_loading: bool = True,
max_multiprocessing_chunksize: int = 2000,
max_processes: int = 128,
multiprocessing_strategy: str = 'file_descriptor',
caching: bool = False,
cache_path: Path = Path("cache/data_silo"),
):
@ -59,6 +60,9 @@ class DataSilo:
values are rather large that might cause memory issues.
:param max_processes: the maximum number of processes to spawn in the multiprocessing.Pool used in DataSilo.
It can be set to 1 to disable the use of multiprocessing or make debugging easier.
:multiprocessing_strategy: Set the multiprocessing sharing strategy, this can be one of file_descriptor/file_system.
If your system has low limits for the number of open file descriptors, and you cant raise them,
you should use the file_system strategy.
:param caching: save the processed datasets on disk to save time/compute if the same train data is used to run
multiple experiments. Each cache has a checksum based on the train_filename of the Processor
and the batch size.
@ -70,6 +74,7 @@ class DataSilo:
self.batch_size = batch_size
self.class_weights = None
self.max_processes = max_processes
self.multiprocessing_strategy = multiprocessing_strategy
self.max_multiprocessing_chunksize = max_multiprocessing_chunksize
self.caching = caching
self.cache_path = cache_path
@ -138,6 +143,14 @@ class DataSilo:
with ExitStack() as stack:
if self.max_processes > 1: # use multiprocessing only when max_processes > 1
if self.multiprocessing_strategy and self.multiprocessing_strategy in ['file_descriptor', 'file_system']:
mp.set_sharing_strategy(self.multiprocessing_strategy)
else:
logger.warning(
f"{self.multiprocessing_strategy} is an invalid strategy, "
f"falling back to default (file_descriptor) strategy."
)
p = stack.enter_context(mp.Pool(processes=num_cpus_used))
logger.info(

View File

@ -308,6 +308,7 @@ class DensePassageRetriever(BaseRetriever):
test_filename: str = None,
max_samples: int = None,
max_processes: int = 128,
multiprocessing_strategy: str = 'file_descriptor',
dev_split: float = 0,
batch_size: int = 2,
embed_title: bool = True,
@ -337,6 +338,9 @@ class DensePassageRetriever(BaseRetriever):
:param max_samples: maximum number of input samples to convert. Can be used for debugging a smaller dataset.
:param max_processes: the maximum number of processes to spawn in the multiprocessing.Pool used in DataSilo.
It can be set to 1 to disable the use of multiprocessing or make debugging easier.
:param multiprocessing_strategy: Set the multiprocessing sharing strategy, this can be one of file_descriptor/file_system.
If your system has low limits for the number of open file descriptors, and you cant raise them,
you should use the file_system strategy.
:param dev_split: The proportion of the train set that will sliced. Only works if dev_filename is set to None
:param batch_size: total number of samples in 1 batch of data
:param embed_title: whether to concatenate passage title with each passage. The default setting in official DPR embeds passage title with the corresponding passage
@ -377,7 +381,13 @@ class DensePassageRetriever(BaseRetriever):
else:
self.model.connect_heads_with_processor(self.processor.tasks, require_labels=True)
data_silo = DataSilo(processor=self.processor, batch_size=batch_size, distributed=False, max_processes=max_processes)
data_silo = DataSilo(
processor=self.processor,
batch_size=batch_size,
distributed=False,
max_processes=max_processes,
multiprocessing_strategy=multiprocessing_strategy
)
# 5. Create an optimizer
self.model, optimizer, lr_schedule = initialize_optimizer(