FlagEmbedding/scripts/hn_mine.py
2025-02-13 22:41:54 +08:00

247 lines
8.9 KiB
Python

import json
import random
import numpy as np
from tqdm import tqdm
from typing import Optional
from dataclasses import dataclass, field
import faiss
from transformers import HfArgumentParser
from FlagEmbedding import FlagAutoModel
from FlagEmbedding.abc.inference import AbsEmbedder
@dataclass
class DataArgs:
"""
Data arguments for hard negative mining.
"""
input_file: str = field(
metadata={"help": "The input file for hard negative mining."}
)
output_file: str = field(
metadata={"help": "The output file for hard negative mining."}
)
candidate_pool: Optional[str] = field(
default=None, metadata={"help": "The candidate pool for hard negative mining. If provided, it should be a jsonl file, each line is a dict with a key 'text'."}
)
range_for_sampling: str = field(
default="10-210", metadata={"help": "The range to sample negatives."}
)
negative_number: int = field(
default=15, metadata={"help": "The number of negatives."}
)
use_gpu_for_searching: bool = field(
default=False, metadata={"help": "Whether to use faiss-gpu for searching."}
)
search_batch_size: int = field(
default=64, metadata={"help": "The batch size for searching."}
)
@dataclass
class ModelArgs:
"""
Model arguments for embedder.
"""
embedder_name_or_path: str = field(
metadata={"help": "The embedder name or path.", "required": True}
)
embedder_model_class: Optional[str] = field(
default=None, metadata={"help": "The embedder model class. Available classes: ['encoder-only-base', 'encoder-only-m3', 'decoder-only-base', 'decoder-only-icl']. Default: None. For the custom model, you need to specifiy the model class.", "choices": ["encoder-only-base", "encoder-only-m3", "decoder-only-base", "decoder-only-icl"]}
)
normalize_embeddings: bool = field(
default=True, metadata={"help": "whether to normalize the embeddings"}
)
pooling_method: str = field(
default="cls", metadata={"help": "The pooling method fot the embedder."}
)
use_fp16: bool = field(
default=True, metadata={"help": "whether to use fp16 for inference"}
)
devices: Optional[str] = field(
default=None, metadata={"help": "Devices to use for inference.", "nargs": "+"}
)
query_instruction_for_retrieval: Optional[str] = field(
default=None, metadata={"help": "Instruction for query"}
)
query_instruction_format_for_retrieval: str = field(
default="{}{}", metadata={"help": "Format for query instruction"}
)
examples_for_task: Optional[str] = field(
default=None, metadata={"help": "Examples for task"}
)
examples_instruction_format: str = field(
default="{}{}", metadata={"help": "Format for examples instruction"}
)
trust_remote_code: bool = field(
default=False, metadata={"help": "Trust remote code"}
)
cache_dir: str = field(
default=None, metadata={"help": "Cache directory for models."}
)
# ================ for inference ===============
batch_size: int = field(
default=3000, metadata={"help": "Batch size for inference."}
)
embedder_query_max_length: int = field(
default=512, metadata={"help": "Max length for query."}
)
embedder_passage_max_length: int = field(
default=512, metadata={"help": "Max length for passage."}
)
def __post_init__(self):
# replace "\\n" with "\n"
if "\\n" in self.query_instruction_format_for_retrieval:
self.query_instruction_format_for_retrieval = self.query_instruction_format_for_retrieval.replace("\\n", "\n")
if "\\n" in self.examples_instruction_format:
self.examples_instruction_format = self.examples_instruction_format.replace("\\n", "\n")
def create_index(embeddings: np.ndarray, use_gpu: bool = False):
index = faiss.IndexFlatIP(len(embeddings[0]))
embeddings = np.asarray(embeddings, dtype=np.float32)
if use_gpu:
co = faiss.GpuMultipleClonerOptions()
co.shard = True
co.useFloat16 = True
index = faiss.index_cpu_to_all_gpus(index, co=co)
index.add(embeddings)
return index
def batch_search(
index: faiss.Index,
query: np.ndarray,
topk: int = 200,
batch_size: int = 64
):
all_scores, all_inxs = [], []
for start_index in tqdm(range(0, len(query), batch_size), desc="Batches", disable=len(query) < 256):
batch_query = query[start_index:start_index + batch_size]
batch_scores, batch_inxs = index.search(np.asarray(batch_query, dtype=np.float32), k=topk)
all_scores.extend(batch_scores.tolist())
all_inxs.extend(batch_inxs.tolist())
return all_scores, all_inxs
def get_corpus(candidate_pool: str):
corpus = []
with open(candidate_pool, "r", encoding="utf-8") as f:
for line in f.readlines():
line = json.loads(line.strip())
corpus.append(line['text'])
return corpus
def find_knn_neg(
model: AbsEmbedder,
input_file: str,
output_file: str,
candidate_pool: Optional[str] = None,
sample_range: str = "10-210",
negative_number: int = 15,
use_gpu: bool = False
):
corpus = []
queries = []
train_data = []
for line in open(input_file):
line = json.loads(line.strip())
train_data.append(line)
corpus.extend(line['pos'])
if 'neg' in line:
corpus.extend(line['neg'])
queries.append(line['query'])
if candidate_pool is not None:
if not isinstance(candidate_pool, list):
candidate_pool = get_corpus(candidate_pool)
corpus = list(set(candidate_pool))
else:
corpus = list(set(corpus))
print(f'inferencing embedding for corpus (number={len(corpus)})--------------')
p_vecs = model.encode(corpus)
print(f'inferencing embedding for queries (number={len(queries)})--------------')
q_vecs = model.encode_queries(queries)
# check if the embeddings are in dictionary format: M3Embedder
if isinstance(p_vecs, dict):
p_vecs = p_vecs["dense_vecs"]
if isinstance(q_vecs, dict):
q_vecs = q_vecs["dense_vecs"]
print('create index and search------------------')
index = create_index(p_vecs, use_gpu=use_gpu)
_, all_inxs = batch_search(index, q_vecs, topk=sample_range[-1])
assert len(all_inxs) == len(train_data)
for i, data in enumerate(train_data):
query = data['query']
inxs = all_inxs[i][sample_range[0]:sample_range[1]]
filtered_inx = []
for inx in inxs:
if inx == -1: break
if corpus[inx] not in data['pos'] and corpus[inx] != query:
filtered_inx.append(inx)
if len(filtered_inx) > negative_number:
filtered_inx = random.sample(filtered_inx, negative_number)
data['neg'] = [corpus[inx] for inx in filtered_inx]
with open(output_file, 'w') as f:
for data in train_data:
if len(data['neg']) < negative_number:
samples = random.sample(corpus, negative_number - len(data['neg']) + len(data['pos']))
samples = [sent for sent in samples if sent not in data['pos']]
data['neg'].extend(samples[: negative_number - len(data['neg'])])
f.write(json.dumps(data, ensure_ascii=False) + '\n')
def load_model(model_args: ModelArgs):
model = FlagAutoModel.from_finetuned(
model_name_or_path=model_args.embedder_name_or_path,
model_class=model_args.embedder_model_class,
normalize_embeddings=model_args.normalize_embeddings,
pooling_method=model_args.pooling_method,
use_fp16=model_args.use_fp16,
query_instruction_for_retrieval=model_args.query_instruction_for_retrieval,
query_instruction_format=model_args.query_instruction_format_for_retrieval,
devices=model_args.devices,
examples_for_task=model_args.examples_for_task,
examples_instruction_format=model_args.examples_instruction_format,
trust_remote_code=model_args.trust_remote_code,
cache_dir=model_args.cache_dir,
batch_size=model_args.batch_size,
query_max_length=model_args.embedder_query_max_length,
passage_max_length=model_args.embedder_passage_max_length,
)
return model
def main(data_args: DataArgs, model_args: ModelArgs):
model = load_model(model_args)
find_knn_neg(
model=model,
input_file=data_args.input_file,
output_file=data_args.output_file,
candidate_pool=data_args.candidate_pool,
sample_range=[int(x) for x in data_args.range_for_sampling.split('-')],
negative_number=data_args.negative_number,
use_gpu=data_args.use_gpu_for_searching
)
if __name__ == "__main__":
parser = HfArgumentParser((
DataArgs,
ModelArgs
))
data_args, model_args = parser.parse_args_into_dataclasses()
data_args: DataArgs
model_args: ModelArgs
main(data_args, model_args)