From e266cf6e29f78df751d9dbe7a505886579233aa5 Mon Sep 17 00:00:00 2001 From: bogdankostic Date: Wed, 21 Dec 2022 18:08:41 +0100 Subject: [PATCH] fix: Make `InferenceProcessor` thread safe (#3709) * Make TextClassificationProcessor thread-safe by removing self.baskets * Add print statement for debugging * Remove print statement for debugging * Fix mypy --- haystack/modeling/data_handler/processor.py | 27 ++++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/haystack/modeling/data_handler/processor.py b/haystack/modeling/data_handler/processor.py index 5667daaff..90fae0317 100644 --- a/haystack/modeling/data_handler/processor.py +++ b/haystack/modeling/data_handler/processor.py @@ -21,6 +21,7 @@ from torch.utils.data import TensorDataset import transformers from transformers import PreTrainedTokenizer, AutoTokenizer +from haystack.errors import HaystackError from haystack.modeling.model.feature_extraction import ( tokenize_batch_question_answering, tokenize_with_metadata, @@ -100,7 +101,6 @@ class Processor(ABC): self.data_dir = Path(data_dir) else: self.data_dir = None # type: ignore - self.baskets: List = [] self._log_params() self.problematic_sample_ids: set = set() @@ -477,7 +477,7 @@ class SquadProcessor(Processor): # Logging if indices: if 0 in indices: - self._log_samples(n_samples=1, baskets=self.baskets) + self._log_samples(n_samples=1, baskets=baskets) # During inference we need to keep the information contained in baskets. if return_baskets: @@ -1854,7 +1854,7 @@ class TextClassificationProcessor(Processor): def dataset_from_dicts( self, dicts: List[Dict], indices: List[int] = [], return_baskets: bool = False, debug: bool = False ): - self.baskets = [] + baskets = [] # Tokenize in batches texts = [x["text"] for x in dicts] tokenized_batch = self.tokenizer( @@ -1889,21 +1889,21 @@ class TextClassificationProcessor(Processor): label_dict = self.convert_labels(dictionary) feat_dict.update(label_dict) - # Add Basket to self.baskets + # Add Basket to baskets curr_sample = Sample(id="", clear_text=dictionary, tokenized=tokenized, features=[feat_dict]) curr_basket = SampleBasket(id_internal=None, raw=dictionary, id_external=None, samples=[curr_sample]) - self.baskets.append(curr_basket) + baskets.append(curr_basket) if indices and 0 not in indices: pass else: - self._log_samples(n_samples=1, baskets=self.baskets) + self._log_samples(n_samples=1, baskets=baskets) # TODO populate problematic ids problematic_ids: set = set() - dataset, tensornames = self._create_dataset() + dataset, tensornames = self._create_dataset(baskets) if return_baskets: - return dataset, tensornames, problematic_ids, self.baskets + return dataset, tensornames, problematic_ids, baskets else: return dataset, tensornames, problematic_ids @@ -1926,13 +1926,16 @@ class TextClassificationProcessor(Processor): ret[task["label_tensor_name"]] = label_ids return ret - def _create_dataset(self): - # TODO this is the proposed new version to replace the mother function - features_flat = [] + def _create_dataset(self, baskets: List[SampleBasket]): + features_flat: List = [] basket_to_remove = [] - for basket in self.baskets: + for basket in baskets: if self._check_sample_features(basket): + if not isinstance(basket.samples, Iterable): + raise HaystackError("basket.samples must contain a list of samples.") for sample in basket.samples: + if sample.features is None: + raise HaystackError("sample.features must not be None.") features_flat.extend(sample.features) else: # remove the entire basket