fix: Make InferenceProcessor thread safe (#3709)

* Make TextClassificationProcessor thread-safe by removing self.baskets

* Add print statement for debugging

* Remove print statement for debugging

* Fix mypy
This commit is contained in:
bogdankostic 2022-12-21 18:08:41 +01:00 committed by GitHub
parent 756e0114e6
commit e266cf6e29
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -21,6 +21,7 @@ from torch.utils.data import TensorDataset
import transformers import transformers
from transformers import PreTrainedTokenizer, AutoTokenizer from transformers import PreTrainedTokenizer, AutoTokenizer
from haystack.errors import HaystackError
from haystack.modeling.model.feature_extraction import ( from haystack.modeling.model.feature_extraction import (
tokenize_batch_question_answering, tokenize_batch_question_answering,
tokenize_with_metadata, tokenize_with_metadata,
@ -100,7 +101,6 @@ class Processor(ABC):
self.data_dir = Path(data_dir) self.data_dir = Path(data_dir)
else: else:
self.data_dir = None # type: ignore self.data_dir = None # type: ignore
self.baskets: List = []
self._log_params() self._log_params()
self.problematic_sample_ids: set = set() self.problematic_sample_ids: set = set()
@ -477,7 +477,7 @@ class SquadProcessor(Processor):
# Logging # Logging
if indices: if indices:
if 0 in indices: if 0 in indices:
self._log_samples(n_samples=1, baskets=self.baskets) self._log_samples(n_samples=1, baskets=baskets)
# During inference we need to keep the information contained in baskets. # During inference we need to keep the information contained in baskets.
if return_baskets: if return_baskets:
@ -1854,7 +1854,7 @@ class TextClassificationProcessor(Processor):
def dataset_from_dicts( def dataset_from_dicts(
self, dicts: List[Dict], indices: List[int] = [], return_baskets: bool = False, debug: bool = False self, dicts: List[Dict], indices: List[int] = [], return_baskets: bool = False, debug: bool = False
): ):
self.baskets = [] baskets = []
# Tokenize in batches # Tokenize in batches
texts = [x["text"] for x in dicts] texts = [x["text"] for x in dicts]
tokenized_batch = self.tokenizer( tokenized_batch = self.tokenizer(
@ -1889,21 +1889,21 @@ class TextClassificationProcessor(Processor):
label_dict = self.convert_labels(dictionary) label_dict = self.convert_labels(dictionary)
feat_dict.update(label_dict) feat_dict.update(label_dict)
# Add Basket to self.baskets # Add Basket to baskets
curr_sample = Sample(id="", clear_text=dictionary, tokenized=tokenized, features=[feat_dict]) curr_sample = Sample(id="", clear_text=dictionary, tokenized=tokenized, features=[feat_dict])
curr_basket = SampleBasket(id_internal=None, raw=dictionary, id_external=None, samples=[curr_sample]) curr_basket = SampleBasket(id_internal=None, raw=dictionary, id_external=None, samples=[curr_sample])
self.baskets.append(curr_basket) baskets.append(curr_basket)
if indices and 0 not in indices: if indices and 0 not in indices:
pass pass
else: else:
self._log_samples(n_samples=1, baskets=self.baskets) self._log_samples(n_samples=1, baskets=baskets)
# TODO populate problematic ids # TODO populate problematic ids
problematic_ids: set = set() problematic_ids: set = set()
dataset, tensornames = self._create_dataset() dataset, tensornames = self._create_dataset(baskets)
if return_baskets: if return_baskets:
return dataset, tensornames, problematic_ids, self.baskets return dataset, tensornames, problematic_ids, baskets
else: else:
return dataset, tensornames, problematic_ids return dataset, tensornames, problematic_ids
@ -1926,13 +1926,16 @@ class TextClassificationProcessor(Processor):
ret[task["label_tensor_name"]] = label_ids ret[task["label_tensor_name"]] = label_ids
return ret return ret
def _create_dataset(self): def _create_dataset(self, baskets: List[SampleBasket]):
# TODO this is the proposed new version to replace the mother function features_flat: List = []
features_flat = []
basket_to_remove = [] basket_to_remove = []
for basket in self.baskets: for basket in baskets:
if self._check_sample_features(basket): if self._check_sample_features(basket):
if not isinstance(basket.samples, Iterable):
raise HaystackError("basket.samples must contain a list of samples.")
for sample in basket.samples: for sample in basket.samples:
if sample.features is None:
raise HaystackError("sample.features must not be None.")
features_flat.extend(sample.features) features_flat.extend(sample.features)
else: else:
# remove the entire basket # remove the entire basket