From f653727e83e51862063812c690982c1c43faef8f Mon Sep 17 00:00:00 2001 From: cfli <545999961@qq.com> Date: Fri, 25 Oct 2024 14:51:47 +0800 Subject: [PATCH] reranker init --- FlagEmbedding/abc/inference/AbsReranker.py | 18 ++- FlagEmbedding/evaluation/mteb/__main__.py | 42 ++++- FlagEmbedding/evaluation/mteb/evaluate.py | 149 ++++++++++-------- .../evaluation/mteb/utils/arguments.py | 10 ++ .../inference/reranker/decoder_only/base.py | 39 +++-- .../reranker/decoder_only/layerwise.py | 39 +++-- .../reranker/decoder_only/lightweight.py | 43 ++--- .../inference/reranker/encoder_only/base.py | 87 ++++++---- .../encoder_only/auto_base_multi_devices.py | 5 +- .../encoder_only/auto_base_single_device.py | 5 +- .../encoder_only/base_multi_devices.py | 5 +- .../encoder_only/base_single_device.py | 5 +- 12 files changed, 292 insertions(+), 155 deletions(-) create mode 100644 FlagEmbedding/evaluation/mteb/utils/arguments.py diff --git a/FlagEmbedding/abc/inference/AbsReranker.py b/FlagEmbedding/abc/inference/AbsReranker.py index 580c849..071b1b0 100644 --- a/FlagEmbedding/abc/inference/AbsReranker.py +++ b/FlagEmbedding/abc/inference/AbsReranker.py @@ -29,6 +29,10 @@ class AbsReranker(ABC): passage_instruction_for_rerank: str = None, passage_instruction_format: str = "{}{}", # specify the format of passage_instruction_for_rerank devices: Union[str, int, List[str], List[int]] = None, + batch_size: int = 128, + query_max_length: int = None, + max_length: int = 512, + normalize: bool = False, **kwargs: Any, ): self.model_name_or_path = model_name_or_path @@ -38,6 +42,13 @@ class AbsReranker(ABC): self.passage_instruction_for_rerank = passage_instruction_for_rerank self.passage_instruction_format = passage_instruction_format self.target_devices = self.get_target_devices(devices) + self.batch_size = batch_size + self.query_max_length = query_max_length + self.max_length = max_length + self.normalize = normalize + + for k in kwargs: + setattr(self, k, kwargs[k]) self.kwargs = kwargs @@ -73,8 +84,8 @@ class AbsReranker(ABC): if isinstance(sentence_pairs, str): sentence_pairs = [sentence_pairs] - if self.query_instruction_format is not None: - if self.passage_instruction_format is None: + if self.query_instruction_for_rerank is not None: + if self.passage_instruction_for_rerank is None: return [ [ self.get_detailed_instruct(self.query_instruction_format, self.query_instruction_for_rerank, sentence_pair[0]), @@ -89,7 +100,7 @@ class AbsReranker(ABC): ] for sentence_pair in sentence_pairs ] else: - if self.passage_instruction_format is None: + if self.passage_instruction_for_rerank is None: return [ [ sentence_pair[0], @@ -130,6 +141,7 @@ class AbsReranker(ABC): self, sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]], batch_size: int = 256, + query_max_length: int = None, max_length: int = 512, normalize: bool = False, device: str = None, diff --git a/FlagEmbedding/evaluation/mteb/__main__.py b/FlagEmbedding/evaluation/mteb/__main__.py index 56aec42..5c3491a 100644 --- a/FlagEmbedding/evaluation/mteb/__main__.py +++ b/FlagEmbedding/evaluation/mteb/__main__.py @@ -1,11 +1,14 @@ +import mteb + + from transformers import HfArgumentParser from FlagEmbedding import FlagAutoModel, FlagAutoReranker from FlagEmbedding.abc.evaluation import AbsModelArgs, AbsEmbedder, AbsReranker, AbsEvaluator -from utils.arguments import MSMARCOEvalArgs -from utils.data_loader import MSMARCODataLoader +from utils.arguments import MTEBEvalArgs +from utils.prompts import get_task_def_by_task_name_and_type, tasks_desc def get_models(model_args: AbsModelArgs): @@ -41,4 +44,37 @@ def get_models(model_args: AbsModelArgs): compress_layers=model_args.compress_layers, compress_ratio=model_args.compress_ratio, ) - return retriever, reranker \ No newline at end of file + return retriever, reranker + +def main(): + parser = HfArgumentParser([AbsModelArgs, BEIREvalArgs]) + model_args, eval_args = parser.parse_args_into_dataclasses() + model_args: AbsModelArgs + eval_args: BEIREvalArgs + + retriever, reranker = get_models(model_args) + + task_types = eval_args.task_types + tasks = eval_args.tasks + languages = eval_args.languages + tasks = mteb.get_tasks( + languages=languages, + tasks=tasks, + task_types=task_types + ) + evaluation = mteb.MTEB(tasks=tasks) + results = evaluation.run(retriever, output_folder=f"results/{str(retriever)}") + + # all_pairs = [] + # for task_type in eval_args.task_types: + # if task_type in tasks_desc.keys(): + # for task_name in tasks_desc[task_type]: + # all_pairs.append((task_type, task_name)) + # for task_type in tasks_desc.keys(): + # for v in tasks_desc[task_type]: + # if v in eval_args.task_types: + # all_pairs.append((task_type, v)) + # all_pairs = list(set(all_pairs)) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/FlagEmbedding/evaluation/mteb/evaluate.py b/FlagEmbedding/evaluation/mteb/evaluate.py index 3da10ba..9e3cca2 100644 --- a/FlagEmbedding/evaluation/mteb/evaluate.py +++ b/FlagEmbedding/evaluation/mteb/evaluate.py @@ -38,22 +38,40 @@ parser.add_argument('--batch_size', default=32, type=int) parser.add_argument('--examples-dir', default='/share/chaofan/code/embedder/evaluate_for_icl/examples', type=str) parser.add_argument('--eight-special-token', default=False, type=bool) parser.add_argument('--passage-prompt', default=False, type=bool) +parser.add_argument('--pool-type', default='last', type=bool) args = parser.parse_args() base_name: str = args.model_name_or_path.split('/')[-1] -if args.eight_special_token is True: - args.pool_type = 'last_eight' -else: - args.pool_type = 'last' +# if args.eight_special_token is True: +# args.pool_type = 'last_eight' +# else: +# args.pool_type = 'last' logger.info('Args: {}'.format(json.dumps(args.__dict__, ensure_ascii=False, indent=4))) -assert args.pool_type in ['cls', 'avg', 'last', 'weightedavg', 'last_eight'], 'pool_type should be cls / avg / last' +# assert args.pool_type in ['cls', 'avg', 'last', 'weightedavg', 'last_eight'], 'pool_type should be cls / avg / last' os.makedirs(args.output_dir, exist_ok=True) ALL_NUM = args.all_num +def find_last_index(lst, element): + if not isinstance(element, list): + try: + reversed_list = lst[::-1] + index_in_reversed = reversed_list.index(element) + last_index = len(lst) - 1 - index_in_reversed + return last_index + except ValueError: + return -1 + else: + last_index = -1 + for i in range(len(lst) - len(element)): + if lst[i:i + len(element)] == element: + last_index = i + break + return last_index + print(args.special_token) class DenseEncoder(torch.nn.Module): - def __init__(self, device, **kwargs): + def __init__(self, **kwargs): super().__init__() self.encoder = AutoModel.from_pretrained(args.model_name_or_path, use_cache=False) if args.peft_name_or_path is not None: @@ -74,13 +92,15 @@ class DenseEncoder(torch.nn.Module): self.prompt = None self.prefix = '' self.suffix = '' - self.gpu_count = torch.cuda.device_count() self.encoder.half() + self.gpu_count = torch.cuda.device_count() + self.encoder.eval() - # self.encoder.cuda() - self.device = device - self.encoder = self.encoder.to(device) + self.encoder.cuda() + + if self.gpu_count > 1: + self.encoder = torch.nn.DataParallel(self.encoder) self.eight_special_token = args.eight_special_token if args.eight_special_token: @@ -92,9 +112,26 @@ class DenseEncoder(torch.nn.Module): self.batch_size = args.batch_size + if args.special_token: + self.index_type = 0 + self.index_start_locs = self.tokenizer('', add_special_tokens=False)['input_ids'][0] + self.index_end_locs = self.tokenizer('', add_special_tokens=False)['input_ids'][0] + else: + self.index_type = 1 + self.index_start_locs = self.tokenizer('\nQuery:', add_special_tokens=False)['input_ids'][1:] + self.index_end_locs = self.tokenizer('\nResponse:', add_special_tokens=False)['input_ids'][1:] + # if self.gpu_count > 1: # self.encoder = torch.nn.DataParallel(self.encoder) + def get_loc(self, sentence): + sentence = list(sentence) + if isinstance(self.index_start_locs, int): + return find_last_index(sentence, self.index_start_locs) + 1, find_last_index(sentence, self.index_end_locs) + else: + return find_last_index(sentence, self.index_start_locs) + len(self.index_start_locs), find_last_index( + sentence, self.index_end_locs) + @torch.no_grad() def encode(self, sentences, **kwargs) -> np.ndarray: """ Returns a list of embeddings for the given sentences. @@ -115,13 +152,22 @@ class DenseEncoder(torch.nn.Module): batch_dict = create_batch_query_dict(self.tokenizer, self.prefix, self.suffix, batch_input_texts, special_tokens=self.special_tokens) # if self.device == 0: - # print(self.tokenizer.decode(batch_dict['input_ids'][0])) - batch_dict = batch_dict.to(self.device) - # batch_dict = move_to_cuda(batch_dict) + # print(self.tokenizer.decode(batch_dict['input_ids'][0])) + # batch_dict = batch_dict.to(self.device) + batch_dict = move_to_cuda(batch_dict) with torch.cuda.amp.autocast(): outputs: BaseModelOutput = self.encoder(**batch_dict) - embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], args.pool_type) + if args.pool_type == 'mean': + seq_locs = [self.get_loc(sentence) for sentence in batch_dict['input_ids']] + embeds = torch.stack( + [ + outputs.last_hidden_state[i, start: end, :].mean(dim=0) + for i, (start, end) in enumerate(seq_locs) + ] + ) + else: + embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], 'last') if self.l2_normalize: embeds = F.normalize(embeds, p=2, dim=-1) encoded_embeds.append(embeds.cpu().numpy()) @@ -151,7 +197,16 @@ class DenseEncoder(torch.nn.Module): with torch.cuda.amp.autocast(): outputs: BaseModelOutput = self.encoder(**batch_dict) - embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], args.pool_type) + if args.pool_type == 'mean': + seq_locs = [self.get_loc(sentence) for sentence in batch_dict['input_ids']] + embeds = torch.stack( + [ + outputs.last_hidden_state[i, start: end, :].mean(dim=0) + for i, (start, end) in enumerate(seq_locs) + ] + ) + else: + embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], 'last') if self.l2_normalize: embeds = F.normalize(embeds, p=2, dim=-1) encoded_embeds.append(embeds.cpu().numpy()) @@ -184,7 +239,16 @@ class DenseEncoder(torch.nn.Module): with torch.cuda.amp.autocast(): outputs: BaseModelOutput = self.encoder(**batch_dict) - embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], args.pool_type) + if args.pool_type == 'mean': + seq_locs = [self.get_loc(sentence) for sentence in batch_dict['input_ids']] + embeds = torch.stack( + [ + outputs.last_hidden_state[i, start: end, :].mean(dim=0) + for i, (start, end) in enumerate(seq_locs) + ] + ) + else: + embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], 'last') if self.l2_normalize: embeds = F.normalize(embeds, p=2, dim=-1) encoded_embeds.append(embeds.cpu().numpy()) @@ -201,44 +265,9 @@ class DenseEncoder(torch.nn.Module): self.suffix = suffix -def main(device, all_pairs): - torch.cuda.set_device(device) +def main(all_pairs): - model = DenseEncoder(device) - - os.environ['CUDA_VISIBLE_DEVICES'] = f'{device}' - - ArxivClusteringP2P_FLAG = False - tmp = None - for i, p in enumerate(all_pairs): - if p[1] == 'ArxivClusteringP2P': - ArxivClusteringP2P_FLAG = True - tmp = p - break - if ArxivClusteringP2P_FLAG: - all_pairs.remove(tmp) - - os.environ['CUDA_VISIBLE_DEVICES'] = f'{device}' - if ArxivClusteringP2P_FLAG is False: - length = len(all_pairs) - start = device * length // ALL_NUM - if device == ALL_NUM - 1: - end = length - else: - end = (device + 1) * length // ALL_NUM - all_pairs = all_pairs[start: end] - else: - if device == ALL_NUM - 1: - all_pairs = [tmp] - else: - all_num = ALL_NUM - 1 - length = len(all_pairs) - start = device * length // ALL_NUM - if device == all_num - 1: - end = length - else: - end = (device + 1) * length // all_num - all_pairs = all_pairs[start: end] + model = DenseEncoder() for (task_type, task_name) in all_pairs: task_def: str = get_task_def_by_task_name_and_type(task_name=task_name, task_type=task_type) @@ -299,6 +328,10 @@ def main(device, all_pairs): except Exception as e: model.batch_size -= 4 print(e) + # sub_eval.run( + # model, + # output_folder=args.output_dir + # ) if __name__ == '__main__': @@ -317,13 +350,5 @@ if __name__ == '__main__': if v in args.task_types: all_pairs.append((task_type, v)) all_pairs = list(set(all_pairs)) - random.shuffle(all_pairs) - for i in range(ALL_NUM): - # i = 7 - process = multiprocessing.Process(target=main, args=(i,all_pairs,)) - processes.append(process) - process.start() - - for process in processes: - process.join() \ No newline at end of file + main(all_pairs) \ No newline at end of file diff --git a/FlagEmbedding/evaluation/mteb/utils/arguments.py b/FlagEmbedding/evaluation/mteb/utils/arguments.py new file mode 100644 index 0000000..cdf6bbf --- /dev/null +++ b/FlagEmbedding/evaluation/mteb/utils/arguments.py @@ -0,0 +1,10 @@ +from dataclasses import dataclass, field +from typing import List + +from FlagEmbedding.abc.evaluation.arguments import AbsEvalArgs + +@dataclass +class MTEBEvalArgs(AbsEvalArgs): + task_types: List[str] = field( + default=None, metadata={"help": "The tasks to evaluate. Default: None"} + ) \ No newline at end of file diff --git a/FlagEmbedding/inference/reranker/decoder_only/base.py b/FlagEmbedding/inference/reranker/decoder_only/base.py index 6c9ed7b..c05dd90 100644 --- a/FlagEmbedding/inference/reranker/decoder_only/base.py +++ b/FlagEmbedding/inference/reranker/decoder_only/base.py @@ -145,24 +145,28 @@ class BaseLLMReranker(AbsReranker): trust_remote_code: bool = False, devices: Union[str, List[str], List[int]] = None, # specify devices, such as ["cuda:0"] or ["0"] prompt: str = None, + batch_size: int = 128, + query_max_length: int = None, + max_length: int = 512, normalize: bool = False, **kwargs: Any, ) -> None: super().__init__( - model_name_or_path, - use_fp16, - query_instruction_for_rerank, - query_instruction_format, - passage_instruction_for_rerank, - passage_instruction_format, - devices, + model_name_or_path=model_name_or_path, + use_fp16=use_fp16, + query_instruction_for_rerank=query_instruction_for_rerank, + query_instruction_format=query_instruction_format, + passage_instruction_for_rerank=passage_instruction_for_rerank, + passage_instruction_format=passage_instruction_format, + devices=devices, + batch_size=batch_size, + query_max_length=query_max_length, + max_length=max_length, + normalize=normalize, **kwargs ) - self.prompt = prompt - self.normalize = normalize - self.tokenizer = AutoTokenizer.from_pretrained( model_name_or_path, cache_dir=cache_dir, @@ -178,9 +182,6 @@ class BaseLLMReranker(AbsReranker): if peft_path: self.model = PeftModel.from_pretrained(self.model,peft_path) self.model = self.model.merge_and_unload() - self.model_name_or_path = model_name_or_path - self.cache_dir = cache_dir - self.kwargs = kwargs self.yes_loc = self.tokenizer('Yes', add_special_tokens=False)['input_ids'][0] @@ -198,6 +199,13 @@ class BaseLLMReranker(AbsReranker): **kwargs: Any ) -> List[float]: if prompt is None: prompt = self.prompt + if batch_size is None: batch_size = self.batch_size + if max_length is None: max_length = self.max_length + if query_max_length is None: + if self.query_max_length is not None: + query_max_length = self.query_max_length + else: + query_max_length = max_length * 3 // 4 if normalize is None: normalize = self.normalize if device is None: @@ -206,9 +214,6 @@ class BaseLLMReranker(AbsReranker): if device == "cpu": self.use_fp16 = False if self.use_fp16: self.model.half() - if device == "cpu": self.use_fp16 = False - if self.use_fp16: self.model.half() - self.model.to(device) self.model.eval() @@ -227,7 +232,7 @@ class BaseLLMReranker(AbsReranker): queries, return_tensors=None, add_special_tokens=False, - max_length=max_length * 3 // 4, + max_length=query_max_length, truncation=True, **kwargs ) diff --git a/FlagEmbedding/inference/reranker/decoder_only/layerwise.py b/FlagEmbedding/inference/reranker/decoder_only/layerwise.py index 303021e..bd07128 100644 --- a/FlagEmbedding/inference/reranker/decoder_only/layerwise.py +++ b/FlagEmbedding/inference/reranker/decoder_only/layerwise.py @@ -42,24 +42,27 @@ class LayerWiseLLMReranker(AbsReranker): devices: Union[str, List[str], List[int]] = None, # specify devices, such as ["cuda:0"] or ["0"] cutoff_layers: List[int] = None, prompt: str = None, + batch_size: int = 128, + query_max_length: int = None, + max_length: int = 512, normalize: bool = False, **kwargs: Any, ) -> None: super().__init__( - model_name_or_path, - use_fp16, - query_instruction_for_rerank, - query_instruction_format, - passage_instruction_for_rerank, - passage_instruction_format, - devices, + model_name_or_path=model_name_or_path, + use_fp16=use_fp16, + query_instruction_for_rerank=query_instruction_for_rerank, + query_instruction_format=query_instruction_format, + passage_instruction_for_rerank=passage_instruction_for_rerank, + passage_instruction_format=passage_instruction_format, + devices=devices, + batch_size=batch_size, + query_max_length=query_max_length, + max_length=max_length, + normalize=normalize, **kwargs ) - self.cutoff_layers = cutoff_layers - self.prompt = prompt - self.normalize = normalize - self.tokenizer = AutoTokenizer.from_pretrained( model_name_or_path, cache_dir=cache_dir, @@ -87,8 +90,6 @@ class LayerWiseLLMReranker(AbsReranker): if peft_path: self.model = PeftModel.from_pretrained(self.model,peft_path) self.model = self.model.merge_and_unload() - self.model_name_or_path = model_name_or_path - self.cache_dir = cache_dir @torch.no_grad() def compute_score_single_gpu( @@ -106,6 +107,13 @@ class LayerWiseLLMReranker(AbsReranker): ) -> List[float]: if cutoff_layers is None: cutoff_layers = self.cutoff_layers if prompt is None: prompt = self.prompt + if batch_size is None: batch_size = self.batch_size + if max_length is None: max_length = self.max_length + if query_max_length is None: + if self.query_max_length is not None: + query_max_length = self.query_max_length + else: + query_max_length = max_length * 3 // 4 if normalize is None: normalize = self.normalize if device is None: @@ -114,9 +122,6 @@ class LayerWiseLLMReranker(AbsReranker): if device == "cpu": self.use_fp16 = False if self.use_fp16: self.model.half() - if device == "cpu": self.use_fp16 = False - if self.use_fp16: self.model.half() - self.model.to(device) self.model.eval() @@ -135,7 +140,7 @@ class LayerWiseLLMReranker(AbsReranker): queries, return_tensors=None, add_special_tokens=False, - max_length=max_length * 3 // 4, + max_length=query_max_length, truncation=True, **kwargs ) diff --git a/FlagEmbedding/inference/reranker/decoder_only/lightweight.py b/FlagEmbedding/inference/reranker/decoder_only/lightweight.py index 5d10e37..bf846d4 100644 --- a/FlagEmbedding/inference/reranker/decoder_only/lightweight.py +++ b/FlagEmbedding/inference/reranker/decoder_only/lightweight.py @@ -91,27 +91,28 @@ class LightweightLLMReranker(AbsReranker): compress_layers: List[int] = [8], compress_ratio: int = 1, prompt: str = None, + batch_size: int = 128, + query_max_length: int = None, + max_length: int = 512, normalize: bool = False, **kwargs: Any, ) -> None: super().__init__( - model_name_or_path, - use_fp16, - query_instruction_for_rerank, - query_instruction_format, - passage_instruction_for_rerank, - passage_instruction_format, - devices, + model_name_or_path=model_name_or_path, + use_fp16=use_fp16, + query_instruction_for_rerank=query_instruction_for_rerank, + query_instruction_format=query_instruction_format, + passage_instruction_for_rerank=passage_instruction_for_rerank, + passage_instruction_format=passage_instruction_format, + devices=devices, + batch_size=batch_size, + query_max_length=query_max_length, + max_length=max_length, + normalize=normalize, **kwargs ) - self.cutoff_layers = cutoff_layers - self.compress_layers = compress_layers - self.compress_ratio = compress_ratio - self.prompt = prompt - self.normalize = normalize - self.tokenizer = AutoTokenizer.from_pretrained( model_name_or_path, cache_dir=cache_dir, @@ -140,10 +141,6 @@ class LightweightLLMReranker(AbsReranker): if peft_path: self.model = PeftModel.from_pretrained(self.model,peft_path) self.model = self.model.merge_and_unload() - self.model_name_or_path = model_name_or_path - self.cache_dir = cache_dir - - self.yes_loc = self.tokenizer('Yes', add_special_tokens=False)['input_ids'][0] @torch.no_grad() def compute_score_single_gpu( @@ -164,6 +161,13 @@ class LightweightLLMReranker(AbsReranker): if compress_layers is None: compress_layers = self.compress_layers if compress_ratio is None: compress_ratio = self.compress_ratio if prompt is None: prompt = self.prompt + if batch_size is None: batch_size = self.batch_size + if max_length is None: max_length = self.max_length + if query_max_length is None: + if self.query_max_length is not None: + query_max_length = self.query_max_length + else: + query_max_length = max_length * 3 // 4 if normalize is None: normalize = self.normalize if device is None: @@ -172,9 +176,6 @@ class LightweightLLMReranker(AbsReranker): if device == "cpu": self.use_fp16 = False if self.use_fp16: self.model.half() - if device == "cpu": self.use_fp16 = False - if self.use_fp16: self.model.half() - self.model.to(device) self.model.eval() @@ -193,7 +194,7 @@ class LightweightLLMReranker(AbsReranker): queries, return_tensors=None, add_special_tokens=False, - max_length=max_length * 3 // 4, + max_length=query_max_length, truncation=True, **kwargs ) diff --git a/FlagEmbedding/inference/reranker/encoder_only/base.py b/FlagEmbedding/inference/reranker/encoder_only/base.py index d069eb0..00be721 100644 --- a/FlagEmbedding/inference/reranker/encoder_only/base.py +++ b/FlagEmbedding/inference/reranker/encoder_only/base.py @@ -23,36 +23,54 @@ class BaseReranker(AbsReranker): trust_remote_code: bool = False, cache_dir: str = None, devices: Union[str, List[str], List[int]] = None, # specify devices, such as ["cuda:0"] or ["0"] + batch_size: int = 128, + query_max_length: int = None, + max_length: int = 512, normalize: bool = False, **kwargs: Any, ): super().__init__( - model_name_or_path, - use_fp16, - query_instruction_for_rerank, - query_instruction_format, - passage_instruction_for_rerank, - passage_instruction_format, - devices, + model_name_or_path=model_name_or_path, + use_fp16=use_fp16, + query_instruction_for_rerank=query_instruction_for_rerank, + query_instruction_format=query_instruction_format, + passage_instruction_for_rerank=passage_instruction_for_rerank, + passage_instruction_format=passage_instruction_format, + devices=devices, + batch_size=batch_size, + query_max_length=query_max_length, + max_length=max_length, + normalize=normalize, **kwargs ) - self.normalize = normalize - - self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir) - self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code, cache_dir=cache_dir) - - self.kwargs = kwargs + self.tokenizer = AutoTokenizer.from_pretrained( + model_name_or_path, + cache_dir=cache_dir + ) + self.model = AutoModelForSequenceClassification.from_pretrained( + model_name_or_path, + trust_remote_code=trust_remote_code, + cache_dir=cache_dir + ) @torch.no_grad() def compute_score_single_gpu( self, sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]], - batch_size: int = 256, - max_length: int = 512, + batch_size: int = None, + query_max_length: int = None, + max_length: int = None, normalize: bool = None, device: str = None, **kwargs: Any ) -> List[float]: + if batch_size is None: batch_size = self.batch_size + if max_length is None: max_length = self.max_length + if query_max_length is None: + if self.query_max_length is not None: + query_max_length = self.query_max_length + else: + query_max_length = max_length * 3 // 4 if normalize is None: normalize = self.normalize if device is None: @@ -61,31 +79,44 @@ class BaseReranker(AbsReranker): if device == "cpu": self.use_fp16 = False if self.use_fp16: self.model.half() - if device == "cpu": self.use_fp16 = False - if self.use_fp16: self.model.half() - self.model.to(device) self.model.eval() assert isinstance(sentence_pairs, list) if isinstance(sentence_pairs[0], str): sentence_pairs = [sentence_pairs] - + # tokenize without padding to get the correct length all_inputs = [] for start_index in trange(0, len(sentence_pairs), batch_size, desc="pre tokenize"): sentences_batch = sentence_pairs[start_index:start_index + batch_size] - inputs_batch = self.tokenizer( - sentences_batch, + queries = [s[0] for s in sentences_batch] + passages = [s[1] for s in sentences_batch] + queries_inputs_batch = self.tokenizer( + queries, + return_tensors=None, + add_special_tokens=False, + max_length=query_max_length, truncation=True, - max_length=max_length, **kwargs - ) - inputs_batch = [{ - k: inputs_batch[k][i] for k in inputs_batch.keys() - } for i in range(len(sentences_batch))] - all_inputs.extend(inputs_batch) - + )['input_ids'] + passages_inputs_batch = self.tokenizer( + passages, + return_tensors=None, + add_special_tokens=False, + max_length=max_length, + truncation=True, + **kwargs + )['input_ids'] + for q_inp, d_inp in zip(queries_inputs_batch, passages_inputs_batch): + item = self.tokenizer.prepare_for_model( + q_inp, + d_inp, + truncation='only_second', + max_length=max_length, + padding=False, + ) + all_inputs.append(item) # sort by length for less padding length_sorted_idx = np.argsort([-len(x['input_ids']) for x in all_inputs]) all_inputs_sorted = [all_inputs[i] for i in length_sorted_idx] diff --git a/examples/inference/reranker/encoder_only/auto_base_multi_devices.py b/examples/inference/reranker/encoder_only/auto_base_multi_devices.py index f948e33..299fc3b 100644 --- a/examples/inference/reranker/encoder_only/auto_base_multi_devices.py +++ b/examples/inference/reranker/encoder_only/auto_base_multi_devices.py @@ -6,6 +6,9 @@ def test_base_multi_devices(): model = FlagAutoReranker.from_finetuned( 'BAAI/bge-reranker-large', use_fp16=True, + batch_size=128, + query_max_length=256, + max_length=512, devices=["cuda:3", "cuda:4"], # if you don't have GPUs, you can use ["cpu", "cpu"] cache_dir=os.getenv('HF_HUB_CACHE', None), ) @@ -27,4 +30,4 @@ if __name__ == '__main__': print("--------------------------------") print("Expected Output:") - print("[ 7.9765625 -6.859375 -7.1484375 5.44921875]") + print("[ 7.97265625 -6.8515625 -7.15625 5.45703125]") diff --git a/examples/inference/reranker/encoder_only/auto_base_single_device.py b/examples/inference/reranker/encoder_only/auto_base_single_device.py index a6153cb..8309901 100644 --- a/examples/inference/reranker/encoder_only/auto_base_single_device.py +++ b/examples/inference/reranker/encoder_only/auto_base_single_device.py @@ -6,6 +6,9 @@ def test_base_multi_devices(): model = FlagAutoReranker.from_finetuned( 'BAAI/bge-reranker-large', use_fp16=True, + batch_size=128, + query_max_length=256, + max_length=512, devices=["cuda:3"], # if you don't have GPUs, you can use ["cpu", "cpu"] cache_dir=os.getenv('HF_HUB_CACHE', None), ) @@ -27,4 +30,4 @@ if __name__ == '__main__': print("--------------------------------") print("Expected Output:") - print("[7.9765625, -6.859375, -7.15625, 5.44921875]") + print("[7.9765625, -6.84375, -7.15625, 5.453125]") diff --git a/examples/inference/reranker/encoder_only/base_multi_devices.py b/examples/inference/reranker/encoder_only/base_multi_devices.py index 262d592..36ac1b1 100644 --- a/examples/inference/reranker/encoder_only/base_multi_devices.py +++ b/examples/inference/reranker/encoder_only/base_multi_devices.py @@ -6,6 +6,9 @@ def test_base_multi_devices(): model = FlagReranker( 'BAAI/bge-reranker-large', use_fp16=True, + batch_size=128, + query_max_length=256, + max_length=512, devices=["cuda:3", "cuda:4"], # if you don't have GPUs, you can use ["cpu", "cpu"] cache_dir=os.getenv('HF_HUB_CACHE', None), ) @@ -27,4 +30,4 @@ if __name__ == '__main__': print("--------------------------------") print("Expected Output:") - print("[ 7.9765625 -6.859375 -7.1484375 5.44921875]") + print("[ 7.97265625 -6.8515625 -7.15625 5.45703125]") diff --git a/examples/inference/reranker/encoder_only/base_single_device.py b/examples/inference/reranker/encoder_only/base_single_device.py index ed3bfb3..debfec1 100644 --- a/examples/inference/reranker/encoder_only/base_single_device.py +++ b/examples/inference/reranker/encoder_only/base_single_device.py @@ -6,6 +6,9 @@ def test_base_multi_devices(): model = FlagReranker( 'BAAI/bge-reranker-large', use_fp16=True, + batch_size=128, + query_max_length=256, + max_length=512, devices=["cuda:3"], # if you don't have GPUs, you can use ["cpu", "cpu"] cache_dir=os.getenv('HF_HUB_CACHE', None), ) @@ -27,4 +30,4 @@ if __name__ == '__main__': print("--------------------------------") print("Expected Output:") - print("[7.9765625, -6.859375, -7.15625, 5.44921875]") + print("[7.9765625, -6.84375, -7.15625, 5.453125]")