FlagEmbedding/FlagEmbedding_old/flag_reranker.py
2024-10-16 19:14:36 +08:00

811 lines
38 KiB
Python

from typing import Union, List, Tuple, Any
import numpy as np
import torch
from torch import Tensor
from torch.utils.data import DataLoader
from tqdm import tqdm, trange
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification, is_torch_npu_available
from peft import PeftModel
import warnings
from torch.utils.data import Dataset
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
class DatasetForReranker(Dataset):
def __init__(
self,
dataset,
tokenizer_path: str,
max_len: int = 512,
query_prefix: str = 'A: ',
passage_prefix: str = 'B: ',
cache_dir: str = None,
prompt: str = None
):
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path,
trust_remote_code=True,
cache_dir=cache_dir)
self.dataset = dataset
self.max_len = max_len
self.query_prefix = query_prefix
self.passage_prefix = passage_prefix
self.total_len = len(self.dataset)
if prompt is None:
prompt = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'."
self.prompt_inputs = self.tokenizer(prompt,
return_tensors=None,
add_special_tokens=False)['input_ids']
sep = "\n"
self.sep_inputs = self.tokenizer(sep,
return_tensors=None,
add_special_tokens=False)['input_ids']
self.encode_max_length = self.max_len + len(self.sep_inputs) + len(self.prompt_inputs)
def __len__(self):
return self.total_len
def __getitem__(self, item):
query, passage = self.dataset[item]
query = self.query_prefix + query
passage = self.passage_prefix + passage
query_inputs = self.tokenizer(query,
return_tensors=None,
add_special_tokens=False,
max_length=self.max_len * 3 // 4,
truncation=True)
passage_inputs = self.tokenizer(passage,
return_tensors=None,
add_special_tokens=False,
max_length=self.max_len,
truncation=True)
item = self.tokenizer.prepare_for_model(
[self.tokenizer.bos_token_id] + query_inputs['input_ids'],
self.sep_inputs + passage_inputs['input_ids'],
truncation='only_second',
max_length=self.encode_max_length,
padding=False,
return_attention_mask=False,
return_token_type_ids=False,
add_special_tokens=False
)
item['input_ids'] = item['input_ids'] + self.sep_inputs + self.prompt_inputs
item['attention_mask'] = [1] * len(item['input_ids'])
item.pop('token_type_ids') if 'token_type_ids' in item.keys() else None
if 'position_ids' in item.keys():
item['position_ids'] = list(range(len(item['input_ids'])))
return item
class collater():
def __init__(self, tokenizer, max_len):
self.tokenizer = tokenizer
self.max_len = max_len
self.pad_to_multiple_of = 8
self.label_pad_token_id = -100
warnings.filterwarnings("ignore",
message="`max_length` is ignored when `padding`=`True` and there is no truncation strategy.")
def __call__(self, data):
labels = [feature["labels"] for feature in data] if "labels" in data[0].keys() else None
# We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
# same length to return tensors.
if labels is not None:
max_label_length = max(len(l) for l in labels)
if self.pad_to_multiple_of is not None:
max_label_length = (
(max_label_length + self.pad_to_multiple_of - 1)
// self.pad_to_multiple_of
* self.pad_to_multiple_of
)
padding_side = self.tokenizer.padding_side
for feature in data:
remainder = [self.label_pad_token_id] * (max_label_length - len(feature["labels"]))
if isinstance(feature["labels"], list):
feature["labels"] = (
feature["labels"] + remainder if padding_side == "right" else remainder + feature["labels"]
)
elif padding_side == "right":
feature["labels"] = np.concatenate([feature["labels"], remainder]).astype(np.int64)
else:
feature["labels"] = np.concatenate([remainder, feature["labels"]]).astype(np.int64)
return self.tokenizer.pad(
data,
padding=True,
max_length=self.max_len,
pad_to_multiple_of=8,
return_tensors='pt',
)
class collater_for_lightweight():
def __init__(self, tokenizer, max_len):
self.tokenizer = tokenizer
self.max_len = max_len
self.pad_to_multiple_of = 8
self.label_pad_token_id = -100
warnings.filterwarnings("ignore",
message="`max_length` is ignored when `padding`=`True` and there is no truncation strategy.")
def __call__(self, data):
features = data[0]
query_lengths = data[1]
prompt_lengths = data[2]
labels = [feature["labels"] for feature in features] if "labels" in features[0].keys() else None
# We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
# same length to return tensors.
if labels is not None:
max_label_length = max(len(l) for l in labels)
if self.pad_to_multiple_of is not None:
max_label_length = (
(max_label_length + self.pad_to_multiple_of - 1)
// self.pad_to_multiple_of
* self.pad_to_multiple_of
)
padding_side = self.tokenizer.padding_side
for feature in features:
remainder = [self.label_pad_token_id] * (max_label_length - len(feature["labels"]))
if isinstance(feature["labels"], list):
feature["labels"] = (
feature["labels"] + remainder if padding_side == "right" else remainder + feature["labels"]
)
elif padding_side == "right":
feature["labels"] = np.concatenate([feature["labels"], remainder]).astype(np.int64)
else:
feature["labels"] = np.concatenate([remainder, feature["labels"]]).astype(np.int64)
collected = self.tokenizer.pad(
features,
padding=True,
max_length=self.max_len,
pad_to_multiple_of=8,
return_tensors='pt',
)
return collected, query_lengths, prompt_lengths
def last_logit_pool(logits: Tensor,
attention_mask: Tensor) -> Tensor:
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return logits[:, -1, :]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = logits.shape[0]
return torch.stack([logits[i, sequence_lengths[i], :] for i in range(batch_size)], dim=0)
def last_logit_pool_layerwise(logits: Tensor,
attention_mask: Tensor) -> Tensor:
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return logits[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = logits.shape[0]
return logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
def last_logit_pool_lightweight(logits: Tensor,
attention_mask: Tensor) -> Tensor:
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return logits[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = logits.shape[0]
return torch.stack([logits[i, sequence_lengths[i]] for i in range(batch_size)], dim=0)
def sigmoid(x):
return 1 / (1 + np.exp(-x))
class FlagReranker:
def __init__(
self,
model_name_or_path: str = None,
use_fp16: bool = False,
cache_dir: str = None,
device: Union[str, int] = None
) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, cache_dir=cache_dir)
if device and isinstance(device, str):
self.device = torch.device(device)
if device == 'cpu':
use_fp16 = False
else:
if torch.cuda.is_available():
if device is not None:
self.device = torch.device(f"cuda:{device}")
else:
self.device = torch.device("cuda")
elif torch.backends.mps.is_available():
self.device = torch.device("mps")
elif is_torch_npu_available():
self.device = torch.device("npu")
else:
self.device = torch.device("cpu")
use_fp16 = False
if use_fp16:
self.model.half()
self.model = self.model.to(self.device)
self.model.eval()
if device is None:
self.num_gpus = torch.cuda.device_count()
if self.num_gpus > 1:
print(f"----------using {self.num_gpus}*GPUs----------")
self.model = torch.nn.DataParallel(self.model)
else:
self.num_gpus = 1
@torch.no_grad()
def compute_score(self, sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]], batch_size: int = 256,
max_length: int = 512, normalize: bool = False) -> List[float]:
if self.num_gpus > 0:
batch_size = batch_size * self.num_gpus
assert isinstance(sentence_pairs, list)
if isinstance(sentence_pairs[0], str):
sentence_pairs = [sentence_pairs]
all_scores = []
for start_index in tqdm(range(0, len(sentence_pairs), batch_size), desc="Compute Scores",
disable=len(sentence_pairs) < 128):
sentences_batch = sentence_pairs[start_index:start_index + batch_size]
inputs = self.tokenizer(
sentences_batch,
padding=True,
truncation=True,
return_tensors='pt',
max_length=max_length,
).to(self.device)
scores = self.model(**inputs, return_dict=True).logits.view(-1, ).float()
all_scores.extend(scores.cpu().numpy().tolist())
if normalize:
all_scores = [sigmoid(score) for score in all_scores]
return all_scores
class FlagLLMReranker:
def __init__(
self,
model_name_or_path: str = None,
peft_path: str = None,
use_fp16: bool = False,
use_bf16: bool = False,
cache_dir: str = None,
device: Union[str, int] = None
) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,
cache_dir=cache_dir,
trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
cache_dir=cache_dir,
trust_remote_code=True,
torch_dtype=torch.bfloat16 if use_bf16 else torch.float32)
if peft_path:
self.model = PeftModel.from_pretrained(self.model,peft_path)
self.model = self.model.merge_and_unload()
self.model_name_or_path = model_name_or_path
self.cache_dir = cache_dir
if device and isinstance(device, str):
self.device = torch.device(device)
else:
device = 0 if device is None else device
if torch.cuda.is_available():
torch.cuda.set_device(device)
self.device = torch.device("cuda")
elif torch.backends.mps.is_available():
self.device = torch.device("mps")
elif is_torch_npu_available():
self.device = torch.device("npu")
else:
self.device = torch.device("cpu")
use_fp16 = False
if use_fp16 and use_bf16 is False:
self.model.half()
self.model = self.model.to(self.device)
self.model.eval()
self.yes_loc = self.tokenizer('Yes', add_special_tokens=False)['input_ids'][0]
@torch.no_grad()
def compute_score(self, sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]], batch_size: int = 16,
max_length: int = 512, prompt: str = None, normalize: bool = False,
use_dataloader: bool = False, num_workers: int = None) -> List[float]:
assert isinstance(sentence_pairs, list)
if isinstance(sentence_pairs[0], str):
sentence_pairs = [sentence_pairs]
length_sorted_idx = np.argsort([-self._text_length(q) - self._text_length(p) for q, p in sentence_pairs])
sentences_sorted = [sentence_pairs[idx] for idx in length_sorted_idx]
dataset, dataloader = None, None
if use_dataloader:
if num_workers is None:
num_workers = min(batch_size, 16)
dataset = DatasetForReranker(sentences_sorted,
self.model_name_or_path,
max_length,
cache_dir=self.cache_dir,
prompt=prompt)
dataloader = DataLoader(dataset, shuffle=False, batch_size=batch_size, drop_last=False,
num_workers=num_workers,
collate_fn=collater(self.tokenizer, max_length))
all_scores = []
if dataloader is not None:
for inputs in tqdm(dataloader):
inputs = inputs.to(self.device)
outputs = self.model(**inputs, output_hidden_states=True)
logits = outputs.logits
scores = last_logit_pool(logits, inputs['attention_mask'])
scores = scores[:, self.yes_loc]
all_scores.extend(scores.cpu().float().tolist())
else:
if prompt is None:
prompt = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'."
prompt_inputs = self.tokenizer(prompt,
return_tensors=None,
add_special_tokens=False)['input_ids']
sep = "\n"
sep_inputs = self.tokenizer(sep,
return_tensors=None,
add_special_tokens=False)['input_ids']
encode_max_length = max_length + len(sep_inputs) + len(prompt_inputs)
for batch_start in trange(0, len(sentences_sorted), batch_size):
batch_sentences = sentences_sorted[batch_start:batch_start + batch_size]
batch_sentences = [(f'A: {q}', f'B: {p}') for q,p in batch_sentences]
queries = [s[0] for s in batch_sentences]
passages = [s[1] for s in batch_sentences]
queries_inputs = self.tokenizer(queries,
return_tensors=None,
add_special_tokens=False,
max_length=max_length * 3 // 4,
truncation=True)
passages_inputs = self.tokenizer(passages,
return_tensors=None,
add_special_tokens=False,
max_length=max_length,
truncation=True)
batch_inputs = []
for query_inputs, passage_inputs in zip(queries_inputs['input_ids'], passages_inputs['input_ids']):
item = self.tokenizer.prepare_for_model(
[self.tokenizer.bos_token_id] + query_inputs,
sep_inputs + passage_inputs,
truncation='only_second',
max_length=encode_max_length,
padding=False,
return_attention_mask=False,
return_token_type_ids=False,
add_special_tokens=False
)
item['input_ids'] = item['input_ids'] + sep_inputs + prompt_inputs
item['attention_mask'] = [1] * len(item['input_ids'])
item.pop('token_type_ids') if 'token_type_ids' in item.keys() else None
if 'position_ids' in item.keys():
item['position_ids'] = list(range(len(item['input_ids'])))
batch_inputs.append(item)
collater_instance = collater(self.tokenizer, max_length)
batch_inputs = collater_instance(
[{'input_ids': item['input_ids'], 'attention_mask': item['attention_mask']} for item in
batch_inputs])
batch_inputs = {key: val.to(self.device) for key, val in batch_inputs.items()}
outputs = self.model(**batch_inputs, output_hidden_states=True)
logits = outputs.logits
scores = last_logit_pool(logits, batch_inputs['attention_mask'])
scores = scores[:, self.yes_loc]
all_scores.extend(scores.cpu().float().tolist())
all_scores = [all_scores[idx] for idx in np.argsort(length_sorted_idx)]
if normalize:
all_scores = [sigmoid(score) for score in all_scores]
# if len(all_scores) == 1:
# return all_scores[0]
return all_scores
def _text_length(self, text: Union[List[int], List[List[int]]]):
"""
Help function to get the length for the input text. Text can be either
a list of ints (which means a single text as input), or a tuple of list of ints
(representing several text inputs to the model).
"""
if isinstance(text, dict): # {key: value} case
return len(next(iter(text.values())))
elif not hasattr(text, '__len__'): # Object has no len() method
return 1
elif len(text) == 0 or isinstance(text[0], int): # Empty string or list of ints
return len(text)
else:
return sum([len(t) for t in text]) # Sum of length of individual strings
class LayerWiseFlagLLMReranker:
def __init__(
self,
model_name_or_path: str = None,
peft_path: str = None,
use_fp16: bool = False,
use_bf16: bool = False,
cache_dir: str = None,
device: Union[str, int] = None
) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,
cache_dir=cache_dir,
trust_remote_code=True)
if use_bf16 is False and use_fp16 is False:
warnings.warn("Due to model constraints, `use_bf16` and `use_fp16` cannot both be `False`. Here, `use_fp16` is set to `True` by default.", UserWarning)
use_fp16 = True
self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
cache_dir=cache_dir,
trust_remote_code=True,
local_files_only=True,
torch_dtype=torch.bfloat16 if use_bf16 else torch.float32)
if peft_path:
self.model = PeftModel.from_pretrained(self.model,peft_path)
self.model = self.model.merge_and_unload()
self.model_name_or_path = model_name_or_path
self.cache_dir = cache_dir
if device and isinstance(device, str):
if device == 'cpu':
warnings.warn('The LLM-based layer-wise reranker does not support CPU; it has been set to CUDA.')
device = 'cuda'
self.device = torch.device(device)
else:
device = 0 if device is None else device
if torch.cuda.is_available():
torch.cuda.set_device(device)
self.device = torch.device("cuda")
elif torch.backends.mps.is_available():
self.device = torch.device("mps")
elif is_torch_npu_available():
self.device = torch.device("npu")
else:
self.device = torch.device("cpu")
use_fp16 = False
if use_fp16 and use_bf16 is False:
self.model.half()
self.model = self.model.to(self.device)
self.model.eval()
self.yes_loc = self.tokenizer('Yes', add_special_tokens=False)['input_ids'][0]
@torch.no_grad()
def compute_score(self, sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]], batch_size: int = 16,
max_length: int = 512, cutoff_layers: List[int] = None, prompt: str = None,
normalize: bool = False, use_dataloader: bool = False,
num_workers: int = None) -> Union[float, List[float], List[List[float]]]:
assert isinstance(sentence_pairs, list)
if isinstance(sentence_pairs[0], str):
sentence_pairs = [sentence_pairs]
length_sorted_idx = np.argsort([-self._text_length(q) - self._text_length(p) for q, p in sentence_pairs])
sentences_sorted = [sentence_pairs[idx] for idx in length_sorted_idx]
dataset, dataloader = None, None
if use_dataloader:
if num_workers is None:
num_workers = min(batch_size, 16)
dataset = DatasetForReranker(sentences_sorted,
self.model_name_or_path,
max_length,
cache_dir=self.cache_dir,
prompt=prompt)
dataloader = DataLoader(dataset, shuffle=False, batch_size=batch_size, drop_last=False,
num_workers=num_workers,
collate_fn=collater(self.tokenizer, max_length))
all_scores = []
if dataloader is not None:
for inputs in tqdm(dataloader):
inputs = inputs.to(self.device)
outputs = self.model(**inputs, output_hidden_states=True, cutoff_layers=cutoff_layers)
all_logits = outputs.logits
tmp_all_scores = []
for logits in all_logits:
scores = last_logit_pool_layerwise(logits, inputs['attention_mask'])
tmp_all_scores.append(scores.contiguous())
if len(all_scores) == 0:
for _ in range(len(tmp_all_scores)):
all_scores.append([])
for i in range(len(tmp_all_scores)):
all_scores[i].extend(tmp_all_scores[i].cpu().float().tolist())
else:
if prompt is None:
prompt = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'."
prompt_inputs = self.tokenizer(prompt,
return_tensors=None,
add_special_tokens=False)['input_ids']
sep = "\n"
sep_inputs = self.tokenizer(sep,
return_tensors=None,
add_special_tokens=False)['input_ids']
encode_max_length = max_length + len(sep_inputs) + len(prompt_inputs)
for batch_start in trange(0, len(sentences_sorted), batch_size):
batch_sentences = sentences_sorted[batch_start:batch_start + batch_size]
batch_sentences = [(f'A: {q}', f'B: {p}') for q, p in batch_sentences]
queries = [s[0] for s in batch_sentences]
passages = [s[1] for s in batch_sentences]
queries_inputs = self.tokenizer(queries,
return_tensors=None,
add_special_tokens=False,
max_length=max_length * 3 // 4,
truncation=True)
passages_inputs = self.tokenizer(passages,
return_tensors=None,
add_special_tokens=False,
max_length=max_length,
truncation=True)
batch_inputs = []
for query_inputs, passage_inputs in zip(queries_inputs['input_ids'], passages_inputs['input_ids']):
item = self.tokenizer.prepare_for_model(
[self.tokenizer.bos_token_id] + query_inputs,
sep_inputs + passage_inputs,
truncation='only_second',
max_length=encode_max_length,
padding=False,
return_attention_mask=False,
return_token_type_ids=False,
add_special_tokens=False
)
item['input_ids'] = item['input_ids'] + sep_inputs + prompt_inputs
item['attention_mask'] = [1] * len(item['input_ids'])
item.pop('token_type_ids') if 'token_type_ids' in item.keys() else None
if 'position_ids' in item.keys():
item['position_ids'] = list(range(len(item['input_ids'])))
batch_inputs.append(item)
collater_instance = collater(self.tokenizer, max_length)
batch_inputs = collater_instance(
[{'input_ids': item['input_ids'], 'attention_mask': item['attention_mask']} for item in
batch_inputs])
batch_inputs = {key: val.to(self.device) for key, val in batch_inputs.items()}
outputs = self.model(**batch_inputs, output_hidden_states=True, cutoff_layers=cutoff_layers)
all_logits = outputs.logits
tmp_all_scores = []
for logits in all_logits:
scores = last_logit_pool_layerwise(logits, batch_inputs['attention_mask'])
tmp_all_scores.append(scores.contiguous())
if len(all_scores) == 0:
for _ in range(len(tmp_all_scores)):
all_scores.append([])
for i in range(len(tmp_all_scores)):
all_scores[i].extend(tmp_all_scores[i].cpu().float().tolist())
for i in range(len(all_scores)):
all_scores[i] = [all_scores[i][idx] for idx in np.argsort(length_sorted_idx)]
if normalize:
all_scores[i] = [sigmoid(score) for score in all_scores[i]]
return all_scores
def _text_length(self, text: Union[List[int], List[List[int]]]):
"""
Help function to get the length for the input text. Text can be either
a list of ints (which means a single text as input), or a tuple of list of ints
(representing several text inputs to the model).
"""
if isinstance(text, dict): # {key: value} case
return len(next(iter(text.values())))
elif not hasattr(text, '__len__'): # Object has no len() method
return 1
elif len(text) == 0 or isinstance(text[0], int): # Empty string or list of ints
return len(text)
else:
return sum([len(t) for t in text]) # Sum of length of individual strings
class LightWeightFlagLLMReranker:
def __init__(
self,
model_name_or_path: str = None,
peft_path: str = None,
use_fp16: bool = False,
use_bf16: bool = False,
cache_dir: str = None,
device: Union[str, int] = None
) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,
cache_dir=cache_dir,
trust_remote_code=True)
self.tokenizer.padding_side = 'right'
if use_bf16 is False and use_fp16 is False:
warnings.warn("Due to model constraints, `use_bf16` and `use_fp16` cannot both be `False`. Here, `use_fp16` is set to `True` by default.", UserWarning)
use_fp16 = True
self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
cache_dir=cache_dir,
trust_remote_code=True,
local_files_only=True,
torch_dtype=torch.bfloat16 if use_bf16 else torch.float32)
if peft_path:
self.model = PeftModel.from_pretrained(self.model,peft_path)
self.model = self.model.merge_and_unload()
self.model_name_or_path = model_name_or_path
self.cache_dir = cache_dir
if device and isinstance(device, str):
if device == 'cpu':
warnings.warn('The LLM-based layer-wise reranker does not support CPU; it has been set to CUDA.')
device = 'cuda'
self.device = torch.device(device)
else:
device = 0 if device is None else device
if torch.cuda.is_available():
torch.cuda.set_device(device)
self.device = torch.device("cuda")
elif torch.backends.mps.is_available():
self.device = torch.device("mps")
elif is_torch_npu_available():
self.device = torch.device("npu")
else:
self.device = torch.device("cpu")
use_fp16 = False
if use_fp16 and use_bf16 is False:
self.model.half()
self.model = self.model.to(self.device)
self.model.eval()
self.yes_loc = self.tokenizer('Yes', add_special_tokens=False)['input_ids'][0]
@torch.no_grad()
def compute_score(self, sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]], batch_size: int = 16,
max_length: int = 512,
cutoff_layers: List[int] = None, compress_layer: List[int] = [8], compress_ratio: int = 1,
prompt: str = None, normalize: bool = False) -> Union[float, List[float], List[List[float]]]:
assert isinstance(sentence_pairs, list)
if isinstance(sentence_pairs[0], str):
sentence_pairs = [sentence_pairs]
length_sorted_idx = np.argsort([-self._text_length(q) - self._text_length(p) for q, p in sentence_pairs])
sentences_sorted = [sentence_pairs[idx] for idx in length_sorted_idx]
if prompt is None:
prompt = "Predict whether passage B contains an answer to query A."
prompt_inputs = self.tokenizer(prompt,
return_tensors=None,
add_special_tokens=False)['input_ids']
sep = "\n"
sep_inputs = self.tokenizer(sep,
return_tensors=None,
add_special_tokens=False)['input_ids']
encode_max_length = max_length + len(sep_inputs) + len(prompt_inputs)
all_scores = []
for batch_start in trange(0, len(sentences_sorted), batch_size):
batch_sentences = sentences_sorted[batch_start:batch_start + batch_size]
batch_sentences = [(f'A: {q}', f'B: {p}') for q, p in batch_sentences]
queries = [s[0] for s in batch_sentences]
passages = [s[1] for s in batch_sentences]
queries_inputs = self.tokenizer(queries,
return_tensors=None,
add_special_tokens=False,
max_length=max_length * 3 // 4,
truncation=True)
passages_inputs = self.tokenizer(passages,
return_tensors=None,
add_special_tokens=False,
max_length=max_length,
truncation=True)
query_lengths = []
prompt_lengths = []
batch_inputs = []
for query_inputs, passage_inputs in zip(queries_inputs['input_ids'], passages_inputs['input_ids']):
item = self.tokenizer.prepare_for_model(
[self.tokenizer.bos_token_id] + query_inputs,
sep_inputs + passage_inputs,
truncation='only_second',
max_length=encode_max_length,
padding=False,
return_attention_mask=False,
return_token_type_ids=False,
add_special_tokens=False
)
item['input_ids'] = item['input_ids'] + sep_inputs + prompt_inputs
item['attention_mask'] = [1] * len(item['input_ids'])
item.pop('token_type_ids') if 'token_type_ids' in item.keys() else None
if 'position_ids' in item.keys():
item['position_ids'] = list(range(len(item['input_ids'])))
batch_inputs.append(item)
query_lengths.append(len([self.tokenizer.bos_token_id] + query_inputs + sep_inputs))
prompt_lengths.append(len(sep_inputs + prompt_inputs))
collater_instance = collater_for_lightweight(self.tokenizer, max_length)
batch_inputs = collater_instance(
[
[{'input_ids': item['input_ids'], 'attention_mask': item['attention_mask']} for item in
batch_inputs],
query_lengths,
prompt_lengths
])[0]
batch_inputs = {key: val.to(self.device) for key, val in batch_inputs.items()}
outputs = self.model(**batch_inputs,
output_hidden_states=True,
compress_layer=compress_layer,
compress_ratio=compress_ratio,
query_lengths=query_lengths,
prompt_lengths=prompt_lengths,
cutoff_layers=cutoff_layers)
scores = []
for i in range(len(outputs.logits)):
logits = last_logit_pool_lightweight(outputs.logits[i], outputs.attention_masks[i])
scores.append(logits.cpu().float().tolist())
if len(all_scores) == 0:
for i in range(len(scores)):
all_scores.append([])
for i in range(len(scores)):
all_scores[i].extend(scores[i])
for i in range(len(all_scores)):
all_scores[i] = [all_scores[i][idx] for idx in np.argsort(length_sorted_idx)]
if normalize:
all_scores[i] = [sigmoid(score) for score in all_scores[i]]
return all_scores
def _text_length(self, text: Union[List[int], List[List[int]]]):
"""
Help function to get the length for the input text. Text can be either
a list of ints (which means a single text as input), or a tuple of list of ints
(representing several text inputs to the model).
"""
if isinstance(text, dict): # {key: value} case
return len(next(iter(text.values())))
elif not hasattr(text, '__len__'): # Object has no len() method
return 1
elif len(text) == 0 or isinstance(text[0], int): # Empty string or list of ints
return len(text)
else:
return sum([len(t) for t in text]) # Sum of length of individual strings