mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-17 18:43:58 +00:00
fix: Replace multiprocessing tokenization with batched fast tokenization (#3089)
* Replace multiprocessing tokenization with batched fast tokenization * Replace deprecated tokenization method invocations
This commit is contained in:
parent
e7771dc18e
commit
66f3f42a46
@ -4,14 +4,11 @@ import hashlib
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
from contextlib import ExitStack
|
|
||||||
from functools import partial
|
|
||||||
from itertools import groupby
|
from itertools import groupby
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
|
||||||
from torch.utils.data import ConcatDataset, Dataset
|
from torch.utils.data import ConcatDataset, Dataset
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from torch.utils.data.sampler import RandomSampler, SequentialSampler
|
from torch.utils.data.sampler import RandomSampler, SequentialSampler
|
||||||
@ -19,7 +16,6 @@ from torch.utils.data.sampler import RandomSampler, SequentialSampler
|
|||||||
from haystack.modeling.data_handler.dataloader import NamedDataLoader
|
from haystack.modeling.data_handler.dataloader import NamedDataLoader
|
||||||
from haystack.modeling.data_handler.processor import Processor
|
from haystack.modeling.data_handler.processor import Processor
|
||||||
from haystack.utils.experiment_tracking import Tracker as tracker
|
from haystack.utils.experiment_tracking import Tracker as tracker
|
||||||
from haystack.modeling.utils import log_ascii_workers, grouper, calc_chunksize
|
|
||||||
from haystack.modeling.visual import TRACTOR_SMALL
|
from haystack.modeling.visual import TRACTOR_SMALL
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -41,7 +37,7 @@ class DataSilo:
|
|||||||
eval_batch_size: Optional[int] = None,
|
eval_batch_size: Optional[int] = None,
|
||||||
distributed: bool = False,
|
distributed: bool = False,
|
||||||
automatic_loading: bool = True,
|
automatic_loading: bool = True,
|
||||||
max_multiprocessing_chunksize: int = 2000,
|
max_multiprocessing_chunksize: int = 512,
|
||||||
max_processes: int = 128,
|
max_processes: int = 128,
|
||||||
multiprocessing_strategy: Optional[str] = None,
|
multiprocessing_strategy: Optional[str] = None,
|
||||||
caching: bool = False,
|
caching: bool = False,
|
||||||
@ -59,9 +55,13 @@ class DataSilo:
|
|||||||
values are rather large that might cause memory issues.
|
values are rather large that might cause memory issues.
|
||||||
:param max_processes: the maximum number of processes to spawn in the multiprocessing.Pool used in DataSilo.
|
:param max_processes: the maximum number of processes to spawn in the multiprocessing.Pool used in DataSilo.
|
||||||
It can be set to 1 to disable the use of multiprocessing or make debugging easier.
|
It can be set to 1 to disable the use of multiprocessing or make debugging easier.
|
||||||
|
.. deprecated:: 1.9
|
||||||
|
Multiprocessing has been removed in 1.9. This parameter will be ignored.
|
||||||
:multiprocessing_strategy: Set the multiprocessing sharing strategy, this can be one of file_descriptor/file_system depending on your OS.
|
:multiprocessing_strategy: Set the multiprocessing sharing strategy, this can be one of file_descriptor/file_system depending on your OS.
|
||||||
If your system has low limits for the number of open file descriptors, and you can’t raise them,
|
If your system has low limits for the number of open file descriptors, and you can’t raise them,
|
||||||
you should use the file_system strategy.
|
you should use the file_system strategy.
|
||||||
|
.. deprecated:: 1.9
|
||||||
|
Multiprocessing has been removed in 1.9. This parameter will be ignored.
|
||||||
:param caching: save the processed datasets on disk to save time/compute if the same train data is used to run
|
:param caching: save the processed datasets on disk to save time/compute if the same train data is used to run
|
||||||
multiple experiments. Each cache has a checksum based on the train_filename of the Processor
|
multiple experiments. Each cache has a checksum based on the train_filename of the Processor
|
||||||
and the batch size.
|
and the batch size.
|
||||||
@ -103,25 +103,6 @@ class DataSilo:
|
|||||||
# later or load from dicts instead of file
|
# later or load from dicts instead of file
|
||||||
self._load_data()
|
self._load_data()
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _dataset_from_chunk(cls, chunk: List[Tuple[int, Dict]], processor: Processor):
|
|
||||||
"""
|
|
||||||
Creating a dataset for a chunk (= subset) of dicts. In multiprocessing:
|
|
||||||
* we read in all dicts from a file
|
|
||||||
* split all dicts into chunks
|
|
||||||
* feed *one chunk* to *one process*
|
|
||||||
=> the *one chunk* gets converted to *one dataset* (that's what we do here)
|
|
||||||
* all datasets get collected and concatenated
|
|
||||||
:param chunk: Instead of only having a list of dicts here we also supply an index (ascending int) for each.
|
|
||||||
=> [(0, dict), (1, dict) ...]
|
|
||||||
:param processor: Haystack basics Processor (e.g. SquadProcessor)
|
|
||||||
:return: PyTorch Dataset
|
|
||||||
"""
|
|
||||||
dicts = [d[1] for d in chunk]
|
|
||||||
indices = [x[0] for x in chunk]
|
|
||||||
dataset, tensor_names, problematic_sample_ids = processor.dataset_from_dicts(dicts=dicts, indices=indices)
|
|
||||||
return dataset, tensor_names, problematic_sample_ids
|
|
||||||
|
|
||||||
def _get_dataset(self, filename: Optional[Union[str, Path]], dicts: Optional[List[Dict]] = None):
|
def _get_dataset(self, filename: Optional[Union[str, Path]], dicts: Optional[List[Dict]] = None):
|
||||||
if not filename and not dicts:
|
if not filename and not dicts:
|
||||||
raise ValueError("You must either supply `filename` or `dicts`")
|
raise ValueError("You must either supply `filename` or `dicts`")
|
||||||
@ -136,58 +117,18 @@ class DataSilo:
|
|||||||
random.shuffle(dicts)
|
random.shuffle(dicts)
|
||||||
|
|
||||||
num_dicts = len(dicts)
|
num_dicts = len(dicts)
|
||||||
multiprocessing_chunk_size, num_cpus_used = calc_chunksize(
|
|
||||||
num_dicts=num_dicts, max_processes=self.max_processes, max_chunksize=self.max_multiprocessing_chunksize
|
|
||||||
)
|
|
||||||
|
|
||||||
with ExitStack() as stack:
|
|
||||||
if self.max_processes > 1: # use multiprocessing only when max_processes > 1
|
|
||||||
if self.multiprocessing_strategy:
|
|
||||||
if self.multiprocessing_strategy in mp.get_all_sharing_strategies():
|
|
||||||
mp.set_sharing_strategy(self.multiprocessing_strategy)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"{self.multiprocessing_strategy} is unavailable, "
|
|
||||||
f"falling back to default multiprocessing sharing strategy of your OS."
|
|
||||||
)
|
|
||||||
|
|
||||||
p = stack.enter_context(mp.Pool(processes=num_cpus_used))
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Got ya {num_cpus_used} parallel workers to convert {num_dicts} dictionaries "
|
|
||||||
f"to pytorch datasets (chunksize = {multiprocessing_chunk_size})..."
|
|
||||||
)
|
|
||||||
log_ascii_workers(num_cpus_used, logger)
|
|
||||||
|
|
||||||
results = p.imap(
|
|
||||||
partial(self._dataset_from_chunk, processor=self.processor),
|
|
||||||
grouper(dicts, multiprocessing_chunk_size),
|
|
||||||
chunksize=1,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
f"Multiprocessing disabled, using a single worker to convert {num_dicts}"
|
|
||||||
f"dictionaries to pytorch datasets."
|
|
||||||
)
|
|
||||||
|
|
||||||
# temporary fix
|
|
||||||
results = map(partial(self._dataset_from_chunk, processor=self.processor), grouper(dicts, 1)) # type: ignore
|
|
||||||
|
|
||||||
datasets = []
|
datasets = []
|
||||||
problematic_ids_all = set()
|
problematic_ids_all = set()
|
||||||
|
batch_size = self.max_multiprocessing_chunksize
|
||||||
desc = f"Preprocessing Dataset"
|
for i in tqdm(range(0, num_dicts, batch_size), desc="Preprocessing dataset", unit=" Dicts"):
|
||||||
if filename:
|
processing_batch = dicts[i : i + batch_size]
|
||||||
desc += f" {filename}"
|
dataset, tensor_names, problematic_sample_ids = self.processor.dataset_from_dicts(
|
||||||
with tqdm(total=len(dicts), unit=" Dicts", desc=desc) as pbar:
|
dicts=processing_batch, indices=list(range(len(processing_batch))) # TODO remove indices
|
||||||
for dataset, tensor_names, problematic_samples in results:
|
)
|
||||||
datasets.append(dataset)
|
datasets.append(dataset)
|
||||||
# update progress bar (last step can have less dicts than actual chunk_size)
|
problematic_ids_all.update(problematic_sample_ids)
|
||||||
pbar.update(min(multiprocessing_chunk_size, pbar.total - pbar.n))
|
|
||||||
problematic_ids_all.update(problematic_samples)
|
|
||||||
|
|
||||||
self.processor.log_problematic(problematic_ids_all)
|
self.processor.log_problematic(problematic_ids_all)
|
||||||
# _dataset_from_chunk can return a None in cases where downsampling has occurred
|
|
||||||
datasets = [d for d in datasets if d]
|
datasets = [d for d in datasets if d]
|
||||||
concat_datasets = ConcatDataset(datasets) # type: Dataset
|
concat_datasets = ConcatDataset(datasets) # type: Dataset
|
||||||
return concat_datasets, tensor_names
|
return concat_datasets, tensor_names
|
||||||
|
|||||||
@ -48,7 +48,7 @@ def sample_to_features_text(sample, tasks, max_seq_len, tokenizer):
|
|||||||
tokens_a = sample.tokenized["tokens"]
|
tokens_a = sample.tokenized["tokens"]
|
||||||
tokens_b = sample.tokenized.get("tokens_b", None)
|
tokens_b = sample.tokenized.get("tokens_b", None)
|
||||||
|
|
||||||
inputs = tokenizer.encode_plus(
|
inputs = tokenizer(
|
||||||
tokens_a,
|
tokens_a,
|
||||||
tokens_b,
|
tokens_b,
|
||||||
add_special_tokens=True,
|
add_special_tokens=True,
|
||||||
|
|||||||
@ -1092,8 +1092,8 @@ class TextSimilarityProcessor(Processor):
|
|||||||
query = self._normalize_question(basket.raw["query"])
|
query = self._normalize_question(basket.raw["query"])
|
||||||
|
|
||||||
# featurize the query
|
# featurize the query
|
||||||
query_inputs = self.query_tokenizer.encode_plus(
|
query_inputs = self.query_tokenizer(
|
||||||
text=query,
|
query,
|
||||||
max_length=self.max_seq_len_query,
|
max_length=self.max_seq_len_query,
|
||||||
add_special_tokens=True,
|
add_special_tokens=True,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
@ -1157,7 +1157,7 @@ class TextSimilarityProcessor(Processor):
|
|||||||
# assign empty string tuples if hard_negative passages less than num_hard_negatives
|
# assign empty string tuples if hard_negative passages less than num_hard_negatives
|
||||||
all_ctx += [("", "")] * ((self.num_positives + self.num_hard_negatives) - len(all_ctx))
|
all_ctx += [("", "")] * ((self.num_positives + self.num_hard_negatives) - len(all_ctx))
|
||||||
|
|
||||||
ctx_inputs = self.passage_tokenizer.batch_encode_plus(
|
ctx_inputs = self.passage_tokenizer(
|
||||||
all_ctx,
|
all_ctx,
|
||||||
add_special_tokens=True,
|
add_special_tokens=True,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
@ -1568,8 +1568,8 @@ class TableTextSimilarityProcessor(Processor):
|
|||||||
query = self._normalize_question(basket.raw["query"])
|
query = self._normalize_question(basket.raw["query"])
|
||||||
|
|
||||||
# featurize the query
|
# featurize the query
|
||||||
query_inputs = self.query_tokenizer.encode_plus(
|
query_inputs = self.query_tokenizer(
|
||||||
text=query,
|
query,
|
||||||
max_length=self.max_seq_len_query,
|
max_length=self.max_seq_len_query,
|
||||||
add_special_tokens=True,
|
add_special_tokens=True,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
@ -1660,7 +1660,7 @@ class TableTextSimilarityProcessor(Processor):
|
|||||||
# assign empty string tuples if hard_negative passages less than num_hard_negatives
|
# assign empty string tuples if hard_negative passages less than num_hard_negatives
|
||||||
all_ctx += [("", "")] * ((self.num_positives + self.num_hard_negatives) - len(all_ctx))
|
all_ctx += [("", "")] * ((self.num_positives + self.num_hard_negatives) - len(all_ctx))
|
||||||
|
|
||||||
inputs = self.passage_tokenizer.batch_encode_plus(
|
inputs = self.passage_tokenizer(
|
||||||
all_ctx,
|
all_ctx,
|
||||||
add_special_tokens=True,
|
add_special_tokens=True,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
@ -1858,7 +1858,7 @@ class TextClassificationProcessor(Processor):
|
|||||||
self.baskets = []
|
self.baskets = []
|
||||||
# Tokenize in batches
|
# Tokenize in batches
|
||||||
texts = [x["text"] for x in dicts]
|
texts = [x["text"] for x in dicts]
|
||||||
tokenized_batch = self.tokenizer.batch_encode_plus(
|
tokenized_batch = self.tokenizer(
|
||||||
texts,
|
texts,
|
||||||
return_offsets_mapping=True,
|
return_offsets_mapping=True,
|
||||||
return_special_tokens_mask=True,
|
return_special_tokens_mask=True,
|
||||||
@ -2093,7 +2093,7 @@ class UnlabeledTextProcessor(Processor):
|
|||||||
if return_baskets:
|
if return_baskets:
|
||||||
raise NotImplementedError("return_baskets is not supported by UnlabeledTextProcessor")
|
raise NotImplementedError("return_baskets is not supported by UnlabeledTextProcessor")
|
||||||
texts = [dict_["text"] for dict_ in dicts]
|
texts = [dict_["text"] for dict_ in dicts]
|
||||||
tokens = self.tokenizer.batch_encode_plus(
|
tokens = self.tokenizer(
|
||||||
texts,
|
texts,
|
||||||
add_special_tokens=True,
|
add_special_tokens=True,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
|
|||||||
@ -100,7 +100,7 @@ def tokenize_batch_question_answering(
|
|||||||
baskets = []
|
baskets = []
|
||||||
# # Tokenize texts in batch mode
|
# # Tokenize texts in batch mode
|
||||||
texts = [d["context"] for d in pre_baskets]
|
texts = [d["context"] for d in pre_baskets]
|
||||||
tokenized_docs_batch = tokenizer.batch_encode_plus(
|
tokenized_docs_batch = tokenizer(
|
||||||
texts, return_offsets_mapping=True, return_special_tokens_mask=True, add_special_tokens=False, verbose=False
|
texts, return_offsets_mapping=True, return_special_tokens_mask=True, add_special_tokens=False, verbose=False
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -108,24 +108,24 @@ def tokenize_batch_question_answering(
|
|||||||
tokenids_batch = tokenized_docs_batch["input_ids"]
|
tokenids_batch = tokenized_docs_batch["input_ids"]
|
||||||
offsets_batch = []
|
offsets_batch = []
|
||||||
for o in tokenized_docs_batch["offset_mapping"]:
|
for o in tokenized_docs_batch["offset_mapping"]:
|
||||||
offsets_batch.append(np.array([x[0] for x in o]))
|
offsets_batch.append(np.asarray([x[0] for x in o], dtype="int16"))
|
||||||
start_of_words_batch = []
|
start_of_words_batch = []
|
||||||
for e in tokenized_docs_batch.encodings:
|
for e in tokenized_docs_batch.encodings:
|
||||||
start_of_words_batch.append(_get_start_of_word_QA(e.words))
|
start_of_words_batch.append(_get_start_of_word_QA(e.word_ids))
|
||||||
|
|
||||||
for i_doc, d in enumerate(pre_baskets):
|
for i_doc, d in enumerate(pre_baskets):
|
||||||
document_text = d["context"]
|
document_text = d["context"]
|
||||||
# # Tokenize questions one by one
|
# # Tokenize questions one by one
|
||||||
for i_q, q in enumerate(d["qas"]):
|
for i_q, q in enumerate(d["qas"]):
|
||||||
question_text = q["question"]
|
question_text = q["question"]
|
||||||
tokenized_q = tokenizer.encode_plus(
|
tokenized_q = tokenizer(
|
||||||
question_text, return_offsets_mapping=True, return_special_tokens_mask=True, add_special_tokens=False
|
question_text, return_offsets_mapping=True, return_special_tokens_mask=True, add_special_tokens=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract relevant data
|
# Extract relevant data
|
||||||
question_tokenids = tokenized_q["input_ids"]
|
question_tokenids = tokenized_q["input_ids"]
|
||||||
question_offsets = [x[0] for x in tokenized_q["offset_mapping"]]
|
question_offsets = [x[0] for x in tokenized_q["offset_mapping"]]
|
||||||
question_sow = _get_start_of_word_QA(tokenized_q.encodings[0].words)
|
question_sow = _get_start_of_word_QA(tokenized_q.encodings[0].word_ids)
|
||||||
|
|
||||||
external_id = q["id"]
|
external_id = q["id"]
|
||||||
# The internal_id depends on unique ids created for each process before forking
|
# The internal_id depends on unique ids created for each process before forking
|
||||||
@ -150,7 +150,7 @@ def tokenize_batch_question_answering(
|
|||||||
|
|
||||||
|
|
||||||
def _get_start_of_word_QA(word_ids):
|
def _get_start_of_word_QA(word_ids):
|
||||||
return [1] + list(np.ediff1d(np.array(word_ids)))
|
return [1] + list(np.ediff1d(np.asarray(word_ids, dtype="int16")))
|
||||||
|
|
||||||
|
|
||||||
def truncate_sequences(
|
def truncate_sequences(
|
||||||
@ -241,7 +241,7 @@ def tokenize_with_metadata(text: str, tokenizer: PreTrainedTokenizer) -> Dict[st
|
|||||||
# Fast Tokenizers return offsets, so we don't need to calculate them ourselves
|
# Fast Tokenizers return offsets, so we don't need to calculate them ourselves
|
||||||
if tokenizer.is_fast:
|
if tokenizer.is_fast:
|
||||||
# tokenized = tokenizer(text, return_offsets_mapping=True, return_special_tokens_mask=True)
|
# tokenized = tokenizer(text, return_offsets_mapping=True, return_special_tokens_mask=True)
|
||||||
tokenized = tokenizer.encode_plus(text, return_offsets_mapping=True, return_special_tokens_mask=True)
|
tokenized = tokenizer(text, return_offsets_mapping=True, return_special_tokens_mask=True)
|
||||||
|
|
||||||
tokens = tokenized["input_ids"]
|
tokens = tokenized["input_ids"]
|
||||||
offsets = np.array([x[0] for x in tokenized["offset_mapping"]])
|
offsets = np.array([x[0] for x in tokenized["offset_mapping"]])
|
||||||
|
|||||||
@ -165,7 +165,7 @@ class RAGenerator(BaseGenerator):
|
|||||||
for i in range(len(texts))
|
for i in range(len(texts))
|
||||||
]
|
]
|
||||||
|
|
||||||
contextualized_inputs = self.tokenizer.generator.batch_encode_plus(
|
contextualized_inputs = self.tokenizer.generator(
|
||||||
rag_input_strings,
|
rag_input_strings,
|
||||||
max_length=self.model.config.max_combined_length,
|
max_length=self.model.config.max_combined_length,
|
||||||
return_tensors=return_tensors,
|
return_tensors=return_tensors,
|
||||||
|
|||||||
@ -653,8 +653,8 @@ class RCIReader(BaseReader):
|
|||||||
row_reps, column_reps = self._create_row_column_representations(table)
|
row_reps, column_reps = self._create_row_column_representations(table)
|
||||||
|
|
||||||
# Get row logits
|
# Get row logits
|
||||||
row_inputs = self.row_tokenizer.batch_encode_plus(
|
row_inputs = self.row_tokenizer(
|
||||||
batch_text_or_text_pairs=[(query, row_rep) for row_rep in row_reps],
|
[(query, row_rep) for row_rep in row_reps],
|
||||||
max_length=self.max_seq_len,
|
max_length=self.max_seq_len,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
add_special_tokens=True,
|
add_special_tokens=True,
|
||||||
@ -665,8 +665,8 @@ class RCIReader(BaseReader):
|
|||||||
row_logits = self.row_model(**row_inputs)[0].detach().cpu().numpy()[:, 1]
|
row_logits = self.row_model(**row_inputs)[0].detach().cpu().numpy()[:, 1]
|
||||||
|
|
||||||
# Get column logits
|
# Get column logits
|
||||||
column_inputs = self.column_tokenizer.batch_encode_plus(
|
column_inputs = self.column_tokenizer(
|
||||||
batch_text_or_text_pairs=[(query, column_rep) for column_rep in column_reps],
|
[(query, column_rep) for column_rep in column_reps],
|
||||||
max_length=self.max_seq_len,
|
max_length=self.max_seq_len,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
add_special_tokens=True,
|
add_special_tokens=True,
|
||||||
|
|||||||
@ -119,13 +119,13 @@ def test_save_load(tmp_path, model_name: str):
|
|||||||
text = "Some Text with neverseentokens plus !215?#. and a combined-token_with/chars"
|
text = "Some Text with neverseentokens plus !215?#. and a combined-token_with/chars"
|
||||||
|
|
||||||
tokenizer.add_tokens(new_tokens=["neverseentokens"])
|
tokenizer.add_tokens(new_tokens=["neverseentokens"])
|
||||||
original_encoding = tokenizer.encode_plus(text)
|
original_encoding = tokenizer(text)
|
||||||
|
|
||||||
save_dir = tmp_path / "saved_tokenizer"
|
save_dir = tmp_path / "saved_tokenizer"
|
||||||
tokenizer.save_pretrained(save_dir)
|
tokenizer.save_pretrained(save_dir)
|
||||||
|
|
||||||
tokenizer_loaded = get_tokenizer(pretrained_model_name_or_path=save_dir)
|
tokenizer_loaded = get_tokenizer(pretrained_model_name_or_path=save_dir)
|
||||||
new_encoding = tokenizer_loaded.encode_plus(text)
|
new_encoding = tokenizer_loaded(text)
|
||||||
|
|
||||||
assert original_encoding == new_encoding
|
assert original_encoding == new_encoding
|
||||||
|
|
||||||
@ -168,7 +168,7 @@ def test_tokenization_on_edge_cases_full_sequence_tokenization(model_name: str,
|
|||||||
words_and_spans = pre_tokenizer.pre_tokenize_str(edge_case)
|
words_and_spans = pre_tokenizer.pre_tokenize_str(edge_case)
|
||||||
words = [x[0] for x in words_and_spans]
|
words = [x[0] for x in words_and_spans]
|
||||||
|
|
||||||
encoded = tokenizer.encode_plus(words, is_split_into_words=True, add_special_tokens=False).encodings[0]
|
encoded = tokenizer(words, is_split_into_words=True, add_special_tokens=False).encodings[0]
|
||||||
expected_tokenization = tokenizer.tokenize(" ".join(edge_case.split())) # remove multiple whitespaces
|
expected_tokenization = tokenizer.tokenize(" ".join(edge_case.split())) # remove multiple whitespaces
|
||||||
|
|
||||||
assert encoded.tokens == expected_tokenization
|
assert encoded.tokens == expected_tokenization
|
||||||
@ -188,7 +188,7 @@ def test_tokenization_on_edge_cases_full_sequence_tokenization_roberta_exception
|
|||||||
words_and_spans = pre_tokenizer.pre_tokenize_str(edge_case)
|
words_and_spans = pre_tokenizer.pre_tokenize_str(edge_case)
|
||||||
words = [x[0] for x in words_and_spans]
|
words = [x[0] for x in words_and_spans]
|
||||||
|
|
||||||
encoded = tokenizer.encode_plus(words, is_split_into_words=True, add_special_tokens=False).encodings[0]
|
encoded = tokenizer(words, is_split_into_words=True, add_special_tokens=False).encodings[0]
|
||||||
expected_tokenization = tokenizer.tokenize(" ".join(edge_case.split())) # remove multiple whitespaces
|
expected_tokenization = tokenizer.tokenize(" ".join(edge_case.split())) # remove multiple whitespaces
|
||||||
|
|
||||||
assert encoded.tokens == expected_tokenization
|
assert encoded.tokens == expected_tokenization
|
||||||
@ -218,7 +218,7 @@ def test_tokenization_on_edge_cases_full_sequence_verify_spans(model_name: str,
|
|||||||
words = [x[0] for x in words_and_spans]
|
words = [x[0] for x in words_and_spans]
|
||||||
word_spans = [x[1] for x in words_and_spans]
|
word_spans = [x[1] for x in words_and_spans]
|
||||||
|
|
||||||
encoded = tokenizer.encode_plus(words, is_split_into_words=True, add_special_tokens=False).encodings[0]
|
encoded = tokenizer(words, is_split_into_words=True, add_special_tokens=False).encodings[0]
|
||||||
|
|
||||||
# subword-tokens have special chars depending on model type. To align with original text we get rid of them
|
# subword-tokens have special chars depending on model type. To align with original text we get rid of them
|
||||||
tokens = [token.replace(marker, "") for token in encoded.tokens]
|
tokens = [token.replace(marker, "") for token in encoded.tokens]
|
||||||
@ -248,7 +248,7 @@ def test_tokenization_on_edge_cases_full_sequence_verify_spans(model_name: str,
|
|||||||
def test_detokenization_for_bert(edge_case):
|
def test_detokenization_for_bert(edge_case):
|
||||||
tokenizer = get_tokenizer(pretrained_model_name_or_path=BERT, do_lower_case=False)
|
tokenizer = get_tokenizer(pretrained_model_name_or_path=BERT, do_lower_case=False)
|
||||||
|
|
||||||
encoded = tokenizer.encode_plus(edge_case, add_special_tokens=False).encodings[0]
|
encoded = tokenizer(edge_case, add_special_tokens=False).encodings[0]
|
||||||
|
|
||||||
detokenized = " ".join(encoded.tokens)
|
detokenized = " ".join(encoded.tokens)
|
||||||
detokenized = re.sub(r"(^|\s+)(##)", "", detokenized)
|
detokenized = re.sub(r"(^|\s+)(##)", "", detokenized)
|
||||||
@ -264,7 +264,7 @@ def test_encode_plus_for_bert():
|
|||||||
tokenizer = get_tokenizer(pretrained_model_name_or_path=BERT, do_lower_case=False)
|
tokenizer = get_tokenizer(pretrained_model_name_or_path=BERT, do_lower_case=False)
|
||||||
text = "Some Text with neverseentokens plus !215?#. and a combined-token_with/chars"
|
text = "Some Text with neverseentokens plus !215?#. and a combined-token_with/chars"
|
||||||
|
|
||||||
encoded_batch = tokenizer.encode_plus(text)
|
encoded_batch = tokenizer(text)
|
||||||
encoded = encoded_batch.encodings[0]
|
encoded = encoded_batch.encodings[0]
|
||||||
|
|
||||||
words = np.array(encoded.words)
|
words = np.array(encoded.words)
|
||||||
@ -316,7 +316,7 @@ def test_tokenize_custom_vocab_bert():
|
|||||||
|
|
||||||
tokenized = tokenizer.tokenize(text)
|
tokenized = tokenizer.tokenize(text)
|
||||||
|
|
||||||
encoded = tokenizer.encode_plus(text, add_special_tokens=False).encodings[0]
|
encoded = tokenizer(text, add_special_tokens=False).encodings[0]
|
||||||
offsets = [x[0] for x in encoded.offsets]
|
offsets = [x[0] for x in encoded.offsets]
|
||||||
start_of_word_single = [True] + list(np.ediff1d(encoded.words) > 0)
|
start_of_word_single = [True] + list(np.ediff1d(encoded.words) > 0)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user