fix: Replace multiprocessing tokenization with batched fast tokenization (#3089)

* Replace multiprocessing tokenization with batched fast tokenization

* Replace deprecated tokenization method invocations
This commit is contained in:
Vladimir Blagojevic 2022-08-31 07:33:39 -04:00 committed by GitHub
parent e7771dc18e
commit 66f3f42a46
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 48 additions and 107 deletions

View File

@ -4,14 +4,11 @@ import hashlib
import json
import logging
import random
from contextlib import ExitStack
from functools import partial
from itertools import groupby
from pathlib import Path
import numpy as np
from tqdm import tqdm
import torch
import torch.multiprocessing as mp
from torch.utils.data import ConcatDataset, Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler
@ -19,7 +16,6 @@ from torch.utils.data.sampler import RandomSampler, SequentialSampler
from haystack.modeling.data_handler.dataloader import NamedDataLoader
from haystack.modeling.data_handler.processor import Processor
from haystack.utils.experiment_tracking import Tracker as tracker
from haystack.modeling.utils import log_ascii_workers, grouper, calc_chunksize
from haystack.modeling.visual import TRACTOR_SMALL
if TYPE_CHECKING:
@ -41,7 +37,7 @@ class DataSilo:
eval_batch_size: Optional[int] = None,
distributed: bool = False,
automatic_loading: bool = True,
max_multiprocessing_chunksize: int = 2000,
max_multiprocessing_chunksize: int = 512,
max_processes: int = 128,
multiprocessing_strategy: Optional[str] = None,
caching: bool = False,
@ -59,9 +55,13 @@ class DataSilo:
values are rather large that might cause memory issues.
:param max_processes: the maximum number of processes to spawn in the multiprocessing.Pool used in DataSilo.
It can be set to 1 to disable the use of multiprocessing or make debugging easier.
.. deprecated:: 1.9
Multiprocessing has been removed in 1.9. This parameter will be ignored.
:multiprocessing_strategy: Set the multiprocessing sharing strategy, this can be one of file_descriptor/file_system depending on your OS.
If your system has low limits for the number of open file descriptors, and you cant raise them,
you should use the file_system strategy.
.. deprecated:: 1.9
Multiprocessing has been removed in 1.9. This parameter will be ignored.
:param caching: save the processed datasets on disk to save time/compute if the same train data is used to run
multiple experiments. Each cache has a checksum based on the train_filename of the Processor
and the batch size.
@ -103,25 +103,6 @@ class DataSilo:
# later or load from dicts instead of file
self._load_data()
@classmethod
def _dataset_from_chunk(cls, chunk: List[Tuple[int, Dict]], processor: Processor):
"""
Creating a dataset for a chunk (= subset) of dicts. In multiprocessing:
* we read in all dicts from a file
* split all dicts into chunks
* feed *one chunk* to *one process*
=> the *one chunk* gets converted to *one dataset* (that's what we do here)
* all datasets get collected and concatenated
:param chunk: Instead of only having a list of dicts here we also supply an index (ascending int) for each.
=> [(0, dict), (1, dict) ...]
:param processor: Haystack basics Processor (e.g. SquadProcessor)
:return: PyTorch Dataset
"""
dicts = [d[1] for d in chunk]
indices = [x[0] for x in chunk]
dataset, tensor_names, problematic_sample_ids = processor.dataset_from_dicts(dicts=dicts, indices=indices)
return dataset, tensor_names, problematic_sample_ids
def _get_dataset(self, filename: Optional[Union[str, Path]], dicts: Optional[List[Dict]] = None):
if not filename and not dicts:
raise ValueError("You must either supply `filename` or `dicts`")
@ -136,61 +117,21 @@ class DataSilo:
random.shuffle(dicts)
num_dicts = len(dicts)
multiprocessing_chunk_size, num_cpus_used = calc_chunksize(
num_dicts=num_dicts, max_processes=self.max_processes, max_chunksize=self.max_multiprocessing_chunksize
)
datasets = []
problematic_ids_all = set()
batch_size = self.max_multiprocessing_chunksize
for i in tqdm(range(0, num_dicts, batch_size), desc="Preprocessing dataset", unit=" Dicts"):
processing_batch = dicts[i : i + batch_size]
dataset, tensor_names, problematic_sample_ids = self.processor.dataset_from_dicts(
dicts=processing_batch, indices=list(range(len(processing_batch))) # TODO remove indices
)
datasets.append(dataset)
problematic_ids_all.update(problematic_sample_ids)
with ExitStack() as stack:
if self.max_processes > 1: # use multiprocessing only when max_processes > 1
if self.multiprocessing_strategy:
if self.multiprocessing_strategy in mp.get_all_sharing_strategies():
mp.set_sharing_strategy(self.multiprocessing_strategy)
else:
logger.warning(
f"{self.multiprocessing_strategy} is unavailable, "
f"falling back to default multiprocessing sharing strategy of your OS."
)
p = stack.enter_context(mp.Pool(processes=num_cpus_used))
logger.info(
f"Got ya {num_cpus_used} parallel workers to convert {num_dicts} dictionaries "
f"to pytorch datasets (chunksize = {multiprocessing_chunk_size})..."
)
log_ascii_workers(num_cpus_used, logger)
results = p.imap(
partial(self._dataset_from_chunk, processor=self.processor),
grouper(dicts, multiprocessing_chunk_size),
chunksize=1,
)
else:
logger.info(
f"Multiprocessing disabled, using a single worker to convert {num_dicts}"
f"dictionaries to pytorch datasets."
)
# temporary fix
results = map(partial(self._dataset_from_chunk, processor=self.processor), grouper(dicts, 1)) # type: ignore
datasets = []
problematic_ids_all = set()
desc = f"Preprocessing Dataset"
if filename:
desc += f" {filename}"
with tqdm(total=len(dicts), unit=" Dicts", desc=desc) as pbar:
for dataset, tensor_names, problematic_samples in results:
datasets.append(dataset)
# update progress bar (last step can have less dicts than actual chunk_size)
pbar.update(min(multiprocessing_chunk_size, pbar.total - pbar.n))
problematic_ids_all.update(problematic_samples)
self.processor.log_problematic(problematic_ids_all)
# _dataset_from_chunk can return a None in cases where downsampling has occurred
datasets = [d for d in datasets if d]
concat_datasets = ConcatDataset(datasets) # type: Dataset
return concat_datasets, tensor_names
self.processor.log_problematic(problematic_ids_all)
datasets = [d for d in datasets if d]
concat_datasets = ConcatDataset(datasets) # type: Dataset
return concat_datasets, tensor_names
def _load_data(
self,

View File

@ -48,7 +48,7 @@ def sample_to_features_text(sample, tasks, max_seq_len, tokenizer):
tokens_a = sample.tokenized["tokens"]
tokens_b = sample.tokenized.get("tokens_b", None)
inputs = tokenizer.encode_plus(
inputs = tokenizer(
tokens_a,
tokens_b,
add_special_tokens=True,

View File

@ -1092,8 +1092,8 @@ class TextSimilarityProcessor(Processor):
query = self._normalize_question(basket.raw["query"])
# featurize the query
query_inputs = self.query_tokenizer.encode_plus(
text=query,
query_inputs = self.query_tokenizer(
query,
max_length=self.max_seq_len_query,
add_special_tokens=True,
truncation=True,
@ -1157,7 +1157,7 @@ class TextSimilarityProcessor(Processor):
# assign empty string tuples if hard_negative passages less than num_hard_negatives
all_ctx += [("", "")] * ((self.num_positives + self.num_hard_negatives) - len(all_ctx))
ctx_inputs = self.passage_tokenizer.batch_encode_plus(
ctx_inputs = self.passage_tokenizer(
all_ctx,
add_special_tokens=True,
truncation=True,
@ -1568,8 +1568,8 @@ class TableTextSimilarityProcessor(Processor):
query = self._normalize_question(basket.raw["query"])
# featurize the query
query_inputs = self.query_tokenizer.encode_plus(
text=query,
query_inputs = self.query_tokenizer(
query,
max_length=self.max_seq_len_query,
add_special_tokens=True,
truncation=True,
@ -1660,7 +1660,7 @@ class TableTextSimilarityProcessor(Processor):
# assign empty string tuples if hard_negative passages less than num_hard_negatives
all_ctx += [("", "")] * ((self.num_positives + self.num_hard_negatives) - len(all_ctx))
inputs = self.passage_tokenizer.batch_encode_plus(
inputs = self.passage_tokenizer(
all_ctx,
add_special_tokens=True,
truncation=True,
@ -1858,7 +1858,7 @@ class TextClassificationProcessor(Processor):
self.baskets = []
# Tokenize in batches
texts = [x["text"] for x in dicts]
tokenized_batch = self.tokenizer.batch_encode_plus(
tokenized_batch = self.tokenizer(
texts,
return_offsets_mapping=True,
return_special_tokens_mask=True,
@ -2093,7 +2093,7 @@ class UnlabeledTextProcessor(Processor):
if return_baskets:
raise NotImplementedError("return_baskets is not supported by UnlabeledTextProcessor")
texts = [dict_["text"] for dict_ in dicts]
tokens = self.tokenizer.batch_encode_plus(
tokens = self.tokenizer(
texts,
add_special_tokens=True,
return_tensors="pt",

View File

@ -100,7 +100,7 @@ def tokenize_batch_question_answering(
baskets = []
# # Tokenize texts in batch mode
texts = [d["context"] for d in pre_baskets]
tokenized_docs_batch = tokenizer.batch_encode_plus(
tokenized_docs_batch = tokenizer(
texts, return_offsets_mapping=True, return_special_tokens_mask=True, add_special_tokens=False, verbose=False
)
@ -108,24 +108,24 @@ def tokenize_batch_question_answering(
tokenids_batch = tokenized_docs_batch["input_ids"]
offsets_batch = []
for o in tokenized_docs_batch["offset_mapping"]:
offsets_batch.append(np.array([x[0] for x in o]))
offsets_batch.append(np.asarray([x[0] for x in o], dtype="int16"))
start_of_words_batch = []
for e in tokenized_docs_batch.encodings:
start_of_words_batch.append(_get_start_of_word_QA(e.words))
start_of_words_batch.append(_get_start_of_word_QA(e.word_ids))
for i_doc, d in enumerate(pre_baskets):
document_text = d["context"]
# # Tokenize questions one by one
for i_q, q in enumerate(d["qas"]):
question_text = q["question"]
tokenized_q = tokenizer.encode_plus(
tokenized_q = tokenizer(
question_text, return_offsets_mapping=True, return_special_tokens_mask=True, add_special_tokens=False
)
# Extract relevant data
question_tokenids = tokenized_q["input_ids"]
question_offsets = [x[0] for x in tokenized_q["offset_mapping"]]
question_sow = _get_start_of_word_QA(tokenized_q.encodings[0].words)
question_sow = _get_start_of_word_QA(tokenized_q.encodings[0].word_ids)
external_id = q["id"]
# The internal_id depends on unique ids created for each process before forking
@ -150,7 +150,7 @@ def tokenize_batch_question_answering(
def _get_start_of_word_QA(word_ids):
return [1] + list(np.ediff1d(np.array(word_ids)))
return [1] + list(np.ediff1d(np.asarray(word_ids, dtype="int16")))
def truncate_sequences(
@ -241,7 +241,7 @@ def tokenize_with_metadata(text: str, tokenizer: PreTrainedTokenizer) -> Dict[st
# Fast Tokenizers return offsets, so we don't need to calculate them ourselves
if tokenizer.is_fast:
# tokenized = tokenizer(text, return_offsets_mapping=True, return_special_tokens_mask=True)
tokenized = tokenizer.encode_plus(text, return_offsets_mapping=True, return_special_tokens_mask=True)
tokenized = tokenizer(text, return_offsets_mapping=True, return_special_tokens_mask=True)
tokens = tokenized["input_ids"]
offsets = np.array([x[0] for x in tokenized["offset_mapping"]])

View File

@ -165,7 +165,7 @@ class RAGenerator(BaseGenerator):
for i in range(len(texts))
]
contextualized_inputs = self.tokenizer.generator.batch_encode_plus(
contextualized_inputs = self.tokenizer.generator(
rag_input_strings,
max_length=self.model.config.max_combined_length,
return_tensors=return_tensors,

View File

@ -653,8 +653,8 @@ class RCIReader(BaseReader):
row_reps, column_reps = self._create_row_column_representations(table)
# Get row logits
row_inputs = self.row_tokenizer.batch_encode_plus(
batch_text_or_text_pairs=[(query, row_rep) for row_rep in row_reps],
row_inputs = self.row_tokenizer(
[(query, row_rep) for row_rep in row_reps],
max_length=self.max_seq_len,
return_tensors="pt",
add_special_tokens=True,
@ -665,8 +665,8 @@ class RCIReader(BaseReader):
row_logits = self.row_model(**row_inputs)[0].detach().cpu().numpy()[:, 1]
# Get column logits
column_inputs = self.column_tokenizer.batch_encode_plus(
batch_text_or_text_pairs=[(query, column_rep) for column_rep in column_reps],
column_inputs = self.column_tokenizer(
[(query, column_rep) for column_rep in column_reps],
max_length=self.max_seq_len,
return_tensors="pt",
add_special_tokens=True,

View File

@ -119,13 +119,13 @@ def test_save_load(tmp_path, model_name: str):
text = "Some Text with neverseentokens plus !215?#. and a combined-token_with/chars"
tokenizer.add_tokens(new_tokens=["neverseentokens"])
original_encoding = tokenizer.encode_plus(text)
original_encoding = tokenizer(text)
save_dir = tmp_path / "saved_tokenizer"
tokenizer.save_pretrained(save_dir)
tokenizer_loaded = get_tokenizer(pretrained_model_name_or_path=save_dir)
new_encoding = tokenizer_loaded.encode_plus(text)
new_encoding = tokenizer_loaded(text)
assert original_encoding == new_encoding
@ -168,7 +168,7 @@ def test_tokenization_on_edge_cases_full_sequence_tokenization(model_name: str,
words_and_spans = pre_tokenizer.pre_tokenize_str(edge_case)
words = [x[0] for x in words_and_spans]
encoded = tokenizer.encode_plus(words, is_split_into_words=True, add_special_tokens=False).encodings[0]
encoded = tokenizer(words, is_split_into_words=True, add_special_tokens=False).encodings[0]
expected_tokenization = tokenizer.tokenize(" ".join(edge_case.split())) # remove multiple whitespaces
assert encoded.tokens == expected_tokenization
@ -188,7 +188,7 @@ def test_tokenization_on_edge_cases_full_sequence_tokenization_roberta_exception
words_and_spans = pre_tokenizer.pre_tokenize_str(edge_case)
words = [x[0] for x in words_and_spans]
encoded = tokenizer.encode_plus(words, is_split_into_words=True, add_special_tokens=False).encodings[0]
encoded = tokenizer(words, is_split_into_words=True, add_special_tokens=False).encodings[0]
expected_tokenization = tokenizer.tokenize(" ".join(edge_case.split())) # remove multiple whitespaces
assert encoded.tokens == expected_tokenization
@ -218,7 +218,7 @@ def test_tokenization_on_edge_cases_full_sequence_verify_spans(model_name: str,
words = [x[0] for x in words_and_spans]
word_spans = [x[1] for x in words_and_spans]
encoded = tokenizer.encode_plus(words, is_split_into_words=True, add_special_tokens=False).encodings[0]
encoded = tokenizer(words, is_split_into_words=True, add_special_tokens=False).encodings[0]
# subword-tokens have special chars depending on model type. To align with original text we get rid of them
tokens = [token.replace(marker, "") for token in encoded.tokens]
@ -248,7 +248,7 @@ def test_tokenization_on_edge_cases_full_sequence_verify_spans(model_name: str,
def test_detokenization_for_bert(edge_case):
tokenizer = get_tokenizer(pretrained_model_name_or_path=BERT, do_lower_case=False)
encoded = tokenizer.encode_plus(edge_case, add_special_tokens=False).encodings[0]
encoded = tokenizer(edge_case, add_special_tokens=False).encodings[0]
detokenized = " ".join(encoded.tokens)
detokenized = re.sub(r"(^|\s+)(##)", "", detokenized)
@ -264,7 +264,7 @@ def test_encode_plus_for_bert():
tokenizer = get_tokenizer(pretrained_model_name_or_path=BERT, do_lower_case=False)
text = "Some Text with neverseentokens plus !215?#. and a combined-token_with/chars"
encoded_batch = tokenizer.encode_plus(text)
encoded_batch = tokenizer(text)
encoded = encoded_batch.encodings[0]
words = np.array(encoded.words)
@ -316,7 +316,7 @@ def test_tokenize_custom_vocab_bert():
tokenized = tokenizer.tokenize(text)
encoded = tokenizer.encode_plus(text, add_special_tokens=False).encodings[0]
encoded = tokenizer(text, add_special_tokens=False).encodings[0]
offsets = [x[0] for x in encoded.offsets]
start_of_word_single = [True] + list(np.ediff1d(encoded.words) > 0)