mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-28 15:38:36 +00:00
refactor: remove Inferencer multiprocessing (#3283)
This commit is contained in:
parent
b49bce97aa
commit
6cb4e93965
@ -1,9 +1,7 @@
|
||||
from typing import List, Optional, Dict, Union, Generator, Set, Any
|
||||
from typing import List, Optional, Dict, Union, Set, Any
|
||||
|
||||
import os
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
from functools import partial
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from torch.utils.data.sampler import SequentialSampler
|
||||
@ -12,13 +10,7 @@ from torch.utils.data import Dataset
|
||||
from haystack.modeling.data_handler.dataloader import NamedDataLoader
|
||||
from haystack.modeling.data_handler.processor import Processor, InferenceProcessor
|
||||
from haystack.modeling.data_handler.samples import SampleBasket
|
||||
from haystack.modeling.utils import (
|
||||
grouper,
|
||||
initialize_device_settings,
|
||||
set_all_seeds,
|
||||
calc_chunksize,
|
||||
log_ascii_workers,
|
||||
)
|
||||
from haystack.modeling.utils import initialize_device_settings, set_all_seeds
|
||||
from haystack.modeling.data_handler.inputs import QAInput
|
||||
from haystack.modeling.model.adaptive_model import AdaptiveModel, BaseAdaptiveModel
|
||||
from haystack.modeling.model.predictions import QAPred
|
||||
@ -70,6 +62,9 @@ class Inferencer:
|
||||
`multiprocessing.Pool` again! To do so call
|
||||
:func:`~farm.infer.Inferencer.close_multiprocessing_pool` after you are
|
||||
done using this class. The garbage collector will not do this for you!
|
||||
.. deprecated:: 1.10
|
||||
This parameter has no effect; it will be removed as Inferencer multiprocessing
|
||||
has been deprecated.
|
||||
:param disable_tqdm: Whether to disable tqdm logging (can get very verbose in multiprocessing)
|
||||
:param devices: List of torch devices (e.g. cuda, cpu, mps) to limit inference to specific devices.
|
||||
A list containing torch device objects and/or strings is supported (For example
|
||||
@ -113,8 +108,6 @@ class Inferencer:
|
||||
model.connect_heads_with_processor(processor.tasks, require_labels=False)
|
||||
set_all_seeds(42)
|
||||
|
||||
self._set_multiprocessing_pool(num_processes)
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
@ -166,6 +159,9 @@ class Inferencer:
|
||||
`multiprocessing.Pool` again! To do so call
|
||||
:func:`~farm.infer.Inferencer.close_multiprocessing_pool` after you are
|
||||
done using this class. The garbage collector will not do this for you!
|
||||
.. deprecated:: 1.10
|
||||
This parameter has no effect; it will be removed as Inferencer multiprocessing
|
||||
has been deprecated.
|
||||
:param disable_tqdm: Whether to disable tqdm logging (can get very verbose in multiprocessing)
|
||||
:param tokenizer_class: (Optional) Name of the tokenizer class to load (e.g. `BertTokenizer`)
|
||||
:param use_fast: (Optional, True by default) Indicate if FARM should try to load the fast version of the tokenizer (True) or
|
||||
@ -259,48 +255,6 @@ class Inferencer:
|
||||
devices=devices,
|
||||
)
|
||||
|
||||
def _set_multiprocessing_pool(self, num_processes: Optional[int]) -> None:
|
||||
"""
|
||||
Initialize a multiprocessing.Pool for instances of Inferencer.
|
||||
|
||||
:param num_processes: the number of processes for `multiprocessing.Pool`.
|
||||
Set to value of 1 (or 0) to disable multiprocessing.
|
||||
Set to None to let Inferencer use all CPU cores minus one.
|
||||
If you want to debug the Language Model, you might need to disable multiprocessing!
|
||||
**Warning!** If you use multiprocessing you have to close the
|
||||
`multiprocessing.Pool` again! To do so call
|
||||
:func:`~farm.infer.Inferencer.close_multiprocessing_pool` after you are
|
||||
done using this class. The garbage collector will not do this for you!
|
||||
:return: None
|
||||
"""
|
||||
self.process_pool = None
|
||||
if num_processes == 0 or num_processes == 1: # disable multiprocessing
|
||||
self.process_pool = None
|
||||
else:
|
||||
if num_processes is None: # use all CPU cores
|
||||
if mp.cpu_count() > 3:
|
||||
num_processes = mp.cpu_count() - 1
|
||||
else:
|
||||
num_processes = mp.cpu_count()
|
||||
self.process_pool = mp.Pool(processes=num_processes)
|
||||
logger.info("Got ya %s parallel workers to do inference ...", num_processes)
|
||||
log_ascii_workers(n=num_processes, logger=logger)
|
||||
|
||||
def close_multiprocessing_pool(self, join: bool = False):
|
||||
"""Close the `multiprocessing.Pool` again.
|
||||
|
||||
If you use multiprocessing you have to close the `multiprocessing.Pool` again!
|
||||
To do so call this function after you are done using this class.
|
||||
The garbage collector will not do this for you!
|
||||
|
||||
:param join: wait for the worker processes to exit
|
||||
"""
|
||||
if self.process_pool is not None:
|
||||
self.process_pool.close()
|
||||
if join:
|
||||
self.process_pool.join()
|
||||
self.process_pool = None
|
||||
|
||||
def save(self, path: str):
|
||||
self.model.save(path)
|
||||
self.processor.save(path)
|
||||
@ -313,6 +267,9 @@ class Inferencer:
|
||||
|
||||
:param file: path of the input file for Inference
|
||||
:param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process
|
||||
.. deprecated:: 1.10
|
||||
This parameter has no effect; it will be removed as Inferencer multiprocessing
|
||||
has been deprecated.
|
||||
:return: list of predictions
|
||||
"""
|
||||
dicts = self.processor.file_to_dicts(file)
|
||||
@ -333,8 +290,11 @@ class Inferencer:
|
||||
One dict per sample.
|
||||
:param return_json: Whether the output should be in a json appropriate format. If False, it returns the prediction
|
||||
object where applicable, else it returns PredObj.to_json()
|
||||
:param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process
|
||||
:param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process
|
||||
(only relevant if you do multiprocessing)
|
||||
.. deprecated:: 1.10
|
||||
This parameter has no effect; it will be removed as Inferencer multiprocessing
|
||||
has been deprecated.
|
||||
:return: list of predictions
|
||||
"""
|
||||
# whether to aggregate predictions across different samples (e.g. for QA on long texts)
|
||||
@ -346,26 +306,8 @@ class Inferencer:
|
||||
if len(self.model.prediction_heads) > 0:
|
||||
aggregate_preds = hasattr(self.model.prediction_heads[0], "aggregate_preds")
|
||||
|
||||
if self.process_pool is None: # multiprocessing disabled (helpful for debugging or using in web frameworks)
|
||||
predictions: Any = self._inference_without_multiprocessing(dicts, return_json, aggregate_preds)
|
||||
return predictions
|
||||
else: # use multiprocessing for inference
|
||||
# Calculate values of multiprocessing_chunksize and num_processes if not supplied in the parameters.
|
||||
|
||||
if multiprocessing_chunksize is None:
|
||||
_chunk_size, _ = calc_chunksize(len(dicts))
|
||||
multiprocessing_chunksize = _chunk_size
|
||||
|
||||
predictions = self._inference_with_multiprocessing(
|
||||
dicts, return_json, aggregate_preds, multiprocessing_chunksize
|
||||
)
|
||||
|
||||
self.processor.log_problematic(self.problematic_sample_ids)
|
||||
# cast the generator to a list if it isnt already a list.
|
||||
if type(predictions) != list:
|
||||
return list(predictions)
|
||||
else:
|
||||
return predictions
|
||||
predictions: Any = self._inference_without_multiprocessing(dicts, return_json, aggregate_preds)
|
||||
return predictions
|
||||
|
||||
def _inference_without_multiprocessing(self, dicts: List[Dict], return_json: bool, aggregate_preds: bool) -> List:
|
||||
"""
|
||||
@ -399,69 +341,6 @@ class Inferencer:
|
||||
|
||||
return preds_all
|
||||
|
||||
def _inference_with_multiprocessing(
|
||||
self,
|
||||
dicts: Union[List[Dict], Generator[Dict, None, None]],
|
||||
return_json: bool,
|
||||
aggregate_preds: bool,
|
||||
multiprocessing_chunksize: int,
|
||||
) -> Generator[Dict, None, None]:
|
||||
"""
|
||||
Implementation of inference. This method is a generator that yields the results.
|
||||
|
||||
:param dicts: Samples to run inference on provided as a list of dicts or a generator object that yield dicts.
|
||||
:param return_json: Whether the output should be in a json appropriate format. If False, it returns the prediction
|
||||
object where applicable, else it returns PredObj.to_json()
|
||||
:param aggregate_preds: whether to aggregate predictions across different samples (e.g. for QA on long texts)
|
||||
:param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process
|
||||
:return: generator object that yield predictions
|
||||
"""
|
||||
|
||||
# We group the input dicts into chunks and feed each chunk to a different process
|
||||
# in the pool, where it gets converted to a pytorch dataset
|
||||
if self.process_pool is not None:
|
||||
results = self.process_pool.imap(
|
||||
partial(self._create_datasets_chunkwise, processor=self.processor),
|
||||
grouper(iterable=dicts, n=multiprocessing_chunksize),
|
||||
1,
|
||||
)
|
||||
|
||||
# Once a process spits out a preprocessed chunk. we feed this dataset directly to the model.
|
||||
# So we don't need to wait until all preprocessing has finished before getting first predictions.
|
||||
for dataset, tensor_names, problematic_sample_ids, baskets in results:
|
||||
self.problematic_sample_ids.update(problematic_sample_ids)
|
||||
if dataset is None:
|
||||
logger.error(
|
||||
f"Part of the dataset could not be converted! \n"
|
||||
f"BE AWARE: The order of predictions will not conform with the input order!"
|
||||
)
|
||||
else:
|
||||
# TODO change format of formatted_preds in QA (list of dicts)
|
||||
if aggregate_preds:
|
||||
predictions = self._get_predictions_and_aggregate(dataset, tensor_names, baskets)
|
||||
else:
|
||||
predictions = self._get_predictions(dataset, tensor_names, baskets)
|
||||
|
||||
if return_json:
|
||||
# TODO this try catch should be removed when all tasks return prediction objects
|
||||
try:
|
||||
predictions = [x.to_json() for x in predictions]
|
||||
except AttributeError:
|
||||
pass
|
||||
yield from predictions
|
||||
|
||||
@classmethod
|
||||
def _create_datasets_chunkwise(cls, chunk, processor: Processor):
|
||||
"""Convert ONE chunk of data (i.e. dictionaries) into ONE pytorch dataset.
|
||||
This is usually executed in one of many parallel processes.
|
||||
The resulting datasets of the processes are merged together afterwards"""
|
||||
dicts = [d[1] for d in chunk]
|
||||
indices = [d[0] for d in chunk]
|
||||
dataset, tensor_names, problematic_sample_ids, baskets = processor.dataset_from_dicts(
|
||||
dicts, indices, return_baskets=True
|
||||
)
|
||||
return dataset, tensor_names, problematic_sample_ids, baskets
|
||||
|
||||
def _get_predictions(self, dataset: Dataset, tensor_names: List, baskets):
|
||||
"""
|
||||
Feed a preprocessed dataset to the model and get the actual predictions (forward pass + formatting).
|
||||
@ -592,6 +471,13 @@ class QAInferencer(Inferencer):
|
||||
def inference_from_dicts(
|
||||
self, dicts: List[dict], return_json: bool = True, multiprocessing_chunksize: Optional[int] = None
|
||||
) -> List[QAPred]:
|
||||
"""
|
||||
:param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process
|
||||
(only relevant if you do multiprocessing)
|
||||
.. deprecated:: 1.10
|
||||
This parameter has no effect; it will be removed as Inferencer multiprocessing
|
||||
has been deprecated.
|
||||
"""
|
||||
return Inferencer.inference_from_dicts(
|
||||
self, dicts, return_json=return_json, multiprocessing_chunksize=multiprocessing_chunksize
|
||||
)
|
||||
@ -599,6 +485,13 @@ class QAInferencer(Inferencer):
|
||||
def inference_from_file(
|
||||
self, file: str, multiprocessing_chunksize: Optional[int] = None, return_json=True
|
||||
) -> List[QAPred]:
|
||||
"""
|
||||
:param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process
|
||||
(only relevant if you do multiprocessing)
|
||||
.. deprecated:: 1.10
|
||||
This parameter has no effect; it will be removed as Inferencer multiprocessing
|
||||
has been deprecated.
|
||||
"""
|
||||
return Inferencer.inference_from_file(
|
||||
self, file, return_json=return_json, multiprocessing_chunksize=multiprocessing_chunksize
|
||||
)
|
||||
@ -606,6 +499,13 @@ class QAInferencer(Inferencer):
|
||||
def inference_from_objects(
|
||||
self, objects: List[QAInput], return_json: bool = True, multiprocessing_chunksize: Optional[int] = None
|
||||
) -> List[QAPred]:
|
||||
"""
|
||||
:param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process
|
||||
(only relevant if you do multiprocessing)
|
||||
.. deprecated:: 1.10
|
||||
This parameter has no effect; it will be removed as Inferencer multiprocessing
|
||||
has been deprecated.
|
||||
"""
|
||||
dicts = [o.to_dict() for o in objects]
|
||||
# TODO investigate this deprecation warning. Timo: I thought we were about to implement Input Objects,
|
||||
# then we can and should use inference from (input) objects!
|
||||
|
||||
@ -1115,22 +1115,15 @@ def adaptive_model_qa(num_processes):
|
||||
"""
|
||||
PyTest Fixture for a Question Answering Inferencer based on PyTorch.
|
||||
"""
|
||||
try:
|
||||
model = Inferencer.load(
|
||||
"deepset/bert-base-cased-squad2",
|
||||
task_type="question_answering",
|
||||
batch_size=16,
|
||||
num_processes=num_processes,
|
||||
gpu=False,
|
||||
)
|
||||
yield model
|
||||
finally:
|
||||
if num_processes != 0:
|
||||
# close the pool
|
||||
# we pass join=True to wait for all sub processes to close
|
||||
# this is because below we want to test if all sub-processes
|
||||
# have exited
|
||||
model.close_multiprocessing_pool(join=True)
|
||||
|
||||
model = Inferencer.load(
|
||||
"deepset/bert-base-cased-squad2",
|
||||
task_type="question_answering",
|
||||
batch_size=16,
|
||||
num_processes=num_processes,
|
||||
gpu=False,
|
||||
)
|
||||
yield model
|
||||
|
||||
# check if all workers (sub processes) are closed
|
||||
current_process = psutil.Process()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user