mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-25 08:04:49 +00:00
fix: Make InferenceProcessor
thread safe (#3709)
* Make TextClassificationProcessor thread-safe by removing self.baskets * Add print statement for debugging * Remove print statement for debugging * Fix mypy
This commit is contained in:
parent
756e0114e6
commit
e266cf6e29
@ -21,6 +21,7 @@ from torch.utils.data import TensorDataset
|
||||
import transformers
|
||||
from transformers import PreTrainedTokenizer, AutoTokenizer
|
||||
|
||||
from haystack.errors import HaystackError
|
||||
from haystack.modeling.model.feature_extraction import (
|
||||
tokenize_batch_question_answering,
|
||||
tokenize_with_metadata,
|
||||
@ -100,7 +101,6 @@ class Processor(ABC):
|
||||
self.data_dir = Path(data_dir)
|
||||
else:
|
||||
self.data_dir = None # type: ignore
|
||||
self.baskets: List = []
|
||||
|
||||
self._log_params()
|
||||
self.problematic_sample_ids: set = set()
|
||||
@ -477,7 +477,7 @@ class SquadProcessor(Processor):
|
||||
# Logging
|
||||
if indices:
|
||||
if 0 in indices:
|
||||
self._log_samples(n_samples=1, baskets=self.baskets)
|
||||
self._log_samples(n_samples=1, baskets=baskets)
|
||||
|
||||
# During inference we need to keep the information contained in baskets.
|
||||
if return_baskets:
|
||||
@ -1854,7 +1854,7 @@ class TextClassificationProcessor(Processor):
|
||||
def dataset_from_dicts(
|
||||
self, dicts: List[Dict], indices: List[int] = [], return_baskets: bool = False, debug: bool = False
|
||||
):
|
||||
self.baskets = []
|
||||
baskets = []
|
||||
# Tokenize in batches
|
||||
texts = [x["text"] for x in dicts]
|
||||
tokenized_batch = self.tokenizer(
|
||||
@ -1889,21 +1889,21 @@ class TextClassificationProcessor(Processor):
|
||||
label_dict = self.convert_labels(dictionary)
|
||||
feat_dict.update(label_dict)
|
||||
|
||||
# Add Basket to self.baskets
|
||||
# Add Basket to baskets
|
||||
curr_sample = Sample(id="", clear_text=dictionary, tokenized=tokenized, features=[feat_dict])
|
||||
curr_basket = SampleBasket(id_internal=None, raw=dictionary, id_external=None, samples=[curr_sample])
|
||||
self.baskets.append(curr_basket)
|
||||
baskets.append(curr_basket)
|
||||
|
||||
if indices and 0 not in indices:
|
||||
pass
|
||||
else:
|
||||
self._log_samples(n_samples=1, baskets=self.baskets)
|
||||
self._log_samples(n_samples=1, baskets=baskets)
|
||||
|
||||
# TODO populate problematic ids
|
||||
problematic_ids: set = set()
|
||||
dataset, tensornames = self._create_dataset()
|
||||
dataset, tensornames = self._create_dataset(baskets)
|
||||
if return_baskets:
|
||||
return dataset, tensornames, problematic_ids, self.baskets
|
||||
return dataset, tensornames, problematic_ids, baskets
|
||||
else:
|
||||
return dataset, tensornames, problematic_ids
|
||||
|
||||
@ -1926,13 +1926,16 @@ class TextClassificationProcessor(Processor):
|
||||
ret[task["label_tensor_name"]] = label_ids
|
||||
return ret
|
||||
|
||||
def _create_dataset(self):
|
||||
# TODO this is the proposed new version to replace the mother function
|
||||
features_flat = []
|
||||
def _create_dataset(self, baskets: List[SampleBasket]):
|
||||
features_flat: List = []
|
||||
basket_to_remove = []
|
||||
for basket in self.baskets:
|
||||
for basket in baskets:
|
||||
if self._check_sample_features(basket):
|
||||
if not isinstance(basket.samples, Iterable):
|
||||
raise HaystackError("basket.samples must contain a list of samples.")
|
||||
for sample in basket.samples:
|
||||
if sample.features is None:
|
||||
raise HaystackError("sample.features must not be None.")
|
||||
features_flat.extend(sample.features)
|
||||
else:
|
||||
# remove the entire basket
|
||||
|
Loading…
x
Reference in New Issue
Block a user