mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-06-27 02:39:58 +00:00
reranker init
This commit is contained in:
parent
00a42ccd4f
commit
f653727e83
@ -29,6 +29,10 @@ class AbsReranker(ABC):
|
||||
passage_instruction_for_rerank: str = None,
|
||||
passage_instruction_format: str = "{}{}", # specify the format of passage_instruction_for_rerank
|
||||
devices: Union[str, int, List[str], List[int]] = None,
|
||||
batch_size: int = 128,
|
||||
query_max_length: int = None,
|
||||
max_length: int = 512,
|
||||
normalize: bool = False,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self.model_name_or_path = model_name_or_path
|
||||
@ -38,6 +42,13 @@ class AbsReranker(ABC):
|
||||
self.passage_instruction_for_rerank = passage_instruction_for_rerank
|
||||
self.passage_instruction_format = passage_instruction_format
|
||||
self.target_devices = self.get_target_devices(devices)
|
||||
self.batch_size = batch_size
|
||||
self.query_max_length = query_max_length
|
||||
self.max_length = max_length
|
||||
self.normalize = normalize
|
||||
|
||||
for k in kwargs:
|
||||
setattr(self, k, kwargs[k])
|
||||
|
||||
self.kwargs = kwargs
|
||||
|
||||
@ -73,8 +84,8 @@ class AbsReranker(ABC):
|
||||
if isinstance(sentence_pairs, str):
|
||||
sentence_pairs = [sentence_pairs]
|
||||
|
||||
if self.query_instruction_format is not None:
|
||||
if self.passage_instruction_format is None:
|
||||
if self.query_instruction_for_rerank is not None:
|
||||
if self.passage_instruction_for_rerank is None:
|
||||
return [
|
||||
[
|
||||
self.get_detailed_instruct(self.query_instruction_format, self.query_instruction_for_rerank, sentence_pair[0]),
|
||||
@ -89,7 +100,7 @@ class AbsReranker(ABC):
|
||||
] for sentence_pair in sentence_pairs
|
||||
]
|
||||
else:
|
||||
if self.passage_instruction_format is None:
|
||||
if self.passage_instruction_for_rerank is None:
|
||||
return [
|
||||
[
|
||||
sentence_pair[0],
|
||||
@ -130,6 +141,7 @@ class AbsReranker(ABC):
|
||||
self,
|
||||
sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
|
||||
batch_size: int = 256,
|
||||
query_max_length: int = None,
|
||||
max_length: int = 512,
|
||||
normalize: bool = False,
|
||||
device: str = None,
|
||||
|
@ -1,11 +1,14 @@
|
||||
import mteb
|
||||
|
||||
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
from FlagEmbedding import FlagAutoModel, FlagAutoReranker
|
||||
from FlagEmbedding.abc.evaluation import AbsModelArgs, AbsEmbedder, AbsReranker, AbsEvaluator
|
||||
|
||||
|
||||
from utils.arguments import MSMARCOEvalArgs
|
||||
from utils.data_loader import MSMARCODataLoader
|
||||
from utils.arguments import MTEBEvalArgs
|
||||
from utils.prompts import get_task_def_by_task_name_and_type, tasks_desc
|
||||
|
||||
|
||||
def get_models(model_args: AbsModelArgs):
|
||||
@ -41,4 +44,37 @@ def get_models(model_args: AbsModelArgs):
|
||||
compress_layers=model_args.compress_layers,
|
||||
compress_ratio=model_args.compress_ratio,
|
||||
)
|
||||
return retriever, reranker
|
||||
return retriever, reranker
|
||||
|
||||
def main():
|
||||
parser = HfArgumentParser([AbsModelArgs, BEIREvalArgs])
|
||||
model_args, eval_args = parser.parse_args_into_dataclasses()
|
||||
model_args: AbsModelArgs
|
||||
eval_args: BEIREvalArgs
|
||||
|
||||
retriever, reranker = get_models(model_args)
|
||||
|
||||
task_types = eval_args.task_types
|
||||
tasks = eval_args.tasks
|
||||
languages = eval_args.languages
|
||||
tasks = mteb.get_tasks(
|
||||
languages=languages,
|
||||
tasks=tasks,
|
||||
task_types=task_types
|
||||
)
|
||||
evaluation = mteb.MTEB(tasks=tasks)
|
||||
results = evaluation.run(retriever, output_folder=f"results/{str(retriever)}")
|
||||
|
||||
# all_pairs = []
|
||||
# for task_type in eval_args.task_types:
|
||||
# if task_type in tasks_desc.keys():
|
||||
# for task_name in tasks_desc[task_type]:
|
||||
# all_pairs.append((task_type, task_name))
|
||||
# for task_type in tasks_desc.keys():
|
||||
# for v in tasks_desc[task_type]:
|
||||
# if v in eval_args.task_types:
|
||||
# all_pairs.append((task_type, v))
|
||||
# all_pairs = list(set(all_pairs))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -38,22 +38,40 @@ parser.add_argument('--batch_size', default=32, type=int)
|
||||
parser.add_argument('--examples-dir', default='/share/chaofan/code/embedder/evaluate_for_icl/examples', type=str)
|
||||
parser.add_argument('--eight-special-token', default=False, type=bool)
|
||||
parser.add_argument('--passage-prompt', default=False, type=bool)
|
||||
parser.add_argument('--pool-type', default='last', type=bool)
|
||||
|
||||
args = parser.parse_args()
|
||||
base_name: str = args.model_name_or_path.split('/')[-1]
|
||||
if args.eight_special_token is True:
|
||||
args.pool_type = 'last_eight'
|
||||
else:
|
||||
args.pool_type = 'last'
|
||||
# if args.eight_special_token is True:
|
||||
# args.pool_type = 'last_eight'
|
||||
# else:
|
||||
# args.pool_type = 'last'
|
||||
|
||||
logger.info('Args: {}'.format(json.dumps(args.__dict__, ensure_ascii=False, indent=4)))
|
||||
assert args.pool_type in ['cls', 'avg', 'last', 'weightedavg', 'last_eight'], 'pool_type should be cls / avg / last'
|
||||
# assert args.pool_type in ['cls', 'avg', 'last', 'weightedavg', 'last_eight'], 'pool_type should be cls / avg / last'
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
ALL_NUM = args.all_num
|
||||
|
||||
def find_last_index(lst, element):
|
||||
if not isinstance(element, list):
|
||||
try:
|
||||
reversed_list = lst[::-1]
|
||||
index_in_reversed = reversed_list.index(element)
|
||||
last_index = len(lst) - 1 - index_in_reversed
|
||||
return last_index
|
||||
except ValueError:
|
||||
return -1
|
||||
else:
|
||||
last_index = -1
|
||||
for i in range(len(lst) - len(element)):
|
||||
if lst[i:i + len(element)] == element:
|
||||
last_index = i
|
||||
break
|
||||
return last_index
|
||||
|
||||
print(args.special_token)
|
||||
class DenseEncoder(torch.nn.Module):
|
||||
def __init__(self, device, **kwargs):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
self.encoder = AutoModel.from_pretrained(args.model_name_or_path, use_cache=False)
|
||||
if args.peft_name_or_path is not None:
|
||||
@ -74,13 +92,15 @@ class DenseEncoder(torch.nn.Module):
|
||||
self.prompt = None
|
||||
self.prefix = ''
|
||||
self.suffix = ''
|
||||
self.gpu_count = torch.cuda.device_count()
|
||||
|
||||
self.encoder.half()
|
||||
self.gpu_count = torch.cuda.device_count()
|
||||
|
||||
self.encoder.eval()
|
||||
# self.encoder.cuda()
|
||||
self.device = device
|
||||
self.encoder = self.encoder.to(device)
|
||||
self.encoder.cuda()
|
||||
|
||||
if self.gpu_count > 1:
|
||||
self.encoder = torch.nn.DataParallel(self.encoder)
|
||||
|
||||
self.eight_special_token = args.eight_special_token
|
||||
if args.eight_special_token:
|
||||
@ -92,9 +112,26 @@ class DenseEncoder(torch.nn.Module):
|
||||
|
||||
self.batch_size = args.batch_size
|
||||
|
||||
if args.special_token:
|
||||
self.index_type = 0
|
||||
self.index_start_locs = self.tokenizer('<query>', add_special_tokens=False)['input_ids'][0]
|
||||
self.index_end_locs = self.tokenizer('<response>', add_special_tokens=False)['input_ids'][0]
|
||||
else:
|
||||
self.index_type = 1
|
||||
self.index_start_locs = self.tokenizer('\nQuery:', add_special_tokens=False)['input_ids'][1:]
|
||||
self.index_end_locs = self.tokenizer('\nResponse:', add_special_tokens=False)['input_ids'][1:]
|
||||
|
||||
# if self.gpu_count > 1:
|
||||
# self.encoder = torch.nn.DataParallel(self.encoder)
|
||||
|
||||
def get_loc(self, sentence):
|
||||
sentence = list(sentence)
|
||||
if isinstance(self.index_start_locs, int):
|
||||
return find_last_index(sentence, self.index_start_locs) + 1, find_last_index(sentence, self.index_end_locs)
|
||||
else:
|
||||
return find_last_index(sentence, self.index_start_locs) + len(self.index_start_locs), find_last_index(
|
||||
sentence, self.index_end_locs)
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, sentences, **kwargs) -> np.ndarray:
|
||||
""" Returns a list of embeddings for the given sentences.
|
||||
@ -115,13 +152,22 @@ class DenseEncoder(torch.nn.Module):
|
||||
|
||||
batch_dict = create_batch_query_dict(self.tokenizer, self.prefix, self.suffix, batch_input_texts, special_tokens=self.special_tokens)
|
||||
# if self.device == 0:
|
||||
# print(self.tokenizer.decode(batch_dict['input_ids'][0]))
|
||||
batch_dict = batch_dict.to(self.device)
|
||||
# batch_dict = move_to_cuda(batch_dict)
|
||||
# print(self.tokenizer.decode(batch_dict['input_ids'][0]))
|
||||
# batch_dict = batch_dict.to(self.device)
|
||||
batch_dict = move_to_cuda(batch_dict)
|
||||
|
||||
with torch.cuda.amp.autocast():
|
||||
outputs: BaseModelOutput = self.encoder(**batch_dict)
|
||||
embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], args.pool_type)
|
||||
if args.pool_type == 'mean':
|
||||
seq_locs = [self.get_loc(sentence) for sentence in batch_dict['input_ids']]
|
||||
embeds = torch.stack(
|
||||
[
|
||||
outputs.last_hidden_state[i, start: end, :].mean(dim=0)
|
||||
for i, (start, end) in enumerate(seq_locs)
|
||||
]
|
||||
)
|
||||
else:
|
||||
embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], 'last')
|
||||
if self.l2_normalize:
|
||||
embeds = F.normalize(embeds, p=2, dim=-1)
|
||||
encoded_embeds.append(embeds.cpu().numpy())
|
||||
@ -151,7 +197,16 @@ class DenseEncoder(torch.nn.Module):
|
||||
|
||||
with torch.cuda.amp.autocast():
|
||||
outputs: BaseModelOutput = self.encoder(**batch_dict)
|
||||
embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], args.pool_type)
|
||||
if args.pool_type == 'mean':
|
||||
seq_locs = [self.get_loc(sentence) for sentence in batch_dict['input_ids']]
|
||||
embeds = torch.stack(
|
||||
[
|
||||
outputs.last_hidden_state[i, start: end, :].mean(dim=0)
|
||||
for i, (start, end) in enumerate(seq_locs)
|
||||
]
|
||||
)
|
||||
else:
|
||||
embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], 'last')
|
||||
if self.l2_normalize:
|
||||
embeds = F.normalize(embeds, p=2, dim=-1)
|
||||
encoded_embeds.append(embeds.cpu().numpy())
|
||||
@ -184,7 +239,16 @@ class DenseEncoder(torch.nn.Module):
|
||||
|
||||
with torch.cuda.amp.autocast():
|
||||
outputs: BaseModelOutput = self.encoder(**batch_dict)
|
||||
embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], args.pool_type)
|
||||
if args.pool_type == 'mean':
|
||||
seq_locs = [self.get_loc(sentence) for sentence in batch_dict['input_ids']]
|
||||
embeds = torch.stack(
|
||||
[
|
||||
outputs.last_hidden_state[i, start: end, :].mean(dim=0)
|
||||
for i, (start, end) in enumerate(seq_locs)
|
||||
]
|
||||
)
|
||||
else:
|
||||
embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], 'last')
|
||||
if self.l2_normalize:
|
||||
embeds = F.normalize(embeds, p=2, dim=-1)
|
||||
encoded_embeds.append(embeds.cpu().numpy())
|
||||
@ -201,44 +265,9 @@ class DenseEncoder(torch.nn.Module):
|
||||
self.suffix = suffix
|
||||
|
||||
|
||||
def main(device, all_pairs):
|
||||
torch.cuda.set_device(device)
|
||||
def main(all_pairs):
|
||||
|
||||
model = DenseEncoder(device)
|
||||
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = f'{device}'
|
||||
|
||||
ArxivClusteringP2P_FLAG = False
|
||||
tmp = None
|
||||
for i, p in enumerate(all_pairs):
|
||||
if p[1] == 'ArxivClusteringP2P':
|
||||
ArxivClusteringP2P_FLAG = True
|
||||
tmp = p
|
||||
break
|
||||
if ArxivClusteringP2P_FLAG:
|
||||
all_pairs.remove(tmp)
|
||||
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = f'{device}'
|
||||
if ArxivClusteringP2P_FLAG is False:
|
||||
length = len(all_pairs)
|
||||
start = device * length // ALL_NUM
|
||||
if device == ALL_NUM - 1:
|
||||
end = length
|
||||
else:
|
||||
end = (device + 1) * length // ALL_NUM
|
||||
all_pairs = all_pairs[start: end]
|
||||
else:
|
||||
if device == ALL_NUM - 1:
|
||||
all_pairs = [tmp]
|
||||
else:
|
||||
all_num = ALL_NUM - 1
|
||||
length = len(all_pairs)
|
||||
start = device * length // ALL_NUM
|
||||
if device == all_num - 1:
|
||||
end = length
|
||||
else:
|
||||
end = (device + 1) * length // all_num
|
||||
all_pairs = all_pairs[start: end]
|
||||
model = DenseEncoder()
|
||||
|
||||
for (task_type, task_name) in all_pairs:
|
||||
task_def: str = get_task_def_by_task_name_and_type(task_name=task_name, task_type=task_type)
|
||||
@ -299,6 +328,10 @@ def main(device, all_pairs):
|
||||
except Exception as e:
|
||||
model.batch_size -= 4
|
||||
print(e)
|
||||
# sub_eval.run(
|
||||
# model,
|
||||
# output_folder=args.output_dir
|
||||
# )
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@ -317,13 +350,5 @@ if __name__ == '__main__':
|
||||
if v in args.task_types:
|
||||
all_pairs.append((task_type, v))
|
||||
all_pairs = list(set(all_pairs))
|
||||
random.shuffle(all_pairs)
|
||||
|
||||
for i in range(ALL_NUM):
|
||||
# i = 7
|
||||
process = multiprocessing.Process(target=main, args=(i,all_pairs,))
|
||||
processes.append(process)
|
||||
process.start()
|
||||
|
||||
for process in processes:
|
||||
process.join()
|
||||
main(all_pairs)
|
10
FlagEmbedding/evaluation/mteb/utils/arguments.py
Normal file
10
FlagEmbedding/evaluation/mteb/utils/arguments.py
Normal file
@ -0,0 +1,10 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from FlagEmbedding.abc.evaluation.arguments import AbsEvalArgs
|
||||
|
||||
@dataclass
|
||||
class MTEBEvalArgs(AbsEvalArgs):
|
||||
task_types: List[str] = field(
|
||||
default=None, metadata={"help": "The tasks to evaluate. Default: None"}
|
||||
)
|
@ -145,24 +145,28 @@ class BaseLLMReranker(AbsReranker):
|
||||
trust_remote_code: bool = False,
|
||||
devices: Union[str, List[str], List[int]] = None, # specify devices, such as ["cuda:0"] or ["0"]
|
||||
prompt: str = None,
|
||||
batch_size: int = 128,
|
||||
query_max_length: int = None,
|
||||
max_length: int = 512,
|
||||
normalize: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
|
||||
super().__init__(
|
||||
model_name_or_path,
|
||||
use_fp16,
|
||||
query_instruction_for_rerank,
|
||||
query_instruction_format,
|
||||
passage_instruction_for_rerank,
|
||||
passage_instruction_format,
|
||||
devices,
|
||||
model_name_or_path=model_name_or_path,
|
||||
use_fp16=use_fp16,
|
||||
query_instruction_for_rerank=query_instruction_for_rerank,
|
||||
query_instruction_format=query_instruction_format,
|
||||
passage_instruction_for_rerank=passage_instruction_for_rerank,
|
||||
passage_instruction_format=passage_instruction_format,
|
||||
devices=devices,
|
||||
batch_size=batch_size,
|
||||
query_max_length=query_max_length,
|
||||
max_length=max_length,
|
||||
normalize=normalize,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self.prompt = prompt
|
||||
self.normalize = normalize
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
@ -178,9 +182,6 @@ class BaseLLMReranker(AbsReranker):
|
||||
if peft_path:
|
||||
self.model = PeftModel.from_pretrained(self.model,peft_path)
|
||||
self.model = self.model.merge_and_unload()
|
||||
self.model_name_or_path = model_name_or_path
|
||||
self.cache_dir = cache_dir
|
||||
self.kwargs = kwargs
|
||||
|
||||
self.yes_loc = self.tokenizer('Yes', add_special_tokens=False)['input_ids'][0]
|
||||
|
||||
@ -198,6 +199,13 @@ class BaseLLMReranker(AbsReranker):
|
||||
**kwargs: Any
|
||||
) -> List[float]:
|
||||
if prompt is None: prompt = self.prompt
|
||||
if batch_size is None: batch_size = self.batch_size
|
||||
if max_length is None: max_length = self.max_length
|
||||
if query_max_length is None:
|
||||
if self.query_max_length is not None:
|
||||
query_max_length = self.query_max_length
|
||||
else:
|
||||
query_max_length = max_length * 3 // 4
|
||||
if normalize is None: normalize = self.normalize
|
||||
|
||||
if device is None:
|
||||
@ -206,9 +214,6 @@ class BaseLLMReranker(AbsReranker):
|
||||
if device == "cpu": self.use_fp16 = False
|
||||
if self.use_fp16: self.model.half()
|
||||
|
||||
if device == "cpu": self.use_fp16 = False
|
||||
if self.use_fp16: self.model.half()
|
||||
|
||||
self.model.to(device)
|
||||
self.model.eval()
|
||||
|
||||
@ -227,7 +232,7 @@ class BaseLLMReranker(AbsReranker):
|
||||
queries,
|
||||
return_tensors=None,
|
||||
add_special_tokens=False,
|
||||
max_length=max_length * 3 // 4,
|
||||
max_length=query_max_length,
|
||||
truncation=True,
|
||||
**kwargs
|
||||
)
|
||||
|
@ -42,24 +42,27 @@ class LayerWiseLLMReranker(AbsReranker):
|
||||
devices: Union[str, List[str], List[int]] = None, # specify devices, such as ["cuda:0"] or ["0"]
|
||||
cutoff_layers: List[int] = None,
|
||||
prompt: str = None,
|
||||
batch_size: int = 128,
|
||||
query_max_length: int = None,
|
||||
max_length: int = 512,
|
||||
normalize: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
model_name_or_path,
|
||||
use_fp16,
|
||||
query_instruction_for_rerank,
|
||||
query_instruction_format,
|
||||
passage_instruction_for_rerank,
|
||||
passage_instruction_format,
|
||||
devices,
|
||||
model_name_or_path=model_name_or_path,
|
||||
use_fp16=use_fp16,
|
||||
query_instruction_for_rerank=query_instruction_for_rerank,
|
||||
query_instruction_format=query_instruction_format,
|
||||
passage_instruction_for_rerank=passage_instruction_for_rerank,
|
||||
passage_instruction_format=passage_instruction_format,
|
||||
devices=devices,
|
||||
batch_size=batch_size,
|
||||
query_max_length=query_max_length,
|
||||
max_length=max_length,
|
||||
normalize=normalize,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self.cutoff_layers = cutoff_layers
|
||||
self.prompt = prompt
|
||||
self.normalize = normalize
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
@ -87,8 +90,6 @@ class LayerWiseLLMReranker(AbsReranker):
|
||||
if peft_path:
|
||||
self.model = PeftModel.from_pretrained(self.model,peft_path)
|
||||
self.model = self.model.merge_and_unload()
|
||||
self.model_name_or_path = model_name_or_path
|
||||
self.cache_dir = cache_dir
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_score_single_gpu(
|
||||
@ -106,6 +107,13 @@ class LayerWiseLLMReranker(AbsReranker):
|
||||
) -> List[float]:
|
||||
if cutoff_layers is None: cutoff_layers = self.cutoff_layers
|
||||
if prompt is None: prompt = self.prompt
|
||||
if batch_size is None: batch_size = self.batch_size
|
||||
if max_length is None: max_length = self.max_length
|
||||
if query_max_length is None:
|
||||
if self.query_max_length is not None:
|
||||
query_max_length = self.query_max_length
|
||||
else:
|
||||
query_max_length = max_length * 3 // 4
|
||||
if normalize is None: normalize = self.normalize
|
||||
|
||||
if device is None:
|
||||
@ -114,9 +122,6 @@ class LayerWiseLLMReranker(AbsReranker):
|
||||
if device == "cpu": self.use_fp16 = False
|
||||
if self.use_fp16: self.model.half()
|
||||
|
||||
if device == "cpu": self.use_fp16 = False
|
||||
if self.use_fp16: self.model.half()
|
||||
|
||||
self.model.to(device)
|
||||
self.model.eval()
|
||||
|
||||
@ -135,7 +140,7 @@ class LayerWiseLLMReranker(AbsReranker):
|
||||
queries,
|
||||
return_tensors=None,
|
||||
add_special_tokens=False,
|
||||
max_length=max_length * 3 // 4,
|
||||
max_length=query_max_length,
|
||||
truncation=True,
|
||||
**kwargs
|
||||
)
|
||||
|
@ -91,27 +91,28 @@ class LightweightLLMReranker(AbsReranker):
|
||||
compress_layers: List[int] = [8],
|
||||
compress_ratio: int = 1,
|
||||
prompt: str = None,
|
||||
batch_size: int = 128,
|
||||
query_max_length: int = None,
|
||||
max_length: int = 512,
|
||||
normalize: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
|
||||
super().__init__(
|
||||
model_name_or_path,
|
||||
use_fp16,
|
||||
query_instruction_for_rerank,
|
||||
query_instruction_format,
|
||||
passage_instruction_for_rerank,
|
||||
passage_instruction_format,
|
||||
devices,
|
||||
model_name_or_path=model_name_or_path,
|
||||
use_fp16=use_fp16,
|
||||
query_instruction_for_rerank=query_instruction_for_rerank,
|
||||
query_instruction_format=query_instruction_format,
|
||||
passage_instruction_for_rerank=passage_instruction_for_rerank,
|
||||
passage_instruction_format=passage_instruction_format,
|
||||
devices=devices,
|
||||
batch_size=batch_size,
|
||||
query_max_length=query_max_length,
|
||||
max_length=max_length,
|
||||
normalize=normalize,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self.cutoff_layers = cutoff_layers
|
||||
self.compress_layers = compress_layers
|
||||
self.compress_ratio = compress_ratio
|
||||
self.prompt = prompt
|
||||
self.normalize = normalize
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
@ -140,10 +141,6 @@ class LightweightLLMReranker(AbsReranker):
|
||||
if peft_path:
|
||||
self.model = PeftModel.from_pretrained(self.model,peft_path)
|
||||
self.model = self.model.merge_and_unload()
|
||||
self.model_name_or_path = model_name_or_path
|
||||
self.cache_dir = cache_dir
|
||||
|
||||
self.yes_loc = self.tokenizer('Yes', add_special_tokens=False)['input_ids'][0]
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_score_single_gpu(
|
||||
@ -164,6 +161,13 @@ class LightweightLLMReranker(AbsReranker):
|
||||
if compress_layers is None: compress_layers = self.compress_layers
|
||||
if compress_ratio is None: compress_ratio = self.compress_ratio
|
||||
if prompt is None: prompt = self.prompt
|
||||
if batch_size is None: batch_size = self.batch_size
|
||||
if max_length is None: max_length = self.max_length
|
||||
if query_max_length is None:
|
||||
if self.query_max_length is not None:
|
||||
query_max_length = self.query_max_length
|
||||
else:
|
||||
query_max_length = max_length * 3 // 4
|
||||
if normalize is None: normalize = self.normalize
|
||||
|
||||
if device is None:
|
||||
@ -172,9 +176,6 @@ class LightweightLLMReranker(AbsReranker):
|
||||
if device == "cpu": self.use_fp16 = False
|
||||
if self.use_fp16: self.model.half()
|
||||
|
||||
if device == "cpu": self.use_fp16 = False
|
||||
if self.use_fp16: self.model.half()
|
||||
|
||||
self.model.to(device)
|
||||
self.model.eval()
|
||||
|
||||
@ -193,7 +194,7 @@ class LightweightLLMReranker(AbsReranker):
|
||||
queries,
|
||||
return_tensors=None,
|
||||
add_special_tokens=False,
|
||||
max_length=max_length * 3 // 4,
|
||||
max_length=query_max_length,
|
||||
truncation=True,
|
||||
**kwargs
|
||||
)
|
||||
|
@ -23,36 +23,54 @@ class BaseReranker(AbsReranker):
|
||||
trust_remote_code: bool = False,
|
||||
cache_dir: str = None,
|
||||
devices: Union[str, List[str], List[int]] = None, # specify devices, such as ["cuda:0"] or ["0"]
|
||||
batch_size: int = 128,
|
||||
query_max_length: int = None,
|
||||
max_length: int = 512,
|
||||
normalize: bool = False,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(
|
||||
model_name_or_path,
|
||||
use_fp16,
|
||||
query_instruction_for_rerank,
|
||||
query_instruction_format,
|
||||
passage_instruction_for_rerank,
|
||||
passage_instruction_format,
|
||||
devices,
|
||||
model_name_or_path=model_name_or_path,
|
||||
use_fp16=use_fp16,
|
||||
query_instruction_for_rerank=query_instruction_for_rerank,
|
||||
query_instruction_format=query_instruction_format,
|
||||
passage_instruction_for_rerank=passage_instruction_for_rerank,
|
||||
passage_instruction_format=passage_instruction_format,
|
||||
devices=devices,
|
||||
batch_size=batch_size,
|
||||
query_max_length=query_max_length,
|
||||
max_length=max_length,
|
||||
normalize=normalize,
|
||||
**kwargs
|
||||
)
|
||||
self.normalize = normalize
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir)
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code, cache_dir=cache_dir)
|
||||
|
||||
self.kwargs = kwargs
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name_or_path,
|
||||
cache_dir=cache_dir
|
||||
)
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(
|
||||
model_name_or_path,
|
||||
trust_remote_code=trust_remote_code,
|
||||
cache_dir=cache_dir
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_score_single_gpu(
|
||||
self,
|
||||
sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
|
||||
batch_size: int = 256,
|
||||
max_length: int = 512,
|
||||
batch_size: int = None,
|
||||
query_max_length: int = None,
|
||||
max_length: int = None,
|
||||
normalize: bool = None,
|
||||
device: str = None,
|
||||
**kwargs: Any
|
||||
) -> List[float]:
|
||||
if batch_size is None: batch_size = self.batch_size
|
||||
if max_length is None: max_length = self.max_length
|
||||
if query_max_length is None:
|
||||
if self.query_max_length is not None:
|
||||
query_max_length = self.query_max_length
|
||||
else:
|
||||
query_max_length = max_length * 3 // 4
|
||||
if normalize is None: normalize = self.normalize
|
||||
|
||||
if device is None:
|
||||
@ -61,31 +79,44 @@ class BaseReranker(AbsReranker):
|
||||
if device == "cpu": self.use_fp16 = False
|
||||
if self.use_fp16: self.model.half()
|
||||
|
||||
if device == "cpu": self.use_fp16 = False
|
||||
if self.use_fp16: self.model.half()
|
||||
|
||||
self.model.to(device)
|
||||
self.model.eval()
|
||||
|
||||
assert isinstance(sentence_pairs, list)
|
||||
if isinstance(sentence_pairs[0], str):
|
||||
sentence_pairs = [sentence_pairs]
|
||||
|
||||
|
||||
# tokenize without padding to get the correct length
|
||||
all_inputs = []
|
||||
for start_index in trange(0, len(sentence_pairs), batch_size, desc="pre tokenize"):
|
||||
sentences_batch = sentence_pairs[start_index:start_index + batch_size]
|
||||
inputs_batch = self.tokenizer(
|
||||
sentences_batch,
|
||||
queries = [s[0] for s in sentences_batch]
|
||||
passages = [s[1] for s in sentences_batch]
|
||||
queries_inputs_batch = self.tokenizer(
|
||||
queries,
|
||||
return_tensors=None,
|
||||
add_special_tokens=False,
|
||||
max_length=query_max_length,
|
||||
truncation=True,
|
||||
max_length=max_length,
|
||||
**kwargs
|
||||
)
|
||||
inputs_batch = [{
|
||||
k: inputs_batch[k][i] for k in inputs_batch.keys()
|
||||
} for i in range(len(sentences_batch))]
|
||||
all_inputs.extend(inputs_batch)
|
||||
|
||||
)['input_ids']
|
||||
passages_inputs_batch = self.tokenizer(
|
||||
passages,
|
||||
return_tensors=None,
|
||||
add_special_tokens=False,
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
**kwargs
|
||||
)['input_ids']
|
||||
for q_inp, d_inp in zip(queries_inputs_batch, passages_inputs_batch):
|
||||
item = self.tokenizer.prepare_for_model(
|
||||
q_inp,
|
||||
d_inp,
|
||||
truncation='only_second',
|
||||
max_length=max_length,
|
||||
padding=False,
|
||||
)
|
||||
all_inputs.append(item)
|
||||
# sort by length for less padding
|
||||
length_sorted_idx = np.argsort([-len(x['input_ids']) for x in all_inputs])
|
||||
all_inputs_sorted = [all_inputs[i] for i in length_sorted_idx]
|
||||
|
@ -6,6 +6,9 @@ def test_base_multi_devices():
|
||||
model = FlagAutoReranker.from_finetuned(
|
||||
'BAAI/bge-reranker-large',
|
||||
use_fp16=True,
|
||||
batch_size=128,
|
||||
query_max_length=256,
|
||||
max_length=512,
|
||||
devices=["cuda:3", "cuda:4"], # if you don't have GPUs, you can use ["cpu", "cpu"]
|
||||
cache_dir=os.getenv('HF_HUB_CACHE', None),
|
||||
)
|
||||
@ -27,4 +30,4 @@ if __name__ == '__main__':
|
||||
|
||||
print("--------------------------------")
|
||||
print("Expected Output:")
|
||||
print("[ 7.9765625 -6.859375 -7.1484375 5.44921875]")
|
||||
print("[ 7.97265625 -6.8515625 -7.15625 5.45703125]")
|
||||
|
@ -6,6 +6,9 @@ def test_base_multi_devices():
|
||||
model = FlagAutoReranker.from_finetuned(
|
||||
'BAAI/bge-reranker-large',
|
||||
use_fp16=True,
|
||||
batch_size=128,
|
||||
query_max_length=256,
|
||||
max_length=512,
|
||||
devices=["cuda:3"], # if you don't have GPUs, you can use ["cpu", "cpu"]
|
||||
cache_dir=os.getenv('HF_HUB_CACHE', None),
|
||||
)
|
||||
@ -27,4 +30,4 @@ if __name__ == '__main__':
|
||||
|
||||
print("--------------------------------")
|
||||
print("Expected Output:")
|
||||
print("[7.9765625, -6.859375, -7.15625, 5.44921875]")
|
||||
print("[7.9765625, -6.84375, -7.15625, 5.453125]")
|
||||
|
@ -6,6 +6,9 @@ def test_base_multi_devices():
|
||||
model = FlagReranker(
|
||||
'BAAI/bge-reranker-large',
|
||||
use_fp16=True,
|
||||
batch_size=128,
|
||||
query_max_length=256,
|
||||
max_length=512,
|
||||
devices=["cuda:3", "cuda:4"], # if you don't have GPUs, you can use ["cpu", "cpu"]
|
||||
cache_dir=os.getenv('HF_HUB_CACHE', None),
|
||||
)
|
||||
@ -27,4 +30,4 @@ if __name__ == '__main__':
|
||||
|
||||
print("--------------------------------")
|
||||
print("Expected Output:")
|
||||
print("[ 7.9765625 -6.859375 -7.1484375 5.44921875]")
|
||||
print("[ 7.97265625 -6.8515625 -7.15625 5.45703125]")
|
||||
|
@ -6,6 +6,9 @@ def test_base_multi_devices():
|
||||
model = FlagReranker(
|
||||
'BAAI/bge-reranker-large',
|
||||
use_fp16=True,
|
||||
batch_size=128,
|
||||
query_max_length=256,
|
||||
max_length=512,
|
||||
devices=["cuda:3"], # if you don't have GPUs, you can use ["cpu", "cpu"]
|
||||
cache_dir=os.getenv('HF_HUB_CACHE', None),
|
||||
)
|
||||
@ -27,4 +30,4 @@ if __name__ == '__main__':
|
||||
|
||||
print("--------------------------------")
|
||||
print("Expected Output:")
|
||||
print("[7.9765625, -6.859375, -7.15625, 5.44921875]")
|
||||
print("[7.9765625, -6.84375, -7.15625, 5.453125]")
|
||||
|
Loading…
x
Reference in New Issue
Block a user