mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-25 22:18:39 +00:00
Add inferencer for QA only (#1484)
* Add inferencer for QA only * Add latest docstring and tutorial changes * Add QA inferencer tests * Add type annotations for inferencer * Fix type annotations, move util functions * Fix type annotations * Move fixtures to the top of the file Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
d569e66bc7
commit
60471cecdf
@ -5,7 +5,6 @@ import random
|
||||
from contextlib import ExitStack
|
||||
from functools import partial
|
||||
from itertools import groupby
|
||||
from itertools import islice
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Tuple, Dict, Union
|
||||
|
||||
@ -19,8 +18,8 @@ from tqdm import tqdm
|
||||
|
||||
from haystack.modeling.data_handler.dataloader import NamedDataLoader
|
||||
from haystack.modeling.data_handler.processor import Processor
|
||||
from haystack.modeling.utils import MLFlowLogger as MlLogger
|
||||
from haystack.modeling.visual.ascii.images import TRACTOR_SMALL, WORKER_F, WORKER_M, WORKER_X
|
||||
from haystack.modeling.utils import MLFlowLogger as MlLogger, log_ascii_workers, grouper, calc_chunksize
|
||||
from haystack.modeling.visual.ascii.images import TRACTOR_SMALL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -716,114 +715,9 @@ class DataSiloForCrossVal:
|
||||
yield current_train_set, current_test_set
|
||||
|
||||
|
||||
def calc_chunksize(num_dicts, min_chunksize=4, max_chunksize=2000, max_processes=128):
|
||||
if mp.cpu_count() > 3:
|
||||
num_cpus = min(mp.cpu_count() - 1 or 1, max_processes) # -1 to keep a CPU core free for xxx
|
||||
else:
|
||||
num_cpus = min(mp.cpu_count(), max_processes) # when there are few cores, we use all of them
|
||||
|
||||
dicts_per_cpu = np.ceil(num_dicts / num_cpus)
|
||||
# automatic adjustment of multiprocessing chunksize
|
||||
# for small files (containing few dicts) we want small chunksize to ulitize all available cores but never less
|
||||
# than 2, because we need it to sample another random sentence in LM finetuning
|
||||
# for large files we want to minimize processor spawning without giving too much data to one process, so we
|
||||
# clip it at 5k
|
||||
multiprocessing_chunk_size = int(np.clip((np.ceil(dicts_per_cpu / 5)), a_min=min_chunksize, a_max=max_chunksize))
|
||||
# This lets us avoid cases in lm_finetuning where a chunk only has a single doc and hence cannot pick
|
||||
# a valid next sentence substitute from another document
|
||||
if num_dicts != 1:
|
||||
while num_dicts % multiprocessing_chunk_size == 1:
|
||||
multiprocessing_chunk_size -= -1
|
||||
dict_batches_to_process = int(num_dicts / multiprocessing_chunk_size)
|
||||
num_processes = min(num_cpus, dict_batches_to_process) or 1
|
||||
|
||||
return multiprocessing_chunk_size, num_processes
|
||||
|
||||
def log_ascii_workers(n, logger):
|
||||
m_worker_lines = WORKER_M.split("\n")
|
||||
f_worker_lines = WORKER_F.split("\n")
|
||||
x_worker_lines = WORKER_X.split("\n")
|
||||
all_worker_lines = []
|
||||
for i in range(n):
|
||||
rand = np.random.randint(low=0,high=3)
|
||||
if(rand % 3 == 0):
|
||||
all_worker_lines.append(f_worker_lines)
|
||||
elif(rand % 3 == 1):
|
||||
all_worker_lines.append(m_worker_lines)
|
||||
else:
|
||||
all_worker_lines.append(x_worker_lines)
|
||||
zipped = zip(*all_worker_lines)
|
||||
for z in zipped:
|
||||
logger.info(" ".join(z))
|
||||
|
||||
def get_dict_checksum(payload_dict):
|
||||
"""
|
||||
Get MD5 checksum for a dict.
|
||||
"""
|
||||
checksum = hashlib.md5(json.dumps(payload_dict, sort_keys=True).encode("utf-8")).hexdigest()
|
||||
return checksum
|
||||
|
||||
def grouper(iterable, n, worker_id=0, total_workers=1):
|
||||
"""
|
||||
Split an iterable into a list of n-sized chunks. Each element in the chunk is a tuple of (index_num, element).
|
||||
|
||||
Example:
|
||||
list(grouper('ABCDEFG', 3))
|
||||
[[(0, 'A'), (1, 'B'), (2, 'C')], [(3, 'D'), (4, 'E'), (5, 'F')], [(6, 'G')]]
|
||||
|
||||
|
||||
Use with the StreamingDataSilo
|
||||
|
||||
When StreamingDataSilo is used with multiple PyTorch DataLoader workers, the generator
|
||||
yielding dicts(that gets converted to datasets) is replicated across the workers.
|
||||
|
||||
To avoid duplicates, we split the dicts across workers by creating a new generator for
|
||||
each worker using this method.
|
||||
|
||||
Input --> [dictA, dictB, dictC, dictD, dictE, ...] with total worker=3 and n=2
|
||||
|
||||
Output for worker 1: [(dictA, dictB), (dictG, dictH), ...]
|
||||
Output for worker 2: [(dictC, dictD), (dictI, dictJ), ...]
|
||||
Output for worker 3: [(dictE, dictF), (dictK, dictL), ...]
|
||||
|
||||
This method also adds an index number to every dict yielded.
|
||||
|
||||
:param iterable: a generator object that yields dicts
|
||||
:type iterable: generator
|
||||
:param n: the dicts are grouped in n-sized chunks that gets converted to datasets
|
||||
:type n: int
|
||||
:param worker_id: the worker_id for the PyTorch DataLoader
|
||||
:type worker_id: int
|
||||
:param total_workers: total number of workers for the PyTorch DataLoader
|
||||
:type total_workers: int
|
||||
"""
|
||||
# TODO make me comprehensible :)
|
||||
def get_iter_start_pos(gen):
|
||||
start_pos = worker_id * n
|
||||
for i in gen:
|
||||
if start_pos:
|
||||
start_pos -= 1
|
||||
continue
|
||||
yield i
|
||||
|
||||
def filter_elements_per_worker(gen):
|
||||
x = n
|
||||
y = (total_workers - 1) * n
|
||||
for i in gen:
|
||||
if x:
|
||||
yield i
|
||||
x -= 1
|
||||
else:
|
||||
if y != 1:
|
||||
y -= 1
|
||||
continue
|
||||
else:
|
||||
x = n
|
||||
y = (total_workers - 1) * n
|
||||
|
||||
iterable = iter(enumerate(iterable))
|
||||
iterable = get_iter_start_pos(iterable)
|
||||
if total_workers > 1:
|
||||
iterable = filter_elements_per_worker(iterable)
|
||||
|
||||
return iter(lambda: list(islice(iterable, n)), [])
|
||||
520
haystack/modeling/infer.py
Normal file
520
haystack/modeling/infer.py
Normal file
@ -0,0 +1,520 @@
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch.utils.data.sampler import SequentialSampler
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
from typing import List, Optional, Dict, Union, Generator, Set, Any
|
||||
|
||||
from haystack.modeling.data_handler.dataloader import NamedDataLoader
|
||||
from haystack.modeling.data_handler.processor import Processor
|
||||
from haystack.modeling.data_handler.samples import SampleBasket
|
||||
from haystack.modeling.utils import grouper
|
||||
from haystack.modeling.data_handler.inputs import QAInput
|
||||
from haystack.modeling.model.adaptive_model import AdaptiveModel, BaseAdaptiveModel
|
||||
from haystack.modeling.utils import initialize_device_settings, MLFlowLogger
|
||||
from haystack.modeling.utils import set_all_seeds, calc_chunksize, log_ascii_workers
|
||||
from haystack.modeling.model.predictions import QAPred
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Inferencer:
|
||||
"""
|
||||
Loads a saved AdaptiveModel/ONNXAdaptiveModel from disk and runs it in inference mode. Can be used for a
|
||||
model with prediction head (down-stream predictions) and without (using LM as embedder).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: AdaptiveModel,
|
||||
processor: Processor,
|
||||
task_type: Optional[str],
|
||||
batch_size: int =4,
|
||||
gpu: bool =False,
|
||||
name: Optional[str] =None,
|
||||
return_class_probs: bool=False,
|
||||
num_processes: Optional[int] =None,
|
||||
disable_tqdm: bool =False
|
||||
):
|
||||
"""
|
||||
Initializes Inferencer from an AdaptiveModel and a Processor instance.
|
||||
|
||||
:param model: AdaptiveModel to run in inference mode
|
||||
:param processor: A dataset specific Processor object which will turn input (file or dict) into a Pytorch Dataset.
|
||||
:param task_type: Type of task the model should be used for. Currently supporting: "question_answering"
|
||||
:param batch_size: Number of samples computed once per batch
|
||||
:param gpu: If GPU shall be used
|
||||
:param name: Name for the current Inferencer model, displayed in the REST API
|
||||
:param return_class_probs: either return probability distribution over all labels or the prob of the associated label
|
||||
:param num_processes: the number of processes for `multiprocessing.Pool`.
|
||||
Set to value of 1 (or 0) to disable multiprocessing.
|
||||
Set to None to let Inferencer use all CPU cores minus one.
|
||||
If you want to debug the Language Model, you might need to disable multiprocessing!
|
||||
**Warning!** If you use multiprocessing you have to close the
|
||||
`multiprocessing.Pool` again! To do so call
|
||||
:func:`~farm.infer.Inferencer.close_multiprocessing_pool` after you are
|
||||
done using this class. The garbage collector will not do this for you!
|
||||
:param disable_tqdm: Whether to disable tqdm logging (can get very verbose in multiprocessing)
|
||||
:return: An instance of the Inferencer.
|
||||
|
||||
"""
|
||||
MLFlowLogger.disable()
|
||||
|
||||
# Init device and distributed settings
|
||||
device, n_gpu = initialize_device_settings(use_cuda=gpu, local_rank=-1, use_amp=None)
|
||||
|
||||
self.processor = processor
|
||||
self.model = model
|
||||
self.model.eval()
|
||||
self.batch_size = batch_size
|
||||
self.device = device
|
||||
self.language = self.model.get_language()
|
||||
self.task_type = task_type
|
||||
self.disable_tqdm = disable_tqdm
|
||||
self.problematic_sample_ids: Set[List[int]] = set() # type ignore
|
||||
|
||||
# TODO add support for multiple prediction heads
|
||||
|
||||
self.name = name if name != None else f"anonymous-{self.task_type}"
|
||||
self.return_class_probs = return_class_probs
|
||||
|
||||
model.connect_heads_with_processor(processor.tasks, require_labels=False)
|
||||
set_all_seeds(42)
|
||||
|
||||
self._set_multiprocessing_pool(num_processes)
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
model_name_or_path: str,
|
||||
revision: Optional[str] = None,
|
||||
batch_size: int = 4,
|
||||
gpu: bool = False,
|
||||
task_type: Optional[str] =None,
|
||||
return_class_probs: bool = False,
|
||||
strict: bool = True,
|
||||
max_seq_len: int = 256,
|
||||
doc_stride: int = 128,
|
||||
num_processes: Optional[int] =None,
|
||||
disable_tqdm: bool = False,
|
||||
tokenizer_class: Optional[str] = None,
|
||||
use_fast: bool = True,
|
||||
tokenizer_args: Dict =None,
|
||||
multithreading_rust: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Load an Inferencer incl. all relevant components (model, tokenizer, processor ...) either by
|
||||
|
||||
1. specifying a public name from transformers' model hub (https://huggingface.co/models)
|
||||
2. or pointing to a local directory it is saved in.
|
||||
|
||||
:param model_name_or_path: Local directory or public name of the model to load.
|
||||
:param revision: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
|
||||
:param batch_size: Number of samples computed once per batch
|
||||
:param gpu: If GPU shall be used
|
||||
:param task_type: Type of task the model should be used for. Currently supporting: "question_answering"
|
||||
:param return_class_probs: either return probability distribution over all labels or the prob of the associated label
|
||||
:param strict: whether to strictly enforce that the keys loaded from saved model match the ones in
|
||||
the PredictionHead (see torch.nn.module.load_state_dict()).
|
||||
Set to `False` for backwards compatibility with PHs saved with older version of FARM.
|
||||
:param max_seq_len: maximum length of one text sample
|
||||
:param doc_stride: Only QA: When input text is longer than max_seq_len it gets split into parts, strided by doc_stride
|
||||
:param num_processes: the number of processes for `multiprocessing.Pool`. Set to value of 0 to disable
|
||||
multiprocessing. Set to None to let Inferencer use all CPU cores minus one. If you want to
|
||||
debug the Language Model, you might need to disable multiprocessing!
|
||||
**Warning!** If you use multiprocessing you have to close the
|
||||
`multiprocessing.Pool` again! To do so call
|
||||
:func:`~farm.infer.Inferencer.close_multiprocessing_pool` after you are
|
||||
done using this class. The garbage collector will not do this for you!
|
||||
:param disable_tqdm: Whether to disable tqdm logging (can get very verbose in multiprocessing)
|
||||
:param tokenizer_class: (Optional) Name of the tokenizer class to load (e.g. `BertTokenizer`)
|
||||
:param use_fast: (Optional, True by default) Indicate if FARM should try to load the fast version of the tokenizer (True) or
|
||||
use the Python one (False).
|
||||
:param tokenizer_args: (Optional) Will be passed to the Tokenizer ``__init__`` method.
|
||||
See https://huggingface.co/transformers/main_classes/tokenizer.html and detailed tokenizer documentation
|
||||
on `Hugging Face Transformers <https://huggingface.co/transformers/>`_.
|
||||
:param multithreading_rust: Whether to allow multithreading in Rust, e.g. for FastTokenizers.
|
||||
Note: Enabling multithreading in Rust AND multiprocessing in python might cause
|
||||
deadlocks.
|
||||
:return: An instance of the Inferencer.
|
||||
|
||||
"""
|
||||
if tokenizer_args is None:
|
||||
tokenizer_args = {}
|
||||
|
||||
device, n_gpu = initialize_device_settings(use_cuda=gpu, local_rank=-1, use_amp=None)
|
||||
name = os.path.basename(model_name_or_path)
|
||||
|
||||
# a) either from local dir
|
||||
if os.path.exists(model_name_or_path):
|
||||
model = BaseAdaptiveModel.load(load_dir=model_name_or_path, device=device, strict=strict)
|
||||
processor = Processor.load_from_dir(model_name_or_path)
|
||||
|
||||
# b) or from remote transformers model hub
|
||||
else:
|
||||
if not task_type:
|
||||
raise ValueError("Please specify the 'task_type' of the model you want to load from transformers. "
|
||||
"Valid options for arg `task_type`:"
|
||||
"'question_answering'")
|
||||
|
||||
model = AdaptiveModel.convert_from_transformers(model_name_or_path,
|
||||
revision=revision,
|
||||
device=device,
|
||||
task_type=task_type,
|
||||
**kwargs)
|
||||
processor = Processor.convert_from_transformers(model_name_or_path,
|
||||
revision=revision,
|
||||
task_type=task_type,
|
||||
max_seq_len=max_seq_len,
|
||||
doc_stride=doc_stride,
|
||||
tokenizer_class=tokenizer_class,
|
||||
tokenizer_args=tokenizer_args,
|
||||
use_fast=use_fast,
|
||||
**kwargs)
|
||||
|
||||
# override processor attributes loaded from config or HF with inferencer params
|
||||
processor.max_seq_len = max_seq_len
|
||||
processor.multithreading_rust = multithreading_rust
|
||||
if hasattr(processor, "doc_stride"):
|
||||
assert doc_stride < max_seq_len, "doc_stride is longer than max_seq_len. This means that there will be gaps " \
|
||||
"as the passage windows slide, causing the model to skip over parts of the document. " \
|
||||
"Please set a lower value for doc_stride (Suggestions: doc_stride=128, max_seq_len=384) "
|
||||
processor.doc_stride = doc_stride
|
||||
|
||||
return cls(
|
||||
model,
|
||||
processor,
|
||||
task_type=task_type,
|
||||
batch_size=batch_size,
|
||||
gpu=gpu,
|
||||
name=name,
|
||||
return_class_probs=return_class_probs,
|
||||
num_processes=num_processes,
|
||||
disable_tqdm=disable_tqdm
|
||||
)
|
||||
|
||||
def _set_multiprocessing_pool(self, num_processes: Optional[int]):
|
||||
"""
|
||||
Initialize a multiprocessing.Pool for instances of Inferencer.
|
||||
|
||||
:param num_processes: the number of processes for `multiprocessing.Pool`.
|
||||
Set to value of 1 (or 0) to disable multiprocessing.
|
||||
Set to None to let Inferencer use all CPU cores minus one.
|
||||
If you want to debug the Language Model, you might need to disable multiprocessing!
|
||||
**Warning!** If you use multiprocessing you have to close the
|
||||
`multiprocessing.Pool` again! To do so call
|
||||
:func:`~farm.infer.Inferencer.close_multiprocessing_pool` after you are
|
||||
done using this class. The garbage collector will not do this for you!
|
||||
:return:
|
||||
"""
|
||||
self.process_pool = None
|
||||
if num_processes == 0 or num_processes == 1: # disable multiprocessing
|
||||
self.process_pool = None
|
||||
else:
|
||||
if num_processes is None: # use all CPU cores
|
||||
if mp.cpu_count() > 3:
|
||||
num_processes = mp.cpu_count() - 1
|
||||
else:
|
||||
num_processes = mp.cpu_count()
|
||||
self.process_pool = mp.Pool(processes=num_processes)
|
||||
logger.info(
|
||||
f"Got ya {num_processes} parallel workers to do inference ..."
|
||||
)
|
||||
log_ascii_workers(n=num_processes,logger=logger)
|
||||
|
||||
def close_multiprocessing_pool(self, join: bool = False):
|
||||
"""Close the `multiprocessing.Pool` again.
|
||||
|
||||
If you use multiprocessing you have to close the `multiprocessing.Pool` again!
|
||||
To do so call this function after you are done using this class.
|
||||
The garbage collector will not do this for you!
|
||||
|
||||
:param join: wait for the worker processes to exit
|
||||
"""
|
||||
if self.process_pool is not None:
|
||||
self.process_pool.close()
|
||||
if join:
|
||||
self.process_pool.join()
|
||||
self.process_pool = None
|
||||
|
||||
def save(self, path: str):
|
||||
self.model.save(path)
|
||||
self.processor.save(path)
|
||||
|
||||
def inference_from_file(self, file: str, multiprocessing_chunksize: int = None, return_json: bool = True):
|
||||
"""
|
||||
Run down-stream inference on samples created from an input file.
|
||||
The file should be in the same format as the ones used during training
|
||||
(e.g. squad style for QA, tsv for doc classification ...) as the same Processor will be used for conversion.
|
||||
|
||||
:param file: path of the input file for Inference
|
||||
:param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process
|
||||
:return: list of predictions
|
||||
"""
|
||||
dicts = self.processor.file_to_dicts(file)
|
||||
preds_all = self.inference_from_dicts(
|
||||
dicts,
|
||||
return_json=return_json,
|
||||
multiprocessing_chunksize=multiprocessing_chunksize,
|
||||
)
|
||||
return list(preds_all)
|
||||
|
||||
def inference_from_dicts(
|
||||
self, dicts: List[Dict], return_json: bool = True, multiprocessing_chunksize: Optional[int] = None
|
||||
) -> List:
|
||||
"""
|
||||
Runs down-stream inference on samples created from input dictionaries.
|
||||
|
||||
* QA (FARM style): [{"questions": ["What is X?"], "text": "Some context containing the answer"}]
|
||||
|
||||
:param dicts: Samples to run inference on provided as a list(or a generator object) of dicts.
|
||||
One dict per sample.
|
||||
:param return_json: Whether the output should be in a json appropriate format. If False, it returns the prediction
|
||||
object where applicable, else it returns PredObj.to_json()
|
||||
:param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process
|
||||
(only relevant if you do multiprocessing)
|
||||
:return: list of predictions
|
||||
|
||||
"""
|
||||
|
||||
# whether to aggregate predictions across different samples (e.g. for QA on long texts)
|
||||
# TODO remove or adjust after implmenting input objects properly
|
||||
# if set(dicts[0].keys()) == {"qas", "context"}:
|
||||
# warnings.warn("QA Input dictionaries with [qas, context] as keys will be deprecated in the future",
|
||||
# DeprecationWarning)
|
||||
|
||||
aggregate_preds = False
|
||||
if len(self.model.prediction_heads) > 0:
|
||||
aggregate_preds = hasattr(self.model.prediction_heads[0], "aggregate_preds")
|
||||
|
||||
if self.process_pool is None: # multiprocessing disabled (helpful for debugging or using in web frameworks)
|
||||
predictions: Any = self._inference_without_multiprocessing(dicts, return_json, aggregate_preds)
|
||||
return predictions
|
||||
else: # use multiprocessing for inference
|
||||
# Calculate values of multiprocessing_chunksize and num_processes if not supplied in the parameters.
|
||||
|
||||
if multiprocessing_chunksize is None:
|
||||
_chunk_size, _ = calc_chunksize(len(dicts))
|
||||
multiprocessing_chunksize = _chunk_size
|
||||
|
||||
predictions = self._inference_with_multiprocessing(
|
||||
dicts, return_json, aggregate_preds, multiprocessing_chunksize,
|
||||
)
|
||||
|
||||
self.processor.log_problematic(self.problematic_sample_ids)
|
||||
# cast the generator to a list if it isnt already a list.
|
||||
if type(predictions) != list:
|
||||
return list(predictions)
|
||||
else:
|
||||
return predictions
|
||||
|
||||
def _inference_without_multiprocessing(self, dicts: List[Dict], return_json: bool, aggregate_preds: bool) -> List:
|
||||
"""
|
||||
Implementation of inference from dicts without using Python multiprocessing. Useful for debugging or in API
|
||||
framework where spawning new processes could be expensive.
|
||||
|
||||
:param dicts: Samples to run inference on provided as a list of dicts. One dict per sample.
|
||||
:param return_json: Whether the output should be in a json appropriate format. If False, it returns the prediction
|
||||
object where applicable, else it returns PredObj.to_json()
|
||||
:param aggregate_preds: whether to aggregate predictions across different samples (e.g. for QA on long texts)
|
||||
:return: list of predictions
|
||||
"""
|
||||
indices = list(range(len(dicts)))
|
||||
dataset, tensor_names, problematic_ids, baskets = self.processor.dataset_from_dicts(
|
||||
dicts, indices=indices, return_baskets=True
|
||||
)
|
||||
self.problematic_sample_ids = problematic_ids
|
||||
|
||||
# TODO change format of formatted_preds in QA (list of dicts)
|
||||
if aggregate_preds:
|
||||
preds_all = self._get_predictions_and_aggregate(dataset, tensor_names, baskets)
|
||||
else:
|
||||
preds_all = self._get_predictions(dataset, tensor_names, baskets)
|
||||
|
||||
if return_json:
|
||||
# TODO this try catch should be removed when all tasks return prediction objects
|
||||
try:
|
||||
preds_all = [x.to_json() for x in preds_all]
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
return preds_all
|
||||
|
||||
def _inference_with_multiprocessing(
|
||||
self, dicts: Union[List[Dict], Generator[Dict, None, None]], return_json: bool, aggregate_preds: bool, multiprocessing_chunksize: int
|
||||
) -> Generator[Dict, None, None]:
|
||||
"""
|
||||
Implementation of inference. This method is a generator that yields the results.
|
||||
|
||||
:param dicts: Samples to run inference on provided as a list of dicts or a generator object that yield dicts.
|
||||
:param return_json: Whether the output should be in a json appropriate format. If False, it returns the prediction
|
||||
object where applicable, else it returns PredObj.to_json()
|
||||
:param aggregate_preds: whether to aggregate predictions across different samples (e.g. for QA on long texts)
|
||||
:param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process
|
||||
:return: generator object that yield predictions
|
||||
"""
|
||||
|
||||
# We group the input dicts into chunks and feed each chunk to a different process
|
||||
# in the pool, where it gets converted to a pytorch dataset
|
||||
if self.process_pool is not None:
|
||||
results = self.process_pool.imap(
|
||||
partial(self._create_datasets_chunkwise, processor=self.processor),
|
||||
grouper(iterable=dicts, n=multiprocessing_chunksize),
|
||||
1,
|
||||
)
|
||||
|
||||
# Once a process spits out a preprocessed chunk. we feed this dataset directly to the model.
|
||||
# So we don't need to wait until all preprocessing has finished before getting first predictions.
|
||||
for dataset, tensor_names, problematic_sample_ids, baskets in results:
|
||||
self.problematic_sample_ids.update(problematic_sample_ids)
|
||||
if dataset is None:
|
||||
logger.error(f"Part of the dataset could not be converted! \n"
|
||||
f"BE AWARE: The order of predictions will not conform with the input order!")
|
||||
else:
|
||||
# TODO change format of formatted_preds in QA (list of dicts)
|
||||
if aggregate_preds:
|
||||
predictions = self._get_predictions_and_aggregate(
|
||||
dataset, tensor_names, baskets
|
||||
)
|
||||
else:
|
||||
predictions = self._get_predictions(dataset, tensor_names, baskets)
|
||||
|
||||
if return_json:
|
||||
# TODO this try catch should be removed when all tasks return prediction objects
|
||||
try:
|
||||
predictions = [x.to_json() for x in predictions]
|
||||
except AttributeError:
|
||||
pass
|
||||
yield from predictions
|
||||
|
||||
@classmethod
|
||||
def _create_datasets_chunkwise(cls, chunk, processor: Processor):
|
||||
"""Convert ONE chunk of data (i.e. dictionaries) into ONE pytorch dataset.
|
||||
This is usually executed in one of many parallel processes.
|
||||
The resulting datasets of the processes are merged together afterwards"""
|
||||
dicts = [d[1] for d in chunk]
|
||||
indices = [d[0] for d in chunk]
|
||||
dataset, tensor_names, problematic_sample_ids, baskets = processor.dataset_from_dicts(dicts, indices, return_baskets=True)
|
||||
return dataset, tensor_names, problematic_sample_ids, baskets
|
||||
|
||||
def _get_predictions(self, dataset: Dataset, tensor_names: List, baskets):
|
||||
"""
|
||||
Feed a preprocessed dataset to the model and get the actual predictions (forward pass + formatting).
|
||||
|
||||
:param dataset: PyTorch Dataset with samples you want to predict
|
||||
:param tensor_names: Names of the tensors in the dataset
|
||||
:param baskets: For each item in the dataset, we need additional information to create formatted preds.
|
||||
Baskets contain all relevant infos for that.
|
||||
Example: QA - input string to convert the predicted answer from indices back to string space
|
||||
:return: list of predictions
|
||||
"""
|
||||
samples = [s for b in baskets for s in b.samples]
|
||||
|
||||
data_loader = NamedDataLoader(
|
||||
dataset=dataset, sampler=SequentialSampler(dataset), batch_size=self.batch_size, tensor_names=tensor_names
|
||||
) # type ignore
|
||||
preds_all = []
|
||||
for i, batch in enumerate(tqdm(data_loader, desc=f"Inferencing Samples", unit=" Batches", disable=self.disable_tqdm)):
|
||||
batch = {key: batch[key].to(self.device) for key in batch}
|
||||
batch_samples = samples[i * self.batch_size : (i + 1) * self.batch_size]
|
||||
|
||||
# get logits
|
||||
with torch.no_grad():
|
||||
logits = self.model.forward(**batch)
|
||||
preds = self.model.formatted_preds(
|
||||
logits=logits,
|
||||
samples=batch_samples,
|
||||
tokenizer=self.processor.tokenizer,
|
||||
return_class_probs=self.return_class_probs,
|
||||
**batch)
|
||||
preds_all += preds
|
||||
return preds_all
|
||||
|
||||
def _get_predictions_and_aggregate(self, dataset: Dataset, tensor_names: List, baskets: List[SampleBasket]):
|
||||
"""
|
||||
Feed a preprocessed dataset to the model and get the actual predictions (forward pass + logits_to_preds + formatted_preds).
|
||||
|
||||
Difference to _get_predictions():
|
||||
- Additional aggregation step across predictions of individual samples
|
||||
(e.g. For QA on long texts, we extract answers from multiple passages and then aggregate them on the "document level")
|
||||
|
||||
:param dataset: PyTorch Dataset with samples you want to predict
|
||||
:param tensor_names: Names of the tensors in the dataset
|
||||
:param baskets: For each item in the dataset, we need additional information to create formatted preds.
|
||||
Baskets contain all relevant infos for that.
|
||||
Example: QA - input string to convert the predicted answer from indices back to string space
|
||||
:return: list of predictions
|
||||
"""
|
||||
|
||||
data_loader = NamedDataLoader(
|
||||
dataset=dataset, sampler=SequentialSampler(dataset), batch_size=self.batch_size, tensor_names=tensor_names
|
||||
) # type ignore
|
||||
# TODO Sometimes this is the preds of one head, sometimes of two. We need a more advanced stacking operation
|
||||
# TODO so that preds of the right shape are passed in to formatted_preds
|
||||
unaggregated_preds_all = []
|
||||
|
||||
for i, batch in enumerate(tqdm(data_loader, desc=f"Inferencing Samples", unit=" Batches", disable=self.disable_tqdm)):
|
||||
|
||||
batch = {key: batch[key].to(self.device) for key in batch}
|
||||
|
||||
# get logits
|
||||
with torch.no_grad():
|
||||
# Aggregation works on preds, not logits. We want as much processing happening in one batch + on GPU
|
||||
# So we transform logits to preds here as well
|
||||
logits = self.model.forward(**batch)
|
||||
# preds = self.model.logits_to_preds(logits, **batch)[0] (This must somehow be useful for SQuAD)
|
||||
preds = self.model.logits_to_preds(logits, **batch)
|
||||
unaggregated_preds_all.append(preds)
|
||||
|
||||
# In some use cases we want to aggregate the individual predictions.
|
||||
# This is mostly useful, if the input text is longer than the max_seq_len that the model can process.
|
||||
# In QA we can use this to get answers from long input texts by first getting predictions for smaller passages
|
||||
# and then aggregating them here.
|
||||
|
||||
# At this point unaggregated preds has shape [n_batches][n_heads][n_samples]
|
||||
|
||||
# can assume that we have only complete docs i.e. all the samples of one doc are in the current chunk
|
||||
logits = [None]
|
||||
preds_all = self.model.formatted_preds(logits=logits, # For QA we collected preds per batch and do not want to pass logits
|
||||
preds=unaggregated_preds_all,
|
||||
baskets=baskets) # type ignore
|
||||
return preds_all
|
||||
|
||||
|
||||
class QAInferencer(Inferencer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if self.task_type != "question_answering":
|
||||
logger.warning("QAInferencer always has task_type='question_answering' even if another value is provided "
|
||||
"to Inferencer.load() or QAInferencer()")
|
||||
self.task_type = "question_answering"
|
||||
|
||||
def inference_from_dicts(self,
|
||||
dicts: List[dict],
|
||||
return_json: bool = True,
|
||||
multiprocessing_chunksize: Optional[int] = None) -> List[QAPred]:
|
||||
return Inferencer.inference_from_dicts(self, dicts, return_json=return_json,
|
||||
multiprocessing_chunksize=multiprocessing_chunksize)
|
||||
|
||||
def inference_from_file(self,
|
||||
file: str,
|
||||
multiprocessing_chunksize: Optional[int] = None,
|
||||
return_json=True) -> List[QAPred]:
|
||||
return Inferencer.inference_from_file(self, file, return_json=return_json,
|
||||
multiprocessing_chunksize=multiprocessing_chunksize)
|
||||
|
||||
def inference_from_objects(self,
|
||||
objects: List[QAInput],
|
||||
return_json: bool = True,
|
||||
multiprocessing_chunksize: Optional[int] = None) -> List[QAPred]:
|
||||
dicts = [o.to_dict() for o in objects]
|
||||
# TODO investigate this deprecation warning. Timo: I thought we were about to implement Input Objects, then we can and should use inference from (input) objects!
|
||||
#logger.warning("QAInferencer.inference_from_objects() will soon be deprecated. Use QAInferencer.inference_from_dicts() instead")
|
||||
return self.inference_from_dicts(dicts, return_json=return_json,
|
||||
multiprocessing_chunksize=multiprocessing_chunksize)
|
||||
@ -4,13 +4,17 @@ import pickle
|
||||
import random
|
||||
import signal
|
||||
from copy import deepcopy
|
||||
from itertools import islice
|
||||
|
||||
import mlflow
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import multiprocessing as mp
|
||||
from requests.exceptions import ConnectionError
|
||||
|
||||
from haystack.modeling.visual.ascii.images import WORKER_M, WORKER_F, WORKER_X
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class GracefulKiller:
|
||||
@ -296,3 +300,113 @@ def all_gather_list(data, group=None, max_size=16384):
|
||||
'in your training script that can cause one worker to finish an epoch '
|
||||
'while other workers are still iterating over their portions of the data.'
|
||||
)
|
||||
|
||||
|
||||
def grouper(iterable, n, worker_id=0, total_workers=1):
|
||||
"""
|
||||
Split an iterable into a list of n-sized chunks. Each element in the chunk is a tuple of (index_num, element).
|
||||
|
||||
Example:
|
||||
|
||||
>>> list(grouper('ABCDEFG', 3))
|
||||
[[(0, 'A'), (1, 'B'), (2, 'C')], [(3, 'D'), (4, 'E'), (5, 'F')], [(6, 'G')]]
|
||||
|
||||
|
||||
|
||||
Use with the StreamingDataSilo
|
||||
|
||||
When StreamingDataSilo is used with multiple PyTorch DataLoader workers, the generator
|
||||
yielding dicts(that gets converted to datasets) is replicated across the workers.
|
||||
|
||||
To avoid duplicates, we split the dicts across workers by creating a new generator for
|
||||
each worker using this method.
|
||||
|
||||
Input --> [dictA, dictB, dictC, dictD, dictE, ...] with total worker=3 and n=2
|
||||
|
||||
Output for worker 1: [(dictA, dictB), (dictG, dictH), ...]
|
||||
Output for worker 2: [(dictC, dictD), (dictI, dictJ), ...]
|
||||
Output for worker 3: [(dictE, dictF), (dictK, dictL), ...]
|
||||
|
||||
This method also adds an index number to every dict yielded.
|
||||
|
||||
:param iterable: a generator object that yields dicts
|
||||
:type iterable: generator
|
||||
:param n: the dicts are grouped in n-sized chunks that gets converted to datasets
|
||||
:type n: int
|
||||
:param worker_id: the worker_id for the PyTorch DataLoader
|
||||
:type worker_id: int
|
||||
:param total_workers: total number of workers for the PyTorch DataLoader
|
||||
:type total_workers: int
|
||||
"""
|
||||
# TODO make me comprehensible :)
|
||||
def get_iter_start_pos(gen):
|
||||
start_pos = worker_id * n
|
||||
for i in gen:
|
||||
if start_pos:
|
||||
start_pos -= 1
|
||||
continue
|
||||
yield i
|
||||
|
||||
def filter_elements_per_worker(gen):
|
||||
x = n
|
||||
y = (total_workers - 1) * n
|
||||
for i in gen:
|
||||
if x:
|
||||
yield i
|
||||
x -= 1
|
||||
else:
|
||||
if y != 1:
|
||||
y -= 1
|
||||
continue
|
||||
else:
|
||||
x = n
|
||||
y = (total_workers - 1) * n
|
||||
|
||||
iterable = iter(enumerate(iterable))
|
||||
iterable = get_iter_start_pos(iterable)
|
||||
if total_workers > 1:
|
||||
iterable = filter_elements_per_worker(iterable)
|
||||
|
||||
return iter(lambda: list(islice(iterable, n)), [])
|
||||
|
||||
|
||||
def calc_chunksize(num_dicts, min_chunksize=4, max_chunksize=2000, max_processes=128):
|
||||
if mp.cpu_count() > 3:
|
||||
num_cpus = min(mp.cpu_count() - 1 or 1, max_processes) # -1 to keep a CPU core free for xxx
|
||||
else:
|
||||
num_cpus = min(mp.cpu_count(), max_processes) # when there are few cores, we use all of them
|
||||
|
||||
dicts_per_cpu = np.ceil(num_dicts / num_cpus)
|
||||
# automatic adjustment of multiprocessing chunksize
|
||||
# for small files (containing few dicts) we want small chunksize to ulitize all available cores but never less
|
||||
# than 2, because we need it to sample another random sentence in LM finetuning
|
||||
# for large files we want to minimize processor spawning without giving too much data to one process, so we
|
||||
# clip it at 5k
|
||||
multiprocessing_chunk_size = int(np.clip((np.ceil(dicts_per_cpu / 5)), a_min=min_chunksize, a_max=max_chunksize))
|
||||
# This lets us avoid cases in lm_finetuning where a chunk only has a single doc and hence cannot pick
|
||||
# a valid next sentence substitute from another document
|
||||
if num_dicts != 1:
|
||||
while num_dicts % multiprocessing_chunk_size == 1:
|
||||
multiprocessing_chunk_size -= -1
|
||||
dict_batches_to_process = int(num_dicts / multiprocessing_chunk_size)
|
||||
num_processes = min(num_cpus, dict_batches_to_process) or 1
|
||||
|
||||
return multiprocessing_chunk_size, num_processes
|
||||
|
||||
|
||||
def log_ascii_workers(n, logger):
|
||||
m_worker_lines = WORKER_M.split("\n")
|
||||
f_worker_lines = WORKER_F.split("\n")
|
||||
x_worker_lines = WORKER_X.split("\n")
|
||||
all_worker_lines = []
|
||||
for i in range(n):
|
||||
rand = np.random.randint(low=0,high=3)
|
||||
if(rand % 3 == 0):
|
||||
all_worker_lines.append(f_worker_lines)
|
||||
elif(rand % 3 == 1):
|
||||
all_worker_lines.append(m_worker_lines)
|
||||
else:
|
||||
all_worker_lines.append(x_worker_lines)
|
||||
zipped = zip(*all_worker_lines)
|
||||
for z in zipped:
|
||||
logger.info(" ".join(z))
|
||||
|
||||
@ -3,6 +3,7 @@ import time
|
||||
from subprocess import run
|
||||
from sys import platform
|
||||
|
||||
import psutil
|
||||
import pytest
|
||||
import requests
|
||||
from elasticsearch import Elasticsearch
|
||||
@ -17,6 +18,7 @@ from haystack.document_store.weaviate import WeaviateDocumentStore
|
||||
|
||||
from haystack.document_store.milvus import MilvusDocumentStore
|
||||
from haystack.generator.transformers import RAGenerator, RAGeneratorType
|
||||
from haystack.modeling.infer import Inferencer, QAInferencer
|
||||
from haystack.ranker import FARMRanker, SentenceTransformersRanker
|
||||
|
||||
from haystack.retriever.sparse import ElasticsearchFilterOnlyRetriever, ElasticsearchRetriever, TfidfRetriever
|
||||
@ -465,3 +467,44 @@ def get_document_store(document_store_type, embedding_dim=768, embedding_field="
|
||||
raise Exception(f"No document store fixture for '{document_store_type}'")
|
||||
|
||||
return document_store
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def adaptive_model_qa(num_processes):
|
||||
"""
|
||||
PyTest Fixture for a Question Answering Inferencer based on PyTorch.
|
||||
"""
|
||||
try:
|
||||
model = Inferencer.load(
|
||||
"deepset/bert-base-cased-squad2",
|
||||
task_type="question_answering",
|
||||
batch_size=16,
|
||||
num_processes=num_processes,
|
||||
gpu=False,
|
||||
)
|
||||
yield model
|
||||
finally:
|
||||
if num_processes != 0:
|
||||
# close the pool
|
||||
# we pass join=True to wait for all sub processes to close
|
||||
# this is because below we want to test if all sub-processes
|
||||
# have exited
|
||||
model.close_multiprocessing_pool(join=True)
|
||||
|
||||
# check if all workers (sub processes) are closed
|
||||
current_process = psutil.Process()
|
||||
children = current_process.children()
|
||||
assert len(children) == 0
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def bert_base_squad2(request):
|
||||
model = QAInferencer.load(
|
||||
"deepset/minilm-uncased-squad2",
|
||||
task_type="question_answering",
|
||||
batch_size=4,
|
||||
num_processes=0,
|
||||
multithreading_rust=False,
|
||||
use_fast=True # TODO parametrize this to test slow as well
|
||||
)
|
||||
return model
|
||||
|
||||
72
test/test_modeling_inference.py
Normal file
72
test/test_modeling_inference.py
Normal file
@ -0,0 +1,72 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multiprocessing_chunksize", [None, 2])
|
||||
@pytest.mark.parametrize("num_processes", [2, 0, None], scope="module")
|
||||
def test_qa_format_and_results(adaptive_model_qa, multiprocessing_chunksize):
|
||||
qa_inputs_dicts = [
|
||||
{
|
||||
"questions": ["In what country is Normandy"],
|
||||
"text": "The Normans are an ethnic group that arose in Normandy, a northern region "
|
||||
"of France, from contact between Viking settlers and indigenous Franks and Gallo-Romans",
|
||||
},
|
||||
{
|
||||
"questions": ["Who counted the game among the best ever made?"],
|
||||
"text": "Twilight Princess was released to universal critical acclaim and commercial success. It received "
|
||||
"perfect scores from major publications such as 1UP.com, Computer and Video Games, Electronic "
|
||||
"Gaming Monthly, Game Informer, GamesRadar, and GameSpy. On the review aggregators GameRankings "
|
||||
"and Metacritic, Twilight Princess has average scores of 95% and 95 for the Wii version and scores "
|
||||
"of 95% and 96 for the GameCube version. GameTrailers in their review called it one of the "
|
||||
"greatest games ever created.",
|
||||
},
|
||||
]
|
||||
ground_truths = ["France", "GameTrailers"]
|
||||
|
||||
results = adaptive_model_qa.inference_from_dicts(
|
||||
dicts=qa_inputs_dicts,
|
||||
multiprocessing_chunksize=multiprocessing_chunksize,
|
||||
)
|
||||
# sample results
|
||||
# [
|
||||
# {
|
||||
# "task": "qa",
|
||||
# "predictions": [
|
||||
# {
|
||||
# "question": "In what country is Normandy",
|
||||
# "question_id": "None",
|
||||
# "ground_truth": None,
|
||||
# "answers": [
|
||||
# {
|
||||
# "score": 1.1272038221359253,
|
||||
# "probability": -1,
|
||||
# "answer": "France",
|
||||
# "offset_answer_start": 54,
|
||||
# "offset_answer_end": 60,
|
||||
# "context": "The Normans gave their name to Normandy, a region in France.",
|
||||
# "offset_context_start": 0,
|
||||
# "offset_context_end": 60,
|
||||
# "document_id": None,
|
||||
# }
|
||||
# ]
|
||||
# }
|
||||
# ],
|
||||
# }
|
||||
# ]
|
||||
predictions = list(results)[0]["predictions"]
|
||||
|
||||
for prediction, ground_truth, qa_input_dict in zip(
|
||||
predictions, ground_truths, qa_inputs_dicts
|
||||
):
|
||||
assert prediction["question"] == qa_input_dict["questions"][0]
|
||||
answer = prediction["answers"][0]
|
||||
assert answer["answer"] in answer["context"]
|
||||
assert answer["answer"] == ground_truth
|
||||
assert (
|
||||
{"answer", "score", "probability", "offset_answer_start", "offset_answer_end", "context",
|
||||
"offset_context_start", "offset_context_end", "document_id"}
|
||||
== answer.keys()
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_qa_format_and_results()
|
||||
233
test/test_modeling_question_answering.py
Normal file
233
test/test_modeling_question_answering.py
Normal file
@ -0,0 +1,233 @@
|
||||
import logging
|
||||
import pytest
|
||||
from math import isclose
|
||||
import numpy as np
|
||||
|
||||
from haystack.modeling.infer import QAInferencer
|
||||
from haystack.modeling.data_handler.inputs import QAInput, Question
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def span_inference_result(bert_base_squad2, caplog=None):
|
||||
if caplog:
|
||||
caplog.set_level(logging.CRITICAL)
|
||||
obj_input = [QAInput(doc_text="Twilight Princess was released to universal critical acclaim and commercial success. It received perfect scores from major publications such as 1UP.com, Computer and Video Games, Electronic Gaming Monthly, Game Informer, GamesRadar, and GameSpy. On the review aggregators GameRankings and Metacritic, Twilight Princess has average scores of 95% and 95 for the Wii version and scores of 95% and 96 for the GameCube version. GameTrailers in their review called it one of the greatest games ever created.",
|
||||
questions=Question("Who counted the game among the best ever made?", uid="best_id_ever"))]
|
||||
result = bert_base_squad2.inference_from_objects(obj_input, return_json=False)[0]
|
||||
return result
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def no_answer_inference_result(bert_base_squad2, caplog=None):
|
||||
if caplog:
|
||||
caplog.set_level(logging.CRITICAL)
|
||||
obj_input = [QAInput(doc_text="The majority of the forest is contained within Brazil, with 60% of the rainforest, followed by Peru with 13%, Colombia with 10%, and with minor amounts in Venezuela, Ecuador, Bolivia, Guyana, Suriname and French Guiana. States or departments in four nations contain \"Amazonas\" in their names. The Amazon represents over half of the planet's remaining rainforests, and comprises the largest and most biodiverse tract of tropical rainforest in the world, with an estimated 390 billion individual trees divided into 16,000 species.",
|
||||
questions=Question("The Amazon represents less than half of the planets remaining what?", uid="best_id_ever"))]
|
||||
result = bert_base_squad2.inference_from_objects(obj_input, return_json=False)[0]
|
||||
return result
|
||||
|
||||
|
||||
def test_inference_different_inputs(bert_base_squad2):
|
||||
qa_format_1 = [
|
||||
{
|
||||
"questions": ["Who counted the game among the best ever made?"],
|
||||
"text": "Twilight Princess was released to universal critical acclaim and commercial success. It received perfect scores from major publications such as 1UP.com, Computer and Video Games, Electronic Gaming Monthly, Game Informer, GamesRadar, and GameSpy. On the review aggregators GameRankings and Metacritic, Twilight Princess has average scores of 95% and 95 for the Wii version and scores of 95% and 96 for the GameCube version. GameTrailers in their review called it one of the greatest games ever created."
|
||||
}]
|
||||
q = Question(text="Who counted the game among the best ever made?")
|
||||
qa_format_2 = QAInput(questions=[q],doc_text= "Twilight Princess was released to universal critical acclaim and commercial success. It received perfect scores from major publications such as 1UP.com, Computer and Video Games, Electronic Gaming Monthly, Game Informer, GamesRadar, and GameSpy. On the review aggregators GameRankings and Metacritic, Twilight Princess has average scores of 95% and 95 for the Wii version and scores of 95% and 96 for the GameCube version. GameTrailers in their review called it one of the greatest games ever created.")
|
||||
|
||||
|
||||
result1 = bert_base_squad2.inference_from_dicts(dicts=qa_format_1)
|
||||
result2 = bert_base_squad2.inference_from_objects(objects=[qa_format_2])
|
||||
assert result1 == result2
|
||||
|
||||
|
||||
def test_span_inference_result_ranking_by_confidence(bert_base_squad2, caplog=None):
|
||||
if caplog:
|
||||
caplog.set_level(logging.CRITICAL)
|
||||
obj_input = [QAInput(doc_text="Twilight Princess was released to universal critical acclaim and commercial success. It received perfect scores from major publications such as 1UP.com, Computer and Video Games, Electronic Gaming Monthly, Game Informer, GamesRadar, and GameSpy. On the review aggregators GameRankings and Metacritic, Twilight Princess has average scores of 95% and 95 for the Wii version and scores of 95% and 96 for the GameCube version. GameTrailers in their review called it one of the greatest games ever created.",
|
||||
questions=Question("Who counted the game among the best ever made?", uid="best_id_ever"))]
|
||||
result = bert_base_squad2.inference_from_objects(obj_input, return_json=False)[0]
|
||||
|
||||
# by default, result is sorted by score and not by confidence
|
||||
assert all(result.prediction[i].score >= result.prediction[i + 1].score for i in range(len(result.prediction) - 1))
|
||||
assert not all(result.prediction[i].confidence >= result.prediction[i + 1].confidence for i in range(len(result.prediction) - 1))
|
||||
|
||||
# ranking can be adjusted so that result is sorted by confidence
|
||||
bert_base_squad2.model.prediction_heads[0].use_confidence_scores_for_ranking = True
|
||||
result_ranked_by_confidence = bert_base_squad2.inference_from_objects(obj_input, return_json=False)[0]
|
||||
assert all(result_ranked_by_confidence.prediction[i].confidence >= result_ranked_by_confidence.prediction[i + 1].confidence for i in range(len(result_ranked_by_confidence.prediction) - 1))
|
||||
assert not all(result_ranked_by_confidence.prediction[i].score >= result_ranked_by_confidence.prediction[i + 1].score for i in range(len(result_ranked_by_confidence.prediction) - 1))
|
||||
|
||||
|
||||
def test_inference_objs(span_inference_result, caplog=None):
|
||||
if caplog:
|
||||
caplog.set_level(logging.CRITICAL)
|
||||
|
||||
assert span_inference_result
|
||||
|
||||
|
||||
def test_span_performance(span_inference_result, caplog=None):
|
||||
if caplog:
|
||||
caplog.set_level(logging.CRITICAL)
|
||||
|
||||
best_pred = span_inference_result.prediction[0]
|
||||
|
||||
assert best_pred.answer == "GameTrailers"
|
||||
|
||||
best_score_gold = 13.4205
|
||||
best_score = best_pred.score
|
||||
assert isclose(best_score, best_score_gold, rel_tol=0.001)
|
||||
|
||||
no_answer_gap_gold = 13.9827
|
||||
no_answer_gap = span_inference_result.no_answer_gap
|
||||
assert isclose(no_answer_gap, no_answer_gap_gold, rel_tol=0.001)
|
||||
|
||||
|
||||
def test_no_answer_performance(no_answer_inference_result, caplog=None):
|
||||
if caplog:
|
||||
caplog.set_level(logging.CRITICAL)
|
||||
best_pred = no_answer_inference_result.prediction[0]
|
||||
|
||||
assert best_pred.answer == "no_answer"
|
||||
|
||||
best_score_gold = 12.1445
|
||||
best_score = best_pred.score
|
||||
assert isclose(best_score, best_score_gold, rel_tol=0.001)
|
||||
|
||||
no_answer_gap_gold = -14.4646
|
||||
no_answer_gap = no_answer_inference_result.no_answer_gap
|
||||
assert isclose(no_answer_gap, no_answer_gap_gold, rel_tol=0.001)
|
||||
|
||||
|
||||
def test_qa_pred_attributes(span_inference_result, caplog=None):
|
||||
if caplog:
|
||||
caplog.set_level(logging.CRITICAL)
|
||||
|
||||
qa_pred = span_inference_result
|
||||
attributes_gold = ['aggregation_level', 'answer_types', 'context', 'context_window_size', 'ground_truth_answer',
|
||||
'id', 'n_passages', 'no_answer_gap', 'prediction', 'question', 'to_json',
|
||||
'to_squad_eval', 'token_offsets']
|
||||
|
||||
for ag in attributes_gold:
|
||||
assert ag in dir(qa_pred)
|
||||
|
||||
|
||||
def test_qa_candidate_attributes(span_inference_result, caplog=None):
|
||||
if caplog:
|
||||
caplog.set_level(logging.CRITICAL)
|
||||
|
||||
qa_candidate = span_inference_result.prediction[0]
|
||||
attributes_gold = ['aggregation_level', 'answer', 'answer_support', 'answer_type', 'context_window',
|
||||
'n_passages_in_doc', 'offset_answer_end', 'offset_answer_start', 'offset_answer_support_end',
|
||||
'offset_answer_support_start', 'offset_context_window_end', 'offset_context_window_start',
|
||||
'offset_unit', 'passage_id', 'probability', 'score', 'set_answer_string', 'set_context_window',
|
||||
'to_doc_level', 'to_list']
|
||||
|
||||
for ag in attributes_gold:
|
||||
assert ag in dir(qa_candidate)
|
||||
|
||||
|
||||
def test_id(span_inference_result, no_answer_inference_result):
|
||||
assert span_inference_result.id == "best_id_ever"
|
||||
assert no_answer_inference_result.id == "best_id_ever"
|
||||
|
||||
|
||||
def test_duplicate_answer_filtering(bert_base_squad2):
|
||||
qa_input = [
|
||||
{
|
||||
"questions": ["“In what country lies the Normandy?”"],
|
||||
"text": """The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse (\"Norman\" comes from \"Norseman\")
|
||||
raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their descendants would gradually merge with the Carolingian-based cultures of West Francia.
|
||||
The distinct cultural and ethnic identity of the Normans emerged initially in the first half of the 10th century, and it continued to evolve over the succeeding centuries. Weird things happen in Normandy, France."""
|
||||
}]
|
||||
|
||||
bert_base_squad2.model.prediction_heads[0].n_best = 5
|
||||
bert_base_squad2.model.prediction_heads[0].n_best_per_sample = 5
|
||||
bert_base_squad2.model.prediction_heads[0].duplicate_filtering = 0
|
||||
|
||||
result = bert_base_squad2.inference_from_dicts(dicts=qa_input)
|
||||
offset_answer_starts = []
|
||||
offset_answer_ends = []
|
||||
for answer in result[0]["predictions"][0]["answers"]:
|
||||
offset_answer_starts.append(answer["offset_answer_start"])
|
||||
offset_answer_ends.append(answer["offset_answer_end"])
|
||||
|
||||
assert len(offset_answer_starts) == len(set(offset_answer_starts))
|
||||
assert len(offset_answer_ends) == len(set(offset_answer_ends))
|
||||
|
||||
|
||||
def test_no_duplicate_answer_filtering(bert_base_squad2):
|
||||
qa_input = [
|
||||
{
|
||||
"questions": ["“In what country lies the Normandy?”"],
|
||||
"text": """The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse (\"Norman\" comes from \"Norseman\")
|
||||
raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their descendants would gradually merge with the Carolingian-based cultures of West Francia.
|
||||
The distinct cultural and ethnic identity of the Normans emerged initially in the first half of the 10th century, and it continued to evolve over the succeeding centuries. Weird things happen in Normandy, France."""
|
||||
}]
|
||||
|
||||
bert_base_squad2.model.prediction_heads[0].n_best = 5
|
||||
bert_base_squad2.model.prediction_heads[0].n_best_per_sample = 5
|
||||
bert_base_squad2.model.prediction_heads[0].duplicate_filtering = -1
|
||||
|
||||
result = bert_base_squad2.inference_from_dicts(dicts=qa_input)
|
||||
offset_answer_starts = []
|
||||
offset_answer_ends = []
|
||||
for answer in result[0]["predictions"][0]["answers"]:
|
||||
offset_answer_starts.append(answer["offset_answer_start"])
|
||||
offset_answer_ends.append(answer["offset_answer_end"])
|
||||
|
||||
assert len(offset_answer_starts) != len(set(offset_answer_starts))
|
||||
assert len(offset_answer_ends) != len(set(offset_answer_ends))
|
||||
|
||||
|
||||
def test_range_duplicate_answer_filtering(bert_base_squad2):
|
||||
qa_input = [
|
||||
{
|
||||
"questions": ["“In what country lies the Normandy?”"],
|
||||
"text": """The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse (\"Norman\" comes from \"Norseman\")
|
||||
raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their descendants would gradually merge with the Carolingian-based cultures of West Francia.
|
||||
The distinct cultural and ethnic identity of the Normans emerged initially in the first half of the 10th century, and it continued to evolve over the succeeding centuries. Weird things happen in Normandy, France."""
|
||||
}]
|
||||
|
||||
bert_base_squad2.model.prediction_heads[0].n_best = 5
|
||||
bert_base_squad2.model.prediction_heads[0].n_best_per_sample = 5
|
||||
bert_base_squad2.model.prediction_heads[0].duplicate_filtering = 5
|
||||
|
||||
result = bert_base_squad2.inference_from_dicts(dicts=qa_input)
|
||||
offset_answer_starts = []
|
||||
offset_answer_ends = []
|
||||
for answer in result[0]["predictions"][0]["answers"]:
|
||||
offset_answer_starts.append(answer["offset_answer_start"])
|
||||
offset_answer_ends.append(answer["offset_answer_end"])
|
||||
|
||||
offset_answer_starts.sort()
|
||||
offset_answer_starts.remove(0)
|
||||
distances_answer_starts = [j-i for i, j in zip(offset_answer_starts[:-1],offset_answer_starts[1:])]
|
||||
assert all(distance > bert_base_squad2.model.prediction_heads[0].duplicate_filtering for distance in distances_answer_starts)
|
||||
|
||||
offset_answer_ends.sort()
|
||||
offset_answer_ends.remove(0)
|
||||
distances_answer_ends = [j-i for i, j in zip(offset_answer_ends[:-1], offset_answer_ends[1:])]
|
||||
assert all(distance > bert_base_squad2.model.prediction_heads[0].duplicate_filtering for distance in distances_answer_ends)
|
||||
|
||||
|
||||
def test_qa_confidence():
|
||||
inferencer = QAInferencer.load("deepset/roberta-base-squad2", task_type="question_answering", batch_size=40, gpu=True)
|
||||
QA_input = [
|
||||
{
|
||||
"questions": ["Who counted the game among the best ever made?"],
|
||||
"text": "Twilight Princess was released to universal critical acclaim and commercial success. It received perfect scores from major publications such as 1UP.com, Computer and Video Games, Electronic Gaming Monthly, Game Informer, GamesRadar, and GameSpy. On the review aggregators GameRankings and Metacritic, Twilight Princess has average scores of 95% and 95 for the Wii version and scores of 95% and 96 for the GameCube version. GameTrailers in their review called it one of the greatest games ever created."
|
||||
}]
|
||||
result = inferencer.inference_from_dicts(dicts=QA_input, return_json=False)[0]
|
||||
assert np.isclose(result.prediction[0].confidence, 0.990427553653717)
|
||||
assert result.prediction[0].answer == "GameTrailers"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_inference_different_inputs()
|
||||
test_inference_objs()
|
||||
test_duplicate_answer_filtering()
|
||||
test_no_duplicate_answer_filtering()
|
||||
test_range_duplicate_answer_filtering()
|
||||
test_qa_confidence()
|
||||
Loading…
x
Reference in New Issue
Block a user