reranker init

This commit is contained in:
cfli 2024-10-25 14:51:47 +08:00
parent 00a42ccd4f
commit f653727e83
12 changed files with 292 additions and 155 deletions

View File

@ -29,6 +29,10 @@ class AbsReranker(ABC):
passage_instruction_for_rerank: str = None,
passage_instruction_format: str = "{}{}", # specify the format of passage_instruction_for_rerank
devices: Union[str, int, List[str], List[int]] = None,
batch_size: int = 128,
query_max_length: int = None,
max_length: int = 512,
normalize: bool = False,
**kwargs: Any,
):
self.model_name_or_path = model_name_or_path
@ -38,6 +42,13 @@ class AbsReranker(ABC):
self.passage_instruction_for_rerank = passage_instruction_for_rerank
self.passage_instruction_format = passage_instruction_format
self.target_devices = self.get_target_devices(devices)
self.batch_size = batch_size
self.query_max_length = query_max_length
self.max_length = max_length
self.normalize = normalize
for k in kwargs:
setattr(self, k, kwargs[k])
self.kwargs = kwargs
@ -73,8 +84,8 @@ class AbsReranker(ABC):
if isinstance(sentence_pairs, str):
sentence_pairs = [sentence_pairs]
if self.query_instruction_format is not None:
if self.passage_instruction_format is None:
if self.query_instruction_for_rerank is not None:
if self.passage_instruction_for_rerank is None:
return [
[
self.get_detailed_instruct(self.query_instruction_format, self.query_instruction_for_rerank, sentence_pair[0]),
@ -89,7 +100,7 @@ class AbsReranker(ABC):
] for sentence_pair in sentence_pairs
]
else:
if self.passage_instruction_format is None:
if self.passage_instruction_for_rerank is None:
return [
[
sentence_pair[0],
@ -130,6 +141,7 @@ class AbsReranker(ABC):
self,
sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
batch_size: int = 256,
query_max_length: int = None,
max_length: int = 512,
normalize: bool = False,
device: str = None,

View File

@ -1,11 +1,14 @@
import mteb
from transformers import HfArgumentParser
from FlagEmbedding import FlagAutoModel, FlagAutoReranker
from FlagEmbedding.abc.evaluation import AbsModelArgs, AbsEmbedder, AbsReranker, AbsEvaluator
from utils.arguments import MSMARCOEvalArgs
from utils.data_loader import MSMARCODataLoader
from utils.arguments import MTEBEvalArgs
from utils.prompts import get_task_def_by_task_name_and_type, tasks_desc
def get_models(model_args: AbsModelArgs):
@ -41,4 +44,37 @@ def get_models(model_args: AbsModelArgs):
compress_layers=model_args.compress_layers,
compress_ratio=model_args.compress_ratio,
)
return retriever, reranker
return retriever, reranker
def main():
parser = HfArgumentParser([AbsModelArgs, BEIREvalArgs])
model_args, eval_args = parser.parse_args_into_dataclasses()
model_args: AbsModelArgs
eval_args: BEIREvalArgs
retriever, reranker = get_models(model_args)
task_types = eval_args.task_types
tasks = eval_args.tasks
languages = eval_args.languages
tasks = mteb.get_tasks(
languages=languages,
tasks=tasks,
task_types=task_types
)
evaluation = mteb.MTEB(tasks=tasks)
results = evaluation.run(retriever, output_folder=f"results/{str(retriever)}")
# all_pairs = []
# for task_type in eval_args.task_types:
# if task_type in tasks_desc.keys():
# for task_name in tasks_desc[task_type]:
# all_pairs.append((task_type, task_name))
# for task_type in tasks_desc.keys():
# for v in tasks_desc[task_type]:
# if v in eval_args.task_types:
# all_pairs.append((task_type, v))
# all_pairs = list(set(all_pairs))
if __name__ == "__main__":
main()

View File

@ -38,22 +38,40 @@ parser.add_argument('--batch_size', default=32, type=int)
parser.add_argument('--examples-dir', default='/share/chaofan/code/embedder/evaluate_for_icl/examples', type=str)
parser.add_argument('--eight-special-token', default=False, type=bool)
parser.add_argument('--passage-prompt', default=False, type=bool)
parser.add_argument('--pool-type', default='last', type=bool)
args = parser.parse_args()
base_name: str = args.model_name_or_path.split('/')[-1]
if args.eight_special_token is True:
args.pool_type = 'last_eight'
else:
args.pool_type = 'last'
# if args.eight_special_token is True:
# args.pool_type = 'last_eight'
# else:
# args.pool_type = 'last'
logger.info('Args: {}'.format(json.dumps(args.__dict__, ensure_ascii=False, indent=4)))
assert args.pool_type in ['cls', 'avg', 'last', 'weightedavg', 'last_eight'], 'pool_type should be cls / avg / last'
# assert args.pool_type in ['cls', 'avg', 'last', 'weightedavg', 'last_eight'], 'pool_type should be cls / avg / last'
os.makedirs(args.output_dir, exist_ok=True)
ALL_NUM = args.all_num
def find_last_index(lst, element):
if not isinstance(element, list):
try:
reversed_list = lst[::-1]
index_in_reversed = reversed_list.index(element)
last_index = len(lst) - 1 - index_in_reversed
return last_index
except ValueError:
return -1
else:
last_index = -1
for i in range(len(lst) - len(element)):
if lst[i:i + len(element)] == element:
last_index = i
break
return last_index
print(args.special_token)
class DenseEncoder(torch.nn.Module):
def __init__(self, device, **kwargs):
def __init__(self, **kwargs):
super().__init__()
self.encoder = AutoModel.from_pretrained(args.model_name_or_path, use_cache=False)
if args.peft_name_or_path is not None:
@ -74,13 +92,15 @@ class DenseEncoder(torch.nn.Module):
self.prompt = None
self.prefix = ''
self.suffix = ''
self.gpu_count = torch.cuda.device_count()
self.encoder.half()
self.gpu_count = torch.cuda.device_count()
self.encoder.eval()
# self.encoder.cuda()
self.device = device
self.encoder = self.encoder.to(device)
self.encoder.cuda()
if self.gpu_count > 1:
self.encoder = torch.nn.DataParallel(self.encoder)
self.eight_special_token = args.eight_special_token
if args.eight_special_token:
@ -92,9 +112,26 @@ class DenseEncoder(torch.nn.Module):
self.batch_size = args.batch_size
if args.special_token:
self.index_type = 0
self.index_start_locs = self.tokenizer('<query>', add_special_tokens=False)['input_ids'][0]
self.index_end_locs = self.tokenizer('<response>', add_special_tokens=False)['input_ids'][0]
else:
self.index_type = 1
self.index_start_locs = self.tokenizer('\nQuery:', add_special_tokens=False)['input_ids'][1:]
self.index_end_locs = self.tokenizer('\nResponse:', add_special_tokens=False)['input_ids'][1:]
# if self.gpu_count > 1:
# self.encoder = torch.nn.DataParallel(self.encoder)
def get_loc(self, sentence):
sentence = list(sentence)
if isinstance(self.index_start_locs, int):
return find_last_index(sentence, self.index_start_locs) + 1, find_last_index(sentence, self.index_end_locs)
else:
return find_last_index(sentence, self.index_start_locs) + len(self.index_start_locs), find_last_index(
sentence, self.index_end_locs)
@torch.no_grad()
def encode(self, sentences, **kwargs) -> np.ndarray:
""" Returns a list of embeddings for the given sentences.
@ -115,13 +152,22 @@ class DenseEncoder(torch.nn.Module):
batch_dict = create_batch_query_dict(self.tokenizer, self.prefix, self.suffix, batch_input_texts, special_tokens=self.special_tokens)
# if self.device == 0:
# print(self.tokenizer.decode(batch_dict['input_ids'][0]))
batch_dict = batch_dict.to(self.device)
# batch_dict = move_to_cuda(batch_dict)
# print(self.tokenizer.decode(batch_dict['input_ids'][0]))
# batch_dict = batch_dict.to(self.device)
batch_dict = move_to_cuda(batch_dict)
with torch.cuda.amp.autocast():
outputs: BaseModelOutput = self.encoder(**batch_dict)
embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], args.pool_type)
if args.pool_type == 'mean':
seq_locs = [self.get_loc(sentence) for sentence in batch_dict['input_ids']]
embeds = torch.stack(
[
outputs.last_hidden_state[i, start: end, :].mean(dim=0)
for i, (start, end) in enumerate(seq_locs)
]
)
else:
embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], 'last')
if self.l2_normalize:
embeds = F.normalize(embeds, p=2, dim=-1)
encoded_embeds.append(embeds.cpu().numpy())
@ -151,7 +197,16 @@ class DenseEncoder(torch.nn.Module):
with torch.cuda.amp.autocast():
outputs: BaseModelOutput = self.encoder(**batch_dict)
embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], args.pool_type)
if args.pool_type == 'mean':
seq_locs = [self.get_loc(sentence) for sentence in batch_dict['input_ids']]
embeds = torch.stack(
[
outputs.last_hidden_state[i, start: end, :].mean(dim=0)
for i, (start, end) in enumerate(seq_locs)
]
)
else:
embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], 'last')
if self.l2_normalize:
embeds = F.normalize(embeds, p=2, dim=-1)
encoded_embeds.append(embeds.cpu().numpy())
@ -184,7 +239,16 @@ class DenseEncoder(torch.nn.Module):
with torch.cuda.amp.autocast():
outputs: BaseModelOutput = self.encoder(**batch_dict)
embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], args.pool_type)
if args.pool_type == 'mean':
seq_locs = [self.get_loc(sentence) for sentence in batch_dict['input_ids']]
embeds = torch.stack(
[
outputs.last_hidden_state[i, start: end, :].mean(dim=0)
for i, (start, end) in enumerate(seq_locs)
]
)
else:
embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], 'last')
if self.l2_normalize:
embeds = F.normalize(embeds, p=2, dim=-1)
encoded_embeds.append(embeds.cpu().numpy())
@ -201,44 +265,9 @@ class DenseEncoder(torch.nn.Module):
self.suffix = suffix
def main(device, all_pairs):
torch.cuda.set_device(device)
def main(all_pairs):
model = DenseEncoder(device)
os.environ['CUDA_VISIBLE_DEVICES'] = f'{device}'
ArxivClusteringP2P_FLAG = False
tmp = None
for i, p in enumerate(all_pairs):
if p[1] == 'ArxivClusteringP2P':
ArxivClusteringP2P_FLAG = True
tmp = p
break
if ArxivClusteringP2P_FLAG:
all_pairs.remove(tmp)
os.environ['CUDA_VISIBLE_DEVICES'] = f'{device}'
if ArxivClusteringP2P_FLAG is False:
length = len(all_pairs)
start = device * length // ALL_NUM
if device == ALL_NUM - 1:
end = length
else:
end = (device + 1) * length // ALL_NUM
all_pairs = all_pairs[start: end]
else:
if device == ALL_NUM - 1:
all_pairs = [tmp]
else:
all_num = ALL_NUM - 1
length = len(all_pairs)
start = device * length // ALL_NUM
if device == all_num - 1:
end = length
else:
end = (device + 1) * length // all_num
all_pairs = all_pairs[start: end]
model = DenseEncoder()
for (task_type, task_name) in all_pairs:
task_def: str = get_task_def_by_task_name_and_type(task_name=task_name, task_type=task_type)
@ -299,6 +328,10 @@ def main(device, all_pairs):
except Exception as e:
model.batch_size -= 4
print(e)
# sub_eval.run(
# model,
# output_folder=args.output_dir
# )
if __name__ == '__main__':
@ -317,13 +350,5 @@ if __name__ == '__main__':
if v in args.task_types:
all_pairs.append((task_type, v))
all_pairs = list(set(all_pairs))
random.shuffle(all_pairs)
for i in range(ALL_NUM):
# i = 7
process = multiprocessing.Process(target=main, args=(i,all_pairs,))
processes.append(process)
process.start()
for process in processes:
process.join()
main(all_pairs)

View File

@ -0,0 +1,10 @@
from dataclasses import dataclass, field
from typing import List
from FlagEmbedding.abc.evaluation.arguments import AbsEvalArgs
@dataclass
class MTEBEvalArgs(AbsEvalArgs):
task_types: List[str] = field(
default=None, metadata={"help": "The tasks to evaluate. Default: None"}
)

View File

@ -145,24 +145,28 @@ class BaseLLMReranker(AbsReranker):
trust_remote_code: bool = False,
devices: Union[str, List[str], List[int]] = None, # specify devices, such as ["cuda:0"] or ["0"]
prompt: str = None,
batch_size: int = 128,
query_max_length: int = None,
max_length: int = 512,
normalize: bool = False,
**kwargs: Any,
) -> None:
super().__init__(
model_name_or_path,
use_fp16,
query_instruction_for_rerank,
query_instruction_format,
passage_instruction_for_rerank,
passage_instruction_format,
devices,
model_name_or_path=model_name_or_path,
use_fp16=use_fp16,
query_instruction_for_rerank=query_instruction_for_rerank,
query_instruction_format=query_instruction_format,
passage_instruction_for_rerank=passage_instruction_for_rerank,
passage_instruction_format=passage_instruction_format,
devices=devices,
batch_size=batch_size,
query_max_length=query_max_length,
max_length=max_length,
normalize=normalize,
**kwargs
)
self.prompt = prompt
self.normalize = normalize
self.tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
cache_dir=cache_dir,
@ -178,9 +182,6 @@ class BaseLLMReranker(AbsReranker):
if peft_path:
self.model = PeftModel.from_pretrained(self.model,peft_path)
self.model = self.model.merge_and_unload()
self.model_name_or_path = model_name_or_path
self.cache_dir = cache_dir
self.kwargs = kwargs
self.yes_loc = self.tokenizer('Yes', add_special_tokens=False)['input_ids'][0]
@ -198,6 +199,13 @@ class BaseLLMReranker(AbsReranker):
**kwargs: Any
) -> List[float]:
if prompt is None: prompt = self.prompt
if batch_size is None: batch_size = self.batch_size
if max_length is None: max_length = self.max_length
if query_max_length is None:
if self.query_max_length is not None:
query_max_length = self.query_max_length
else:
query_max_length = max_length * 3 // 4
if normalize is None: normalize = self.normalize
if device is None:
@ -206,9 +214,6 @@ class BaseLLMReranker(AbsReranker):
if device == "cpu": self.use_fp16 = False
if self.use_fp16: self.model.half()
if device == "cpu": self.use_fp16 = False
if self.use_fp16: self.model.half()
self.model.to(device)
self.model.eval()
@ -227,7 +232,7 @@ class BaseLLMReranker(AbsReranker):
queries,
return_tensors=None,
add_special_tokens=False,
max_length=max_length * 3 // 4,
max_length=query_max_length,
truncation=True,
**kwargs
)

View File

@ -42,24 +42,27 @@ class LayerWiseLLMReranker(AbsReranker):
devices: Union[str, List[str], List[int]] = None, # specify devices, such as ["cuda:0"] or ["0"]
cutoff_layers: List[int] = None,
prompt: str = None,
batch_size: int = 128,
query_max_length: int = None,
max_length: int = 512,
normalize: bool = False,
**kwargs: Any,
) -> None:
super().__init__(
model_name_or_path,
use_fp16,
query_instruction_for_rerank,
query_instruction_format,
passage_instruction_for_rerank,
passage_instruction_format,
devices,
model_name_or_path=model_name_or_path,
use_fp16=use_fp16,
query_instruction_for_rerank=query_instruction_for_rerank,
query_instruction_format=query_instruction_format,
passage_instruction_for_rerank=passage_instruction_for_rerank,
passage_instruction_format=passage_instruction_format,
devices=devices,
batch_size=batch_size,
query_max_length=query_max_length,
max_length=max_length,
normalize=normalize,
**kwargs
)
self.cutoff_layers = cutoff_layers
self.prompt = prompt
self.normalize = normalize
self.tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
cache_dir=cache_dir,
@ -87,8 +90,6 @@ class LayerWiseLLMReranker(AbsReranker):
if peft_path:
self.model = PeftModel.from_pretrained(self.model,peft_path)
self.model = self.model.merge_and_unload()
self.model_name_or_path = model_name_or_path
self.cache_dir = cache_dir
@torch.no_grad()
def compute_score_single_gpu(
@ -106,6 +107,13 @@ class LayerWiseLLMReranker(AbsReranker):
) -> List[float]:
if cutoff_layers is None: cutoff_layers = self.cutoff_layers
if prompt is None: prompt = self.prompt
if batch_size is None: batch_size = self.batch_size
if max_length is None: max_length = self.max_length
if query_max_length is None:
if self.query_max_length is not None:
query_max_length = self.query_max_length
else:
query_max_length = max_length * 3 // 4
if normalize is None: normalize = self.normalize
if device is None:
@ -114,9 +122,6 @@ class LayerWiseLLMReranker(AbsReranker):
if device == "cpu": self.use_fp16 = False
if self.use_fp16: self.model.half()
if device == "cpu": self.use_fp16 = False
if self.use_fp16: self.model.half()
self.model.to(device)
self.model.eval()
@ -135,7 +140,7 @@ class LayerWiseLLMReranker(AbsReranker):
queries,
return_tensors=None,
add_special_tokens=False,
max_length=max_length * 3 // 4,
max_length=query_max_length,
truncation=True,
**kwargs
)

View File

@ -91,27 +91,28 @@ class LightweightLLMReranker(AbsReranker):
compress_layers: List[int] = [8],
compress_ratio: int = 1,
prompt: str = None,
batch_size: int = 128,
query_max_length: int = None,
max_length: int = 512,
normalize: bool = False,
**kwargs: Any,
) -> None:
super().__init__(
model_name_or_path,
use_fp16,
query_instruction_for_rerank,
query_instruction_format,
passage_instruction_for_rerank,
passage_instruction_format,
devices,
model_name_or_path=model_name_or_path,
use_fp16=use_fp16,
query_instruction_for_rerank=query_instruction_for_rerank,
query_instruction_format=query_instruction_format,
passage_instruction_for_rerank=passage_instruction_for_rerank,
passage_instruction_format=passage_instruction_format,
devices=devices,
batch_size=batch_size,
query_max_length=query_max_length,
max_length=max_length,
normalize=normalize,
**kwargs
)
self.cutoff_layers = cutoff_layers
self.compress_layers = compress_layers
self.compress_ratio = compress_ratio
self.prompt = prompt
self.normalize = normalize
self.tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
cache_dir=cache_dir,
@ -140,10 +141,6 @@ class LightweightLLMReranker(AbsReranker):
if peft_path:
self.model = PeftModel.from_pretrained(self.model,peft_path)
self.model = self.model.merge_and_unload()
self.model_name_or_path = model_name_or_path
self.cache_dir = cache_dir
self.yes_loc = self.tokenizer('Yes', add_special_tokens=False)['input_ids'][0]
@torch.no_grad()
def compute_score_single_gpu(
@ -164,6 +161,13 @@ class LightweightLLMReranker(AbsReranker):
if compress_layers is None: compress_layers = self.compress_layers
if compress_ratio is None: compress_ratio = self.compress_ratio
if prompt is None: prompt = self.prompt
if batch_size is None: batch_size = self.batch_size
if max_length is None: max_length = self.max_length
if query_max_length is None:
if self.query_max_length is not None:
query_max_length = self.query_max_length
else:
query_max_length = max_length * 3 // 4
if normalize is None: normalize = self.normalize
if device is None:
@ -172,9 +176,6 @@ class LightweightLLMReranker(AbsReranker):
if device == "cpu": self.use_fp16 = False
if self.use_fp16: self.model.half()
if device == "cpu": self.use_fp16 = False
if self.use_fp16: self.model.half()
self.model.to(device)
self.model.eval()
@ -193,7 +194,7 @@ class LightweightLLMReranker(AbsReranker):
queries,
return_tensors=None,
add_special_tokens=False,
max_length=max_length * 3 // 4,
max_length=query_max_length,
truncation=True,
**kwargs
)

View File

@ -23,36 +23,54 @@ class BaseReranker(AbsReranker):
trust_remote_code: bool = False,
cache_dir: str = None,
devices: Union[str, List[str], List[int]] = None, # specify devices, such as ["cuda:0"] or ["0"]
batch_size: int = 128,
query_max_length: int = None,
max_length: int = 512,
normalize: bool = False,
**kwargs: Any,
):
super().__init__(
model_name_or_path,
use_fp16,
query_instruction_for_rerank,
query_instruction_format,
passage_instruction_for_rerank,
passage_instruction_format,
devices,
model_name_or_path=model_name_or_path,
use_fp16=use_fp16,
query_instruction_for_rerank=query_instruction_for_rerank,
query_instruction_format=query_instruction_format,
passage_instruction_for_rerank=passage_instruction_for_rerank,
passage_instruction_format=passage_instruction_format,
devices=devices,
batch_size=batch_size,
query_max_length=query_max_length,
max_length=max_length,
normalize=normalize,
**kwargs
)
self.normalize = normalize
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code, cache_dir=cache_dir)
self.kwargs = kwargs
self.tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
cache_dir=cache_dir
)
self.model = AutoModelForSequenceClassification.from_pretrained(
model_name_or_path,
trust_remote_code=trust_remote_code,
cache_dir=cache_dir
)
@torch.no_grad()
def compute_score_single_gpu(
self,
sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
batch_size: int = 256,
max_length: int = 512,
batch_size: int = None,
query_max_length: int = None,
max_length: int = None,
normalize: bool = None,
device: str = None,
**kwargs: Any
) -> List[float]:
if batch_size is None: batch_size = self.batch_size
if max_length is None: max_length = self.max_length
if query_max_length is None:
if self.query_max_length is not None:
query_max_length = self.query_max_length
else:
query_max_length = max_length * 3 // 4
if normalize is None: normalize = self.normalize
if device is None:
@ -61,31 +79,44 @@ class BaseReranker(AbsReranker):
if device == "cpu": self.use_fp16 = False
if self.use_fp16: self.model.half()
if device == "cpu": self.use_fp16 = False
if self.use_fp16: self.model.half()
self.model.to(device)
self.model.eval()
assert isinstance(sentence_pairs, list)
if isinstance(sentence_pairs[0], str):
sentence_pairs = [sentence_pairs]
# tokenize without padding to get the correct length
all_inputs = []
for start_index in trange(0, len(sentence_pairs), batch_size, desc="pre tokenize"):
sentences_batch = sentence_pairs[start_index:start_index + batch_size]
inputs_batch = self.tokenizer(
sentences_batch,
queries = [s[0] for s in sentences_batch]
passages = [s[1] for s in sentences_batch]
queries_inputs_batch = self.tokenizer(
queries,
return_tensors=None,
add_special_tokens=False,
max_length=query_max_length,
truncation=True,
max_length=max_length,
**kwargs
)
inputs_batch = [{
k: inputs_batch[k][i] for k in inputs_batch.keys()
} for i in range(len(sentences_batch))]
all_inputs.extend(inputs_batch)
)['input_ids']
passages_inputs_batch = self.tokenizer(
passages,
return_tensors=None,
add_special_tokens=False,
max_length=max_length,
truncation=True,
**kwargs
)['input_ids']
for q_inp, d_inp in zip(queries_inputs_batch, passages_inputs_batch):
item = self.tokenizer.prepare_for_model(
q_inp,
d_inp,
truncation='only_second',
max_length=max_length,
padding=False,
)
all_inputs.append(item)
# sort by length for less padding
length_sorted_idx = np.argsort([-len(x['input_ids']) for x in all_inputs])
all_inputs_sorted = [all_inputs[i] for i in length_sorted_idx]

View File

@ -6,6 +6,9 @@ def test_base_multi_devices():
model = FlagAutoReranker.from_finetuned(
'BAAI/bge-reranker-large',
use_fp16=True,
batch_size=128,
query_max_length=256,
max_length=512,
devices=["cuda:3", "cuda:4"], # if you don't have GPUs, you can use ["cpu", "cpu"]
cache_dir=os.getenv('HF_HUB_CACHE', None),
)
@ -27,4 +30,4 @@ if __name__ == '__main__':
print("--------------------------------")
print("Expected Output:")
print("[ 7.9765625 -6.859375 -7.1484375 5.44921875]")
print("[ 7.97265625 -6.8515625 -7.15625 5.45703125]")

View File

@ -6,6 +6,9 @@ def test_base_multi_devices():
model = FlagAutoReranker.from_finetuned(
'BAAI/bge-reranker-large',
use_fp16=True,
batch_size=128,
query_max_length=256,
max_length=512,
devices=["cuda:3"], # if you don't have GPUs, you can use ["cpu", "cpu"]
cache_dir=os.getenv('HF_HUB_CACHE', None),
)
@ -27,4 +30,4 @@ if __name__ == '__main__':
print("--------------------------------")
print("Expected Output:")
print("[7.9765625, -6.859375, -7.15625, 5.44921875]")
print("[7.9765625, -6.84375, -7.15625, 5.453125]")

View File

@ -6,6 +6,9 @@ def test_base_multi_devices():
model = FlagReranker(
'BAAI/bge-reranker-large',
use_fp16=True,
batch_size=128,
query_max_length=256,
max_length=512,
devices=["cuda:3", "cuda:4"], # if you don't have GPUs, you can use ["cpu", "cpu"]
cache_dir=os.getenv('HF_HUB_CACHE', None),
)
@ -27,4 +30,4 @@ if __name__ == '__main__':
print("--------------------------------")
print("Expected Output:")
print("[ 7.9765625 -6.859375 -7.1484375 5.44921875]")
print("[ 7.97265625 -6.8515625 -7.15625 5.45703125]")

View File

@ -6,6 +6,9 @@ def test_base_multi_devices():
model = FlagReranker(
'BAAI/bge-reranker-large',
use_fp16=True,
batch_size=128,
query_max_length=256,
max_length=512,
devices=["cuda:3"], # if you don't have GPUs, you can use ["cpu", "cpu"]
cache_dir=os.getenv('HF_HUB_CACHE', None),
)
@ -27,4 +30,4 @@ if __name__ == '__main__':
print("--------------------------------")
print("Expected Output:")
print("[7.9765625, -6.859375, -7.15625, 5.44921875]")
print("[7.9765625, -6.84375, -7.15625, 5.453125]")