mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-06-27 02:39:58 +00:00
145 lines
5.5 KiB
Python
145 lines
5.5 KiB
Python
import json
|
|
from typing import Optional, List
|
|
|
|
from dataclasses import dataclass, field
|
|
from transformers import HfArgumentParser
|
|
from FlagEmbedding import FlagAutoReranker
|
|
|
|
|
|
@dataclass
|
|
class ScoreArgs:
|
|
input_file: str = field(
|
|
default=None, metadata={"help": "The input jsonl file, each line includes query, pos and neg."}
|
|
)
|
|
output_file: str = field(
|
|
default=None, metadata={"help": "The output jsonl file, it includes query, pos, neg, pos_scores and neg_scores."}
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class ModelArgs:
|
|
use_fp16: bool = field(
|
|
default=True, metadata={"help": "whether to use fp16 for inference"}
|
|
)
|
|
devices: Optional[str] = field(
|
|
default=None, metadata={"help": "Devices to use for inference.", "nargs": "+"}
|
|
)
|
|
trust_remote_code: bool = field(
|
|
default=False, metadata={"help": "Trust remote code"}
|
|
)
|
|
reranker_name_or_path: Optional[str] = field(
|
|
default=None, metadata={"help": "The reranker name or path."}
|
|
)
|
|
reranker_model_class: Optional[str] = field(
|
|
default=None, metadata={"help": "The reranker model class. Available classes: ['encoder-only-base', 'decoder-only-base', 'decoder-only-layerwise', 'decoder-only-lightweight']. Default: None. For the custom model, you need to specify the model class.", "choices": ["encoder-only-base", "decoder-only-base", "decoder-only-layerwise", "decoder-only-lightweight"]}
|
|
)
|
|
reranker_peft_path: Optional[str] = field(
|
|
default=None, metadata={"help": "The reranker peft path."}
|
|
)
|
|
use_bf16: bool = field(
|
|
default=False, metadata={"help": "whether to use bf16 for inference"}
|
|
)
|
|
query_instruction_for_rerank: Optional[str] = field(
|
|
default=None, metadata={"help": "Instruction for query"}
|
|
)
|
|
query_instruction_format_for_rerank: str = field(
|
|
default="{}{}", metadata={"help": "Format for query instruction"}
|
|
)
|
|
passage_instruction_for_rerank: Optional[str] = field(
|
|
default=None, metadata={"help": "Instruction for passage"}
|
|
)
|
|
passage_instruction_format_for_rerank: str = field(
|
|
default="{}{}", metadata={"help": "Format for passage instruction"}
|
|
)
|
|
cache_dir: str = field(
|
|
default=None, metadata={"help": "Cache directory for models."}
|
|
)
|
|
# ================ for inference ===============
|
|
reranker_batch_size: int = field(
|
|
default=3000, metadata={"help": "Batch size for inference."}
|
|
)
|
|
reranker_query_max_length: Optional[int] = field(
|
|
default=None, metadata={"help": "Max length for reranking."}
|
|
)
|
|
reranker_max_length: int = field(
|
|
default=512, metadata={"help": "Max length for reranking."}
|
|
)
|
|
normalize: bool = field(
|
|
default=False, metadata={"help": "whether to normalize the reranking scores"}
|
|
)
|
|
prompt: Optional[str] = field(
|
|
default=None, metadata={"help": "The prompt for the reranker."}
|
|
)
|
|
cutoff_layers: List[int] = field(
|
|
default=None, metadata={"help": "The output layers of layerwise/lightweight reranker."}
|
|
)
|
|
compress_ratio: int = field(
|
|
default=1, metadata={"help": "The compress ratio of lightweight reranker."}
|
|
)
|
|
compress_layers: Optional[int] = field(
|
|
default=None, metadata={"help": "The compress layers of lightweight reranker.", "nargs": "+"}
|
|
)
|
|
|
|
|
|
def main(score_args: ScoreArgs, model_args: ModelArgs):
|
|
reranker = FlagAutoReranker.from_finetuned(
|
|
model_name_or_path=model_args.reranker_name_or_path,
|
|
model_class=model_args.reranker_model_class,
|
|
peft_path=model_args.reranker_peft_path,
|
|
use_fp16=model_args.use_fp16,
|
|
use_bf16=model_args.use_bf16,
|
|
query_instruction_for_rerank=model_args.query_instruction_for_rerank,
|
|
query_instruction_format=model_args.query_instruction_format_for_rerank,
|
|
passage_instruction_for_rerank=model_args.passage_instruction_for_rerank,
|
|
passage_instruction_format=model_args.passage_instruction_format_for_rerank,
|
|
cache_dir=model_args.cache_dir,
|
|
trust_remote_code=model_args.trust_remote_code,
|
|
devices=model_args.devices,
|
|
normalize=model_args.normalize,
|
|
prompt=model_args.prompt,
|
|
cutoff_layers=model_args.cutoff_layers,
|
|
compress_layers=model_args.compress_layers,
|
|
compress_ratio=model_args.compress_ratio,
|
|
batch_size=model_args.reranker_batch_size,
|
|
query_max_length=model_args.reranker_query_max_length,
|
|
max_length=model_args.reranker_max_length,
|
|
)
|
|
|
|
pairs = []
|
|
data = []
|
|
with open(score_args.input_file) as f:
|
|
for line in f:
|
|
data.append(json.loads(line))
|
|
for p in data[-1]['pos']:
|
|
pairs.append((data[-1]['query'], p))
|
|
for p in data[-1]['neg']:
|
|
pairs.append((data[-1]['query'], p))
|
|
|
|
scores = reranker.compute_score(pairs)
|
|
|
|
score_idx = 0
|
|
for i in range(len(data)):
|
|
data[i]['pos_scores'] = []
|
|
data[i]['neg_scores'] = []
|
|
for _ in range(len(data[i]['pos'])):
|
|
data[i]['pos_scores'].append(float(scores[score_idx]))
|
|
score_idx += 1
|
|
for _ in range(len(data[i]['neg'])):
|
|
data[i]['neg_scores'].append(float(scores[score_idx]))
|
|
score_idx += 1
|
|
|
|
with open(score_args.output_file, 'w') as f:
|
|
for d in data:
|
|
f.write(json.dumps(d) + '\n')
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = HfArgumentParser((
|
|
ScoreArgs,
|
|
ModelArgs
|
|
))
|
|
score_args, model_args = parser.parse_args_into_dataclasses()
|
|
score_args: ScoreArgs
|
|
model_args: ModelArgs
|
|
main(score_args, model_args)
|