mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-25 08:04:49 +00:00
fix: Make InferenceProcessor
thread safe (#3709)
* Make TextClassificationProcessor thread-safe by removing self.baskets * Add print statement for debugging * Remove print statement for debugging * Fix mypy
This commit is contained in:
parent
756e0114e6
commit
e266cf6e29
@ -21,6 +21,7 @@ from torch.utils.data import TensorDataset
|
|||||||
import transformers
|
import transformers
|
||||||
from transformers import PreTrainedTokenizer, AutoTokenizer
|
from transformers import PreTrainedTokenizer, AutoTokenizer
|
||||||
|
|
||||||
|
from haystack.errors import HaystackError
|
||||||
from haystack.modeling.model.feature_extraction import (
|
from haystack.modeling.model.feature_extraction import (
|
||||||
tokenize_batch_question_answering,
|
tokenize_batch_question_answering,
|
||||||
tokenize_with_metadata,
|
tokenize_with_metadata,
|
||||||
@ -100,7 +101,6 @@ class Processor(ABC):
|
|||||||
self.data_dir = Path(data_dir)
|
self.data_dir = Path(data_dir)
|
||||||
else:
|
else:
|
||||||
self.data_dir = None # type: ignore
|
self.data_dir = None # type: ignore
|
||||||
self.baskets: List = []
|
|
||||||
|
|
||||||
self._log_params()
|
self._log_params()
|
||||||
self.problematic_sample_ids: set = set()
|
self.problematic_sample_ids: set = set()
|
||||||
@ -477,7 +477,7 @@ class SquadProcessor(Processor):
|
|||||||
# Logging
|
# Logging
|
||||||
if indices:
|
if indices:
|
||||||
if 0 in indices:
|
if 0 in indices:
|
||||||
self._log_samples(n_samples=1, baskets=self.baskets)
|
self._log_samples(n_samples=1, baskets=baskets)
|
||||||
|
|
||||||
# During inference we need to keep the information contained in baskets.
|
# During inference we need to keep the information contained in baskets.
|
||||||
if return_baskets:
|
if return_baskets:
|
||||||
@ -1854,7 +1854,7 @@ class TextClassificationProcessor(Processor):
|
|||||||
def dataset_from_dicts(
|
def dataset_from_dicts(
|
||||||
self, dicts: List[Dict], indices: List[int] = [], return_baskets: bool = False, debug: bool = False
|
self, dicts: List[Dict], indices: List[int] = [], return_baskets: bool = False, debug: bool = False
|
||||||
):
|
):
|
||||||
self.baskets = []
|
baskets = []
|
||||||
# Tokenize in batches
|
# Tokenize in batches
|
||||||
texts = [x["text"] for x in dicts]
|
texts = [x["text"] for x in dicts]
|
||||||
tokenized_batch = self.tokenizer(
|
tokenized_batch = self.tokenizer(
|
||||||
@ -1889,21 +1889,21 @@ class TextClassificationProcessor(Processor):
|
|||||||
label_dict = self.convert_labels(dictionary)
|
label_dict = self.convert_labels(dictionary)
|
||||||
feat_dict.update(label_dict)
|
feat_dict.update(label_dict)
|
||||||
|
|
||||||
# Add Basket to self.baskets
|
# Add Basket to baskets
|
||||||
curr_sample = Sample(id="", clear_text=dictionary, tokenized=tokenized, features=[feat_dict])
|
curr_sample = Sample(id="", clear_text=dictionary, tokenized=tokenized, features=[feat_dict])
|
||||||
curr_basket = SampleBasket(id_internal=None, raw=dictionary, id_external=None, samples=[curr_sample])
|
curr_basket = SampleBasket(id_internal=None, raw=dictionary, id_external=None, samples=[curr_sample])
|
||||||
self.baskets.append(curr_basket)
|
baskets.append(curr_basket)
|
||||||
|
|
||||||
if indices and 0 not in indices:
|
if indices and 0 not in indices:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
self._log_samples(n_samples=1, baskets=self.baskets)
|
self._log_samples(n_samples=1, baskets=baskets)
|
||||||
|
|
||||||
# TODO populate problematic ids
|
# TODO populate problematic ids
|
||||||
problematic_ids: set = set()
|
problematic_ids: set = set()
|
||||||
dataset, tensornames = self._create_dataset()
|
dataset, tensornames = self._create_dataset(baskets)
|
||||||
if return_baskets:
|
if return_baskets:
|
||||||
return dataset, tensornames, problematic_ids, self.baskets
|
return dataset, tensornames, problematic_ids, baskets
|
||||||
else:
|
else:
|
||||||
return dataset, tensornames, problematic_ids
|
return dataset, tensornames, problematic_ids
|
||||||
|
|
||||||
@ -1926,13 +1926,16 @@ class TextClassificationProcessor(Processor):
|
|||||||
ret[task["label_tensor_name"]] = label_ids
|
ret[task["label_tensor_name"]] = label_ids
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def _create_dataset(self):
|
def _create_dataset(self, baskets: List[SampleBasket]):
|
||||||
# TODO this is the proposed new version to replace the mother function
|
features_flat: List = []
|
||||||
features_flat = []
|
|
||||||
basket_to_remove = []
|
basket_to_remove = []
|
||||||
for basket in self.baskets:
|
for basket in baskets:
|
||||||
if self._check_sample_features(basket):
|
if self._check_sample_features(basket):
|
||||||
|
if not isinstance(basket.samples, Iterable):
|
||||||
|
raise HaystackError("basket.samples must contain a list of samples.")
|
||||||
for sample in basket.samples:
|
for sample in basket.samples:
|
||||||
|
if sample.features is None:
|
||||||
|
raise HaystackError("sample.features must not be None.")
|
||||||
features_flat.extend(sample.features)
|
features_flat.extend(sample.features)
|
||||||
else:
|
else:
|
||||||
# remove the entire basket
|
# remove the entire basket
|
||||||
|
Loading…
x
Reference in New Issue
Block a user