import argparse import json import random import numpy as np import faiss from tqdm import tqdm from FlagEmbedding import FlagModel def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--model_name_or_path', default="BAAI/bge-base-en", type=str) parser.add_argument('--input_file', default=None, type=str) parser.add_argument('--candidate_pool', default=None, type=str) parser.add_argument('--output_file', default=None, type=str) parser.add_argument('--range_for_sampling', default="10-210", type=str, help="range to sample negatives") parser.add_argument('--use_gpu_for_searching', action='store_true', help='use faiss-gpu') parser.add_argument('--negative_number', default=15, type=int, help='the number of negatives') parser.add_argument('--query_instruction_for_retrieval', default="") return parser.parse_args() def create_index(embeddings, use_gpu): index = faiss.IndexFlatIP(len(embeddings[0])) embeddings = np.asarray(embeddings, dtype=np.float32) if use_gpu: co = faiss.GpuMultipleClonerOptions() co.shard = True co.useFloat16 = True index = faiss.index_cpu_to_all_gpus(index, co=co) index.add(embeddings) return index def batch_search(index, query, topk: int = 200, batch_size: int = 64): all_scores, all_inxs = [], [] for start_index in tqdm(range(0, len(query), batch_size), desc="Batches", disable=len(query) < 256): batch_query = query[start_index:start_index + batch_size] batch_scores, batch_inxs = index.search(np.asarray(batch_query, dtype=np.float32), k=topk) all_scores.extend(batch_scores.tolist()) all_inxs.extend(batch_inxs.tolist()) return all_scores, all_inxs def get_corpus(candidate_pool): corpus = [] for line in open(candidate_pool): line = json.loads(line.strip()) corpus.append(line['text']) return corpus def find_knn_neg(model, input_file, candidate_pool, output_file, sample_range, negative_number, use_gpu): corpus = [] queries = [] train_data = [] for line in open(input_file): line = json.loads(line.strip()) train_data.append(line) corpus.extend(line['pos']) if 'neg' in line: corpus.extend(line['neg']) queries.append(line['query']) if candidate_pool is not None: if not isinstance(candidate_pool, list): candidate_pool = get_corpus(candidate_pool) corpus = list(set(candidate_pool)) else: corpus = list(set(corpus)) print(f'inferencing embedding for corpus (number={len(corpus)})--------------') p_vecs = model.encode(corpus, batch_size=256) print(f'inferencing embedding for queries (number={len(queries)})--------------') q_vecs = model.encode_queries(queries, batch_size=256) print('create index and search------------------') index = create_index(p_vecs, use_gpu=use_gpu) _, all_inxs = batch_search(index, q_vecs, topk=sample_range[-1]) assert len(all_inxs) == len(train_data) for i, data in enumerate(train_data): query = data['query'] inxs = all_inxs[i][sample_range[0]:sample_range[1]] filtered_inx = [] for inx in inxs: if inx == -1: break if corpus[inx] not in data['pos'] and corpus[inx] != query: filtered_inx.append(inx) if len(filtered_inx) > negative_number: filtered_inx = random.sample(filtered_inx, negative_number) data['neg'] = [corpus[inx] for inx in filtered_inx] with open(output_file, 'w') as f: for data in train_data: if len(data['neg']) < negative_number: samples = random.sample(corpus, negative_number - len(data['neg']) + len(data['pos'])) samples = [sent for sent in samples if sent not in data['pos']] data['neg'].extend(samples[: negative_number - len(data['neg'])]) f.write(json.dumps(data, ensure_ascii=False) + '\n') if __name__ == '__main__': args = get_args() sample_range = args.range_for_sampling.split('-') sample_range = [int(x) for x in sample_range] model = FlagModel(args.model_name_or_path, query_instruction_for_retrieval=args.query_instruction_for_retrieval) find_knn_neg(model, input_file=args.input_file, candidate_pool=args.candidate_pool, output_file=args.output_file, sample_range=sample_range, negative_number=args.negative_number, use_gpu=args.use_gpu_for_searching)