2024-10-29 20:15:14 +08:00
import json
2024-10-29 20:26:25 +08:00
from typing import Optional , List
2024-10-29 20:15:14 +08:00
from dataclasses import dataclass , field
from transformers import HfArgumentParser
2024-11-23 16:35:12 +08:00
from FlagEmbedding import FlagAutoReranker
2024-10-29 20:15:14 +08:00
@dataclass
class ScoreArgs :
input_file : str = field (
2024-10-29 20:26:25 +08:00
default = None , metadata = { " help " : " The input jsonl file, each line includes query, pos and neg. " }
2024-10-29 20:15:14 +08:00
)
output_file : str = field (
2024-10-29 20:26:25 +08:00
default = None , metadata = { " help " : " The output jsonl file, it includes query, pos, neg, pos_scores and neg_scores. " }
)
2024-11-23 16:35:12 +08:00
2024-10-29 20:26:25 +08:00
@dataclass
class ModelArgs :
use_fp16 : bool = field (
default = True , metadata = { " help " : " whether to use fp16 for inference " }
)
devices : Optional [ str ] = field (
default = None , metadata = { " help " : " Devices to use for inference. " , " nargs " : " + " }
)
trust_remote_code : bool = field (
default = False , metadata = { " help " : " Trust remote code " }
)
reranker_name_or_path : Optional [ str ] = field (
default = None , metadata = { " help " : " The reranker name or path. " }
)
reranker_model_class : Optional [ str ] = field (
2024-11-12 12:27:28 +08:00
default = None , metadata = { " help " : " The reranker model class. Available classes: [ ' encoder-only-base ' , ' decoder-only-base ' , ' decoder-only-layerwise ' , ' decoder-only-lightweight ' ]. Default: None. For the custom model, you need to specify the model class. " , " choices " : [ " encoder-only-base " , " decoder-only-base " , " decoder-only-layerwise " , " decoder-only-lightweight " ] }
2024-10-29 20:26:25 +08:00
)
reranker_peft_path : Optional [ str ] = field (
default = None , metadata = { " help " : " The reranker peft path. " }
)
use_bf16 : bool = field (
default = False , metadata = { " help " : " whether to use bf16 for inference " }
)
query_instruction_for_rerank : Optional [ str ] = field (
default = None , metadata = { " help " : " Instruction for query " }
)
query_instruction_format_for_rerank : str = field (
default = " {} {} " , metadata = { " help " : " Format for query instruction " }
)
passage_instruction_for_rerank : Optional [ str ] = field (
default = None , metadata = { " help " : " Instruction for passage " }
)
passage_instruction_format_for_rerank : str = field (
default = " {} {} " , metadata = { " help " : " Format for passage instruction " }
)
cache_dir : str = field (
default = None , metadata = { " help " : " Cache directory for models. " }
)
# ================ for inference ===============
reranker_batch_size : int = field (
default = 3000 , metadata = { " help " : " Batch size for inference. " }
)
reranker_query_max_length : Optional [ int ] = field (
default = None , metadata = { " help " : " Max length for reranking. " }
)
reranker_max_length : int = field (
default = 512 , metadata = { " help " : " Max length for reranking. " }
)
normalize : bool = field (
default = False , metadata = { " help " : " whether to normalize the reranking scores " }
)
prompt : Optional [ str ] = field (
default = None , metadata = { " help " : " The prompt for the reranker. " }
)
cutoff_layers : List [ int ] = field (
default = None , metadata = { " help " : " The output layers of layerwise/lightweight reranker. " }
)
compress_ratio : int = field (
default = 1 , metadata = { " help " : " The compress ratio of lightweight reranker. " }
)
compress_layers : Optional [ int ] = field (
default = None , metadata = { " help " : " The compress layers of lightweight reranker. " , " nargs " : " + " }
2024-10-29 20:15:14 +08:00
)
2024-11-23 16:35:12 +08:00
def main ( score_args : ScoreArgs , model_args : ModelArgs ) :
2024-10-29 20:15:14 +08:00
reranker = FlagAutoReranker . from_finetuned (
model_name_or_path = model_args . reranker_name_or_path ,
model_class = model_args . reranker_model_class ,
peft_path = model_args . reranker_peft_path ,
use_fp16 = model_args . use_fp16 ,
use_bf16 = model_args . use_bf16 ,
query_instruction_for_rerank = model_args . query_instruction_for_rerank ,
query_instruction_format = model_args . query_instruction_format_for_rerank ,
passage_instruction_for_rerank = model_args . passage_instruction_for_rerank ,
passage_instruction_format = model_args . passage_instruction_format_for_rerank ,
cache_dir = model_args . cache_dir ,
trust_remote_code = model_args . trust_remote_code ,
devices = model_args . devices ,
normalize = model_args . normalize ,
prompt = model_args . prompt ,
cutoff_layers = model_args . cutoff_layers ,
compress_layers = model_args . compress_layers ,
compress_ratio = model_args . compress_ratio ,
batch_size = model_args . reranker_batch_size ,
query_max_length = model_args . reranker_query_max_length ,
max_length = model_args . reranker_max_length ,
)
pairs = [ ]
data = [ ]
with open ( score_args . input_file ) as f :
for line in f :
data . append ( json . loads ( line ) )
for p in data [ - 1 ] [ ' pos ' ] :
pairs . append ( ( data [ - 1 ] [ ' query ' ] , p ) )
for p in data [ - 1 ] [ ' neg ' ] :
pairs . append ( ( data [ - 1 ] [ ' query ' ] , p ) )
scores = reranker . compute_score ( pairs )
score_idx = 0
for i in range ( len ( data ) ) :
data [ i ] [ ' pos_scores ' ] = [ ]
data [ i ] [ ' neg_scores ' ] = [ ]
for _ in range ( len ( data [ i ] [ ' pos ' ] ) ) :
data [ i ] [ ' pos_scores ' ] . append ( float ( scores [ score_idx ] ) )
score_idx + = 1
for _ in range ( len ( data [ i ] [ ' neg ' ] ) ) :
data [ i ] [ ' neg_scores ' ] . append ( float ( scores [ score_idx ] ) )
score_idx + = 1
2024-10-29 20:27:36 +08:00
with open ( score_args . output_file , ' w ' ) as f :
2024-10-29 20:15:14 +08:00
for d in data :
2024-10-29 20:30:40 +08:00
f . write ( json . dumps ( d ) + ' \n ' )
2024-10-29 20:36:30 +08:00
2024-11-23 16:35:12 +08:00
if __name__ == " __main__ " :
2024-10-29 20:36:30 +08:00
parser = HfArgumentParser ( (
ScoreArgs ,
ModelArgs
) )
score_args , model_args = parser . parse_args_into_dataclasses ( )
score_args : ScoreArgs
model_args : ModelArgs
main ( score_args , model_args )