import json import random import numpy as np from tqdm import tqdm from typing import Optional from dataclasses import dataclass, field import faiss from transformers import HfArgumentParser from FlagEmbedding import FlagAutoModel from FlagEmbedding.abc.inference import AbsEmbedder @dataclass class DataArgs: """ Data arguments for hard negative mining. """ input_file: str = field( metadata={"help": "The input file for hard negative mining."} ) output_file: str = field( metadata={"help": "The output file for hard negative mining."} ) candidate_pool: Optional[str] = field( default=None, metadata={"help": "The candidate pool for hard negative mining. If provided, it should be a jsonl file, each line is a dict with a key 'text'."} ) range_for_sampling: str = field( default="10-210", metadata={"help": "The range to sample negatives."} ) negative_number: int = field( default=15, metadata={"help": "The number of negatives."} ) use_gpu_for_searching: bool = field( default=False, metadata={"help": "Whether to use faiss-gpu for searching."} ) search_batch_size: int = field( default=64, metadata={"help": "The batch size for searching."} ) @dataclass class ModelArgs: """ Model arguments for embedder. """ embedder_name_or_path: str = field( metadata={"help": "The embedder name or path.", "required": True} ) embedder_model_class: Optional[str] = field( default=None, metadata={"help": "The embedder model class. Available classes: ['encoder-only-base', 'encoder-only-m3', 'decoder-only-base', 'decoder-only-icl']. Default: None. For the custom model, you need to specifiy the model class.", "choices": ["encoder-only-base", "encoder-only-m3", "decoder-only-base", "decoder-only-icl"]} ) normalize_embeddings: bool = field( default=True, metadata={"help": "whether to normalize the embeddings"} ) pooling_method: str = field( default="cls", metadata={"help": "The pooling method fot the embedder."} ) use_fp16: bool = field( default=True, metadata={"help": "whether to use fp16 for inference"} ) devices: Optional[str] = field( default=None, metadata={"help": "Devices to use for inference.", "nargs": "+"} ) query_instruction_for_retrieval: Optional[str] = field( default=None, metadata={"help": "Instruction for query"} ) query_instruction_format_for_retrieval: str = field( default="{}{}", metadata={"help": "Format for query instruction"} ) examples_for_task: Optional[str] = field( default=None, metadata={"help": "Examples for task"} ) examples_instruction_format: str = field( default="{}{}", metadata={"help": "Format for examples instruction"} ) trust_remote_code: bool = field( default=False, metadata={"help": "Trust remote code"} ) cache_dir: str = field( default=None, metadata={"help": "Cache directory for models."} ) # ================ for inference =============== batch_size: int = field( default=3000, metadata={"help": "Batch size for inference."} ) embedder_query_max_length: int = field( default=512, metadata={"help": "Max length for query."} ) embedder_passage_max_length: int = field( default=512, metadata={"help": "Max length for passage."} ) def __post_init__(self): # replace "\\n" with "\n" if "\\n" in self.query_instruction_format_for_retrieval: self.query_instruction_format_for_retrieval = self.query_instruction_format_for_retrieval.replace("\\n", "\n") if "\\n" in self.examples_instruction_format: self.examples_instruction_format = self.examples_instruction_format.replace("\\n", "\n") def create_index(embeddings: np.ndarray, use_gpu: bool = False): index = faiss.IndexFlatIP(len(embeddings[0])) embeddings = np.asarray(embeddings, dtype=np.float32) if use_gpu: co = faiss.GpuMultipleClonerOptions() co.shard = True co.useFloat16 = True index = faiss.index_cpu_to_all_gpus(index, co=co) index.add(embeddings) return index def batch_search( index: faiss.Index, query: np.ndarray, topk: int = 200, batch_size: int = 64 ): all_scores, all_inxs = [], [] for start_index in tqdm(range(0, len(query), batch_size), desc="Batches", disable=len(query) < 256): batch_query = query[start_index:start_index + batch_size] batch_scores, batch_inxs = index.search(np.asarray(batch_query, dtype=np.float32), k=topk) all_scores.extend(batch_scores.tolist()) all_inxs.extend(batch_inxs.tolist()) return all_scores, all_inxs def get_corpus(candidate_pool: str): corpus = [] with open(candidate_pool, "r", encoding="utf-8") as f: for line in f.readlines(): line = json.loads(line.strip()) corpus.append(line['text']) return corpus def find_knn_neg( model: AbsEmbedder, input_file: str, output_file: str, candidate_pool: Optional[str] = None, sample_range: str = "10-210", negative_number: int = 15, use_gpu: bool = False ): corpus = [] queries = [] train_data = [] for line in open(input_file): line = json.loads(line.strip()) train_data.append(line) corpus.extend(line['pos']) if 'neg' in line: corpus.extend(line['neg']) queries.append(line['query']) if candidate_pool is not None: if not isinstance(candidate_pool, list): candidate_pool = get_corpus(candidate_pool) corpus = list(set(candidate_pool)) else: corpus = list(set(corpus)) print(f'inferencing embedding for corpus (number={len(corpus)})--------------') p_vecs = model.encode(corpus) print(f'inferencing embedding for queries (number={len(queries)})--------------') q_vecs = model.encode_queries(queries) # check if the embeddings are in dictionary format: M3Embedder if isinstance(p_vecs, dict): p_vecs = p_vecs["dense_vecs"] if isinstance(q_vecs, dict): q_vecs = q_vecs["dense_vecs"] print('create index and search------------------') index = create_index(p_vecs, use_gpu=use_gpu) _, all_inxs = batch_search(index, q_vecs, topk=sample_range[-1]) assert len(all_inxs) == len(train_data) for i, data in enumerate(train_data): query = data['query'] inxs = all_inxs[i][sample_range[0]:sample_range[1]] filtered_inx = [] for inx in inxs: if inx == -1: break if corpus[inx] not in data['pos'] and corpus[inx] != query: filtered_inx.append(inx) if len(filtered_inx) > negative_number: filtered_inx = random.sample(filtered_inx, negative_number) data['neg'] = [corpus[inx] for inx in filtered_inx] with open(output_file, 'w') as f: for data in train_data: if len(data['neg']) < negative_number: samples = random.sample(corpus, negative_number - len(data['neg']) + len(data['pos'])) samples = [sent for sent in samples if sent not in data['pos']] data['neg'].extend(samples[: negative_number - len(data['neg'])]) f.write(json.dumps(data, ensure_ascii=False) + '\n') def load_model(model_args: ModelArgs): model = FlagAutoModel.from_finetuned( model_name_or_path=model_args.embedder_name_or_path, model_class=model_args.embedder_model_class, normalize_embeddings=model_args.normalize_embeddings, pooling_method=model_args.pooling_method, use_fp16=model_args.use_fp16, query_instruction_for_retrieval=model_args.query_instruction_for_retrieval, query_instruction_format=model_args.query_instruction_format_for_retrieval, devices=model_args.devices, examples_for_task=model_args.examples_for_task, examples_instruction_format=model_args.examples_instruction_format, trust_remote_code=model_args.trust_remote_code, cache_dir=model_args.cache_dir, batch_size=model_args.batch_size, query_max_length=model_args.embedder_query_max_length, passage_max_length=model_args.embedder_passage_max_length, ) return model def main(data_args: DataArgs, model_args: ModelArgs): model = load_model(model_args) find_knn_neg( model=model, input_file=data_args.input_file, output_file=data_args.output_file, candidate_pool=data_args.candidate_pool, sample_range=[int(x) for x in data_args.range_for_sampling.split('-')], negative_number=data_args.negative_number, use_gpu=data_args.use_gpu_for_searching ) if __name__ == "__main__": parser = HfArgumentParser(( DataArgs, ModelArgs )) data_args, model_args = parser.parse_args_into_dataclasses() data_args: DataArgs model_args: ModelArgs main(data_args, model_args)