diff --git a/haystack/modeling/data_handler/data_silo.py b/haystack/modeling/data_handler/data_silo.py index f7237b8d2..91271a7c0 100644 --- a/haystack/modeling/data_handler/data_silo.py +++ b/haystack/modeling/data_handler/data_silo.py @@ -4,14 +4,11 @@ import hashlib import json import logging import random -from contextlib import ExitStack -from functools import partial from itertools import groupby from pathlib import Path import numpy as np from tqdm import tqdm import torch -import torch.multiprocessing as mp from torch.utils.data import ConcatDataset, Dataset from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import RandomSampler, SequentialSampler @@ -19,7 +16,6 @@ from torch.utils.data.sampler import RandomSampler, SequentialSampler from haystack.modeling.data_handler.dataloader import NamedDataLoader from haystack.modeling.data_handler.processor import Processor from haystack.utils.experiment_tracking import Tracker as tracker -from haystack.modeling.utils import log_ascii_workers, grouper, calc_chunksize from haystack.modeling.visual import TRACTOR_SMALL if TYPE_CHECKING: @@ -41,7 +37,7 @@ class DataSilo: eval_batch_size: Optional[int] = None, distributed: bool = False, automatic_loading: bool = True, - max_multiprocessing_chunksize: int = 2000, + max_multiprocessing_chunksize: int = 512, max_processes: int = 128, multiprocessing_strategy: Optional[str] = None, caching: bool = False, @@ -59,9 +55,13 @@ class DataSilo: values are rather large that might cause memory issues. :param max_processes: the maximum number of processes to spawn in the multiprocessing.Pool used in DataSilo. It can be set to 1 to disable the use of multiprocessing or make debugging easier. + .. deprecated:: 1.9 + Multiprocessing has been removed in 1.9. This parameter will be ignored. :multiprocessing_strategy: Set the multiprocessing sharing strategy, this can be one of file_descriptor/file_system depending on your OS. If your system has low limits for the number of open file descriptors, and you can’t raise them, you should use the file_system strategy. + .. deprecated:: 1.9 + Multiprocessing has been removed in 1.9. This parameter will be ignored. :param caching: save the processed datasets on disk to save time/compute if the same train data is used to run multiple experiments. Each cache has a checksum based on the train_filename of the Processor and the batch size. @@ -103,25 +103,6 @@ class DataSilo: # later or load from dicts instead of file self._load_data() - @classmethod - def _dataset_from_chunk(cls, chunk: List[Tuple[int, Dict]], processor: Processor): - """ - Creating a dataset for a chunk (= subset) of dicts. In multiprocessing: - * we read in all dicts from a file - * split all dicts into chunks - * feed *one chunk* to *one process* - => the *one chunk* gets converted to *one dataset* (that's what we do here) - * all datasets get collected and concatenated - :param chunk: Instead of only having a list of dicts here we also supply an index (ascending int) for each. - => [(0, dict), (1, dict) ...] - :param processor: Haystack basics Processor (e.g. SquadProcessor) - :return: PyTorch Dataset - """ - dicts = [d[1] for d in chunk] - indices = [x[0] for x in chunk] - dataset, tensor_names, problematic_sample_ids = processor.dataset_from_dicts(dicts=dicts, indices=indices) - return dataset, tensor_names, problematic_sample_ids - def _get_dataset(self, filename: Optional[Union[str, Path]], dicts: Optional[List[Dict]] = None): if not filename and not dicts: raise ValueError("You must either supply `filename` or `dicts`") @@ -136,61 +117,21 @@ class DataSilo: random.shuffle(dicts) num_dicts = len(dicts) - multiprocessing_chunk_size, num_cpus_used = calc_chunksize( - num_dicts=num_dicts, max_processes=self.max_processes, max_chunksize=self.max_multiprocessing_chunksize - ) + datasets = [] + problematic_ids_all = set() + batch_size = self.max_multiprocessing_chunksize + for i in tqdm(range(0, num_dicts, batch_size), desc="Preprocessing dataset", unit=" Dicts"): + processing_batch = dicts[i : i + batch_size] + dataset, tensor_names, problematic_sample_ids = self.processor.dataset_from_dicts( + dicts=processing_batch, indices=list(range(len(processing_batch))) # TODO remove indices + ) + datasets.append(dataset) + problematic_ids_all.update(problematic_sample_ids) - with ExitStack() as stack: - if self.max_processes > 1: # use multiprocessing only when max_processes > 1 - if self.multiprocessing_strategy: - if self.multiprocessing_strategy in mp.get_all_sharing_strategies(): - mp.set_sharing_strategy(self.multiprocessing_strategy) - else: - logger.warning( - f"{self.multiprocessing_strategy} is unavailable, " - f"falling back to default multiprocessing sharing strategy of your OS." - ) - - p = stack.enter_context(mp.Pool(processes=num_cpus_used)) - - logger.info( - f"Got ya {num_cpus_used} parallel workers to convert {num_dicts} dictionaries " - f"to pytorch datasets (chunksize = {multiprocessing_chunk_size})..." - ) - log_ascii_workers(num_cpus_used, logger) - - results = p.imap( - partial(self._dataset_from_chunk, processor=self.processor), - grouper(dicts, multiprocessing_chunk_size), - chunksize=1, - ) - else: - logger.info( - f"Multiprocessing disabled, using a single worker to convert {num_dicts}" - f"dictionaries to pytorch datasets." - ) - - # temporary fix - results = map(partial(self._dataset_from_chunk, processor=self.processor), grouper(dicts, 1)) # type: ignore - - datasets = [] - problematic_ids_all = set() - - desc = f"Preprocessing Dataset" - if filename: - desc += f" {filename}" - with tqdm(total=len(dicts), unit=" Dicts", desc=desc) as pbar: - for dataset, tensor_names, problematic_samples in results: - datasets.append(dataset) - # update progress bar (last step can have less dicts than actual chunk_size) - pbar.update(min(multiprocessing_chunk_size, pbar.total - pbar.n)) - problematic_ids_all.update(problematic_samples) - - self.processor.log_problematic(problematic_ids_all) - # _dataset_from_chunk can return a None in cases where downsampling has occurred - datasets = [d for d in datasets if d] - concat_datasets = ConcatDataset(datasets) # type: Dataset - return concat_datasets, tensor_names + self.processor.log_problematic(problematic_ids_all) + datasets = [d for d in datasets if d] + concat_datasets = ConcatDataset(datasets) # type: Dataset + return concat_datasets, tensor_names def _load_data( self, diff --git a/haystack/modeling/data_handler/input_features.py b/haystack/modeling/data_handler/input_features.py index 0d4a390d2..0696b281c 100644 --- a/haystack/modeling/data_handler/input_features.py +++ b/haystack/modeling/data_handler/input_features.py @@ -48,7 +48,7 @@ def sample_to_features_text(sample, tasks, max_seq_len, tokenizer): tokens_a = sample.tokenized["tokens"] tokens_b = sample.tokenized.get("tokens_b", None) - inputs = tokenizer.encode_plus( + inputs = tokenizer( tokens_a, tokens_b, add_special_tokens=True, diff --git a/haystack/modeling/data_handler/processor.py b/haystack/modeling/data_handler/processor.py index b44bcdeef..7bdf153a5 100644 --- a/haystack/modeling/data_handler/processor.py +++ b/haystack/modeling/data_handler/processor.py @@ -1092,8 +1092,8 @@ class TextSimilarityProcessor(Processor): query = self._normalize_question(basket.raw["query"]) # featurize the query - query_inputs = self.query_tokenizer.encode_plus( - text=query, + query_inputs = self.query_tokenizer( + query, max_length=self.max_seq_len_query, add_special_tokens=True, truncation=True, @@ -1157,7 +1157,7 @@ class TextSimilarityProcessor(Processor): # assign empty string tuples if hard_negative passages less than num_hard_negatives all_ctx += [("", "")] * ((self.num_positives + self.num_hard_negatives) - len(all_ctx)) - ctx_inputs = self.passage_tokenizer.batch_encode_plus( + ctx_inputs = self.passage_tokenizer( all_ctx, add_special_tokens=True, truncation=True, @@ -1568,8 +1568,8 @@ class TableTextSimilarityProcessor(Processor): query = self._normalize_question(basket.raw["query"]) # featurize the query - query_inputs = self.query_tokenizer.encode_plus( - text=query, + query_inputs = self.query_tokenizer( + query, max_length=self.max_seq_len_query, add_special_tokens=True, truncation=True, @@ -1660,7 +1660,7 @@ class TableTextSimilarityProcessor(Processor): # assign empty string tuples if hard_negative passages less than num_hard_negatives all_ctx += [("", "")] * ((self.num_positives + self.num_hard_negatives) - len(all_ctx)) - inputs = self.passage_tokenizer.batch_encode_plus( + inputs = self.passage_tokenizer( all_ctx, add_special_tokens=True, truncation=True, @@ -1858,7 +1858,7 @@ class TextClassificationProcessor(Processor): self.baskets = [] # Tokenize in batches texts = [x["text"] for x in dicts] - tokenized_batch = self.tokenizer.batch_encode_plus( + tokenized_batch = self.tokenizer( texts, return_offsets_mapping=True, return_special_tokens_mask=True, @@ -2093,7 +2093,7 @@ class UnlabeledTextProcessor(Processor): if return_baskets: raise NotImplementedError("return_baskets is not supported by UnlabeledTextProcessor") texts = [dict_["text"] for dict_ in dicts] - tokens = self.tokenizer.batch_encode_plus( + tokens = self.tokenizer( texts, add_special_tokens=True, return_tensors="pt", diff --git a/haystack/modeling/model/tokenization.py b/haystack/modeling/model/tokenization.py index 6c6db86c0..7ad2afc13 100644 --- a/haystack/modeling/model/tokenization.py +++ b/haystack/modeling/model/tokenization.py @@ -100,7 +100,7 @@ def tokenize_batch_question_answering( baskets = [] # # Tokenize texts in batch mode texts = [d["context"] for d in pre_baskets] - tokenized_docs_batch = tokenizer.batch_encode_plus( + tokenized_docs_batch = tokenizer( texts, return_offsets_mapping=True, return_special_tokens_mask=True, add_special_tokens=False, verbose=False ) @@ -108,24 +108,24 @@ def tokenize_batch_question_answering( tokenids_batch = tokenized_docs_batch["input_ids"] offsets_batch = [] for o in tokenized_docs_batch["offset_mapping"]: - offsets_batch.append(np.array([x[0] for x in o])) + offsets_batch.append(np.asarray([x[0] for x in o], dtype="int16")) start_of_words_batch = [] for e in tokenized_docs_batch.encodings: - start_of_words_batch.append(_get_start_of_word_QA(e.words)) + start_of_words_batch.append(_get_start_of_word_QA(e.word_ids)) for i_doc, d in enumerate(pre_baskets): document_text = d["context"] # # Tokenize questions one by one for i_q, q in enumerate(d["qas"]): question_text = q["question"] - tokenized_q = tokenizer.encode_plus( + tokenized_q = tokenizer( question_text, return_offsets_mapping=True, return_special_tokens_mask=True, add_special_tokens=False ) # Extract relevant data question_tokenids = tokenized_q["input_ids"] question_offsets = [x[0] for x in tokenized_q["offset_mapping"]] - question_sow = _get_start_of_word_QA(tokenized_q.encodings[0].words) + question_sow = _get_start_of_word_QA(tokenized_q.encodings[0].word_ids) external_id = q["id"] # The internal_id depends on unique ids created for each process before forking @@ -150,7 +150,7 @@ def tokenize_batch_question_answering( def _get_start_of_word_QA(word_ids): - return [1] + list(np.ediff1d(np.array(word_ids))) + return [1] + list(np.ediff1d(np.asarray(word_ids, dtype="int16"))) def truncate_sequences( @@ -241,7 +241,7 @@ def tokenize_with_metadata(text: str, tokenizer: PreTrainedTokenizer) -> Dict[st # Fast Tokenizers return offsets, so we don't need to calculate them ourselves if tokenizer.is_fast: # tokenized = tokenizer(text, return_offsets_mapping=True, return_special_tokens_mask=True) - tokenized = tokenizer.encode_plus(text, return_offsets_mapping=True, return_special_tokens_mask=True) + tokenized = tokenizer(text, return_offsets_mapping=True, return_special_tokens_mask=True) tokens = tokenized["input_ids"] offsets = np.array([x[0] for x in tokenized["offset_mapping"]]) diff --git a/haystack/nodes/answer_generator/transformers.py b/haystack/nodes/answer_generator/transformers.py index 5387c058b..779da0385 100644 --- a/haystack/nodes/answer_generator/transformers.py +++ b/haystack/nodes/answer_generator/transformers.py @@ -165,7 +165,7 @@ class RAGenerator(BaseGenerator): for i in range(len(texts)) ] - contextualized_inputs = self.tokenizer.generator.batch_encode_plus( + contextualized_inputs = self.tokenizer.generator( rag_input_strings, max_length=self.model.config.max_combined_length, return_tensors=return_tensors, diff --git a/haystack/nodes/reader/table.py b/haystack/nodes/reader/table.py index 28c3d52fe..13bdcac14 100644 --- a/haystack/nodes/reader/table.py +++ b/haystack/nodes/reader/table.py @@ -653,8 +653,8 @@ class RCIReader(BaseReader): row_reps, column_reps = self._create_row_column_representations(table) # Get row logits - row_inputs = self.row_tokenizer.batch_encode_plus( - batch_text_or_text_pairs=[(query, row_rep) for row_rep in row_reps], + row_inputs = self.row_tokenizer( + [(query, row_rep) for row_rep in row_reps], max_length=self.max_seq_len, return_tensors="pt", add_special_tokens=True, @@ -665,8 +665,8 @@ class RCIReader(BaseReader): row_logits = self.row_model(**row_inputs)[0].detach().cpu().numpy()[:, 1] # Get column logits - column_inputs = self.column_tokenizer.batch_encode_plus( - batch_text_or_text_pairs=[(query, column_rep) for column_rep in column_reps], + column_inputs = self.column_tokenizer( + [(query, column_rep) for column_rep in column_reps], max_length=self.max_seq_len, return_tensors="pt", add_special_tokens=True, diff --git a/test/modeling/test_tokenization.py b/test/modeling/test_tokenization.py index 5758eeede..e755b7799 100644 --- a/test/modeling/test_tokenization.py +++ b/test/modeling/test_tokenization.py @@ -119,13 +119,13 @@ def test_save_load(tmp_path, model_name: str): text = "Some Text with neverseentokens plus !215?#. and a combined-token_with/chars" tokenizer.add_tokens(new_tokens=["neverseentokens"]) - original_encoding = tokenizer.encode_plus(text) + original_encoding = tokenizer(text) save_dir = tmp_path / "saved_tokenizer" tokenizer.save_pretrained(save_dir) tokenizer_loaded = get_tokenizer(pretrained_model_name_or_path=save_dir) - new_encoding = tokenizer_loaded.encode_plus(text) + new_encoding = tokenizer_loaded(text) assert original_encoding == new_encoding @@ -168,7 +168,7 @@ def test_tokenization_on_edge_cases_full_sequence_tokenization(model_name: str, words_and_spans = pre_tokenizer.pre_tokenize_str(edge_case) words = [x[0] for x in words_and_spans] - encoded = tokenizer.encode_plus(words, is_split_into_words=True, add_special_tokens=False).encodings[0] + encoded = tokenizer(words, is_split_into_words=True, add_special_tokens=False).encodings[0] expected_tokenization = tokenizer.tokenize(" ".join(edge_case.split())) # remove multiple whitespaces assert encoded.tokens == expected_tokenization @@ -188,7 +188,7 @@ def test_tokenization_on_edge_cases_full_sequence_tokenization_roberta_exception words_and_spans = pre_tokenizer.pre_tokenize_str(edge_case) words = [x[0] for x in words_and_spans] - encoded = tokenizer.encode_plus(words, is_split_into_words=True, add_special_tokens=False).encodings[0] + encoded = tokenizer(words, is_split_into_words=True, add_special_tokens=False).encodings[0] expected_tokenization = tokenizer.tokenize(" ".join(edge_case.split())) # remove multiple whitespaces assert encoded.tokens == expected_tokenization @@ -218,7 +218,7 @@ def test_tokenization_on_edge_cases_full_sequence_verify_spans(model_name: str, words = [x[0] for x in words_and_spans] word_spans = [x[1] for x in words_and_spans] - encoded = tokenizer.encode_plus(words, is_split_into_words=True, add_special_tokens=False).encodings[0] + encoded = tokenizer(words, is_split_into_words=True, add_special_tokens=False).encodings[0] # subword-tokens have special chars depending on model type. To align with original text we get rid of them tokens = [token.replace(marker, "") for token in encoded.tokens] @@ -248,7 +248,7 @@ def test_tokenization_on_edge_cases_full_sequence_verify_spans(model_name: str, def test_detokenization_for_bert(edge_case): tokenizer = get_tokenizer(pretrained_model_name_or_path=BERT, do_lower_case=False) - encoded = tokenizer.encode_plus(edge_case, add_special_tokens=False).encodings[0] + encoded = tokenizer(edge_case, add_special_tokens=False).encodings[0] detokenized = " ".join(encoded.tokens) detokenized = re.sub(r"(^|\s+)(##)", "", detokenized) @@ -264,7 +264,7 @@ def test_encode_plus_for_bert(): tokenizer = get_tokenizer(pretrained_model_name_or_path=BERT, do_lower_case=False) text = "Some Text with neverseentokens plus !215?#. and a combined-token_with/chars" - encoded_batch = tokenizer.encode_plus(text) + encoded_batch = tokenizer(text) encoded = encoded_batch.encodings[0] words = np.array(encoded.words) @@ -316,7 +316,7 @@ def test_tokenize_custom_vocab_bert(): tokenized = tokenizer.tokenize(text) - encoded = tokenizer.encode_plus(text, add_special_tokens=False).encodings[0] + encoded = tokenizer(text, add_special_tokens=False).encodings[0] offsets = [x[0] for x in encoded.offsets] start_of_word_single = [True] + list(np.ediff1d(encoded.words) > 0)