diff --git a/docs/_src/api/api/retriever.md b/docs/_src/api/api/retriever.md index 5bd86c8b0..25ccd7409 100644 --- a/docs/_src/api/api/retriever.md +++ b/docs/_src/api/api/retriever.md @@ -366,7 +366,7 @@ Embeddings of documents / passages shape (batch_size, embedding_dim) #### train ```python - | train(data_dir: str, train_filename: str, dev_filename: str = None, test_filename: str = None, max_samples: int = None, max_processes: int = 128, multiprocessing_strategy: str = 'file_descriptor', dev_split: float = 0, batch_size: int = 2, embed_title: bool = True, num_hard_negatives: int = 1, num_positives: int = 1, n_epochs: int = 3, evaluate_every: int = 1000, n_gpu: int = 1, learning_rate: float = 1e-5, epsilon: float = 1e-08, weight_decay: float = 0.0, num_warmup_steps: int = 100, grad_acc_steps: int = 1, use_amp: str = None, optimizer_name: str = "AdamW", optimizer_correct_bias: bool = True, save_dir: str = "../saved_models/dpr", query_encoder_save_dir: str = "query_encoder", passage_encoder_save_dir: str = "passage_encoder") + | train(data_dir: str, train_filename: str, dev_filename: str = None, test_filename: str = None, max_samples: int = None, max_processes: int = 128, multiprocessing_strategy: Optional[str] = None, dev_split: float = 0, batch_size: int = 2, embed_title: bool = True, num_hard_negatives: int = 1, num_positives: int = 1, n_epochs: int = 3, evaluate_every: int = 1000, n_gpu: int = 1, learning_rate: float = 1e-5, epsilon: float = 1e-08, weight_decay: float = 0.0, num_warmup_steps: int = 100, grad_acc_steps: int = 1, use_amp: str = None, optimizer_name: str = "AdamW", optimizer_correct_bias: bool = True, save_dir: str = "../saved_models/dpr", query_encoder_save_dir: str = "query_encoder", passage_encoder_save_dir: str = "passage_encoder") ``` train a DensePassageRetrieval model @@ -380,7 +380,7 @@ train a DensePassageRetrieval model - `max_samples`: maximum number of input samples to convert. Can be used for debugging a smaller dataset. - `max_processes`: the maximum number of processes to spawn in the multiprocessing.Pool used in DataSilo. It can be set to 1 to disable the use of multiprocessing or make debugging easier. -- `multiprocessing_strategy`: Set the multiprocessing sharing strategy, this can be one of file_descriptor/file_system. +- `multiprocessing_strategy`: Set the multiprocessing sharing strategy, this can be one of file_descriptor/file_system depending on your OS. If your system has low limits for the number of open file descriptors, and you can’t raise them, you should use the file_system strategy. - `dev_split`: The proportion of the train set that will sliced. Only works if dev_filename is set to None diff --git a/haystack/modeling/data_handler/data_silo.py b/haystack/modeling/data_handler/data_silo.py index 623c8b1f6..fdad73d23 100644 --- a/haystack/modeling/data_handler/data_silo.py +++ b/haystack/modeling/data_handler/data_silo.py @@ -44,7 +44,7 @@ class DataSilo: automatic_loading: bool = True, max_multiprocessing_chunksize: int = 2000, max_processes: int = 128, - multiprocessing_strategy: str = 'file_descriptor', + multiprocessing_strategy: Optional[str] = None, caching: bool = False, cache_path: Path = Path("cache/data_silo"), ): @@ -60,7 +60,7 @@ class DataSilo: values are rather large that might cause memory issues. :param max_processes: the maximum number of processes to spawn in the multiprocessing.Pool used in DataSilo. It can be set to 1 to disable the use of multiprocessing or make debugging easier. - :multiprocessing_strategy: Set the multiprocessing sharing strategy, this can be one of file_descriptor/file_system. + :multiprocessing_strategy: Set the multiprocessing sharing strategy, this can be one of file_descriptor/file_system depending on your OS. If your system has low limits for the number of open file descriptors, and you can’t raise them, you should use the file_system strategy. :param caching: save the processed datasets on disk to save time/compute if the same train data is used to run @@ -143,13 +143,14 @@ class DataSilo: with ExitStack() as stack: if self.max_processes > 1: # use multiprocessing only when max_processes > 1 - if self.multiprocessing_strategy and self.multiprocessing_strategy in ['file_descriptor', 'file_system']: - mp.set_sharing_strategy(self.multiprocessing_strategy) - else: - logger.warning( - f"{self.multiprocessing_strategy} is an invalid strategy, " - f"falling back to default (file_descriptor) strategy." - ) + if self.multiprocessing_strategy: + if self.multiprocessing_strategy in mp.get_all_sharing_strategies(): + mp.set_sharing_strategy(self.multiprocessing_strategy) + else: + logger.warning( + f"{self.multiprocessing_strategy} is unavailable, " + f"falling back to default multiprocessing sharing strategy of your OS." + ) p = stack.enter_context(mp.Pool(processes=num_cpus_used)) @@ -823,4 +824,4 @@ class DistillationDataSilo(DataSilo): "data_silo_type": self.__class__.__name__, } checksum = get_dict_checksum(payload_dict) - return checksum \ No newline at end of file + return checksum diff --git a/haystack/nodes/retriever/dense.py b/haystack/nodes/retriever/dense.py index 49cbc6423..b829dc595 100644 --- a/haystack/nodes/retriever/dense.py +++ b/haystack/nodes/retriever/dense.py @@ -308,7 +308,7 @@ class DensePassageRetriever(BaseRetriever): test_filename: str = None, max_samples: int = None, max_processes: int = 128, - multiprocessing_strategy: str = 'file_descriptor', + multiprocessing_strategy: Optional[str] = None, dev_split: float = 0, batch_size: int = 2, embed_title: bool = True, @@ -338,7 +338,7 @@ class DensePassageRetriever(BaseRetriever): :param max_samples: maximum number of input samples to convert. Can be used for debugging a smaller dataset. :param max_processes: the maximum number of processes to spawn in the multiprocessing.Pool used in DataSilo. It can be set to 1 to disable the use of multiprocessing or make debugging easier. - :param multiprocessing_strategy: Set the multiprocessing sharing strategy, this can be one of file_descriptor/file_system. + :param multiprocessing_strategy: Set the multiprocessing sharing strategy, this can be one of file_descriptor/file_system depending on your OS. If your system has low limits for the number of open file descriptors, and you can’t raise them, you should use the file_system strategy. :param dev_split: The proportion of the train set that will sliced. Only works if dev_filename is set to None