FlagEmbedding/scripts/hn_mine.py

123 lines
4.6 KiB
Python
Raw Normal View History

import argparse
import json
import random
import numpy as np
import faiss
from tqdm import tqdm
from FlagEmbedding import FlagModel
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model_name_or_path', default="BAAI/bge-base-en", type=str)
parser.add_argument('--input_file', default=None, type=str)
parser.add_argument('--candidate_pool', default=None, type=str)
parser.add_argument('--output_file', default=None, type=str)
parser.add_argument('--range_for_sampling', default="10-210", type=str, help="range to sample negatives")
parser.add_argument('--use_gpu_for_searching', action='store_true', help='use faiss-gpu')
parser.add_argument('--negative_number', default=15, type=int, help='the number of negatives')
parser.add_argument('--query_instruction_for_retrieval', default="")
return parser.parse_args()
def create_index(embeddings, use_gpu):
index = faiss.IndexFlatIP(len(embeddings[0]))
embeddings = np.asarray(embeddings, dtype=np.float32)
if use_gpu:
co = faiss.GpuMultipleClonerOptions()
co.shard = True
co.useFloat16 = True
index = faiss.index_cpu_to_all_gpus(index, co=co)
index.add(embeddings)
return index
def batch_search(index,
query,
topk: int = 200,
batch_size: int = 64):
all_scores, all_inxs = [], []
for start_index in tqdm(range(0, len(query), batch_size), desc="Batches", disable=len(query) < 256):
batch_query = query[start_index:start_index + batch_size]
batch_scores, batch_inxs = index.search(np.asarray(batch_query, dtype=np.float32), k=topk)
all_scores.extend(batch_scores.tolist())
all_inxs.extend(batch_inxs.tolist())
return all_scores, all_inxs
def get_corpus(candidate_pool):
corpus = []
for line in open(candidate_pool):
line = json.loads(line.strip())
corpus.append(line['text'])
return corpus
def find_knn_neg(model, input_file, candidate_pool, output_file, sample_range, negative_number, use_gpu):
corpus = []
queries = []
train_data = []
for line in open(input_file):
line = json.loads(line.strip())
train_data.append(line)
corpus.extend(line['pos'])
if 'neg' in line:
corpus.extend(line['neg'])
queries.append(line['query'])
if candidate_pool is not None:
if not isinstance(candidate_pool, list):
candidate_pool = get_corpus(candidate_pool)
corpus = list(set(candidate_pool))
else:
corpus = list(set(corpus))
print(f'inferencing embedding for corpus (number={len(corpus)})--------------')
p_vecs = model.encode(corpus, batch_size=256)
print(f'inferencing embedding for queries (number={len(queries)})--------------')
q_vecs = model.encode_queries(queries, batch_size=256)
print('create index and search------------------')
index = create_index(p_vecs, use_gpu=use_gpu)
_, all_inxs = batch_search(index, q_vecs, topk=sample_range[-1])
assert len(all_inxs) == len(train_data)
for i, data in enumerate(train_data):
query = data['query']
inxs = all_inxs[i][sample_range[0]:sample_range[1]]
filtered_inx = []
for inx in inxs:
if inx == -1: break
if corpus[inx] not in data['pos'] and corpus[inx] != query:
filtered_inx.append(inx)
if len(filtered_inx) > negative_number:
filtered_inx = random.sample(filtered_inx, negative_number)
data['neg'] = [corpus[inx] for inx in filtered_inx]
with open(output_file, 'w') as f:
for data in train_data:
if len(data['neg']) < negative_number:
samples = random.sample(corpus, negative_number - len(data['neg']) + len(data['pos']))
samples = [sent for sent in samples if sent not in data['pos']]
data['neg'].extend(samples[: negative_number - len(data['neg'])])
f.write(json.dumps(data, ensure_ascii=False) + '\n')
if __name__ == '__main__':
args = get_args()
sample_range = args.range_for_sampling.split('-')
sample_range = [int(x) for x in sample_range]
model = FlagModel(args.model_name_or_path, query_instruction_for_retrieval=args.query_instruction_for_retrieval)
find_knn_neg(model,
input_file=args.input_file,
candidate_pool=args.candidate_pool,
output_file=args.output_file,
sample_range=sample_range,
negative_number=args.negative_number,
use_gpu=args.use_gpu_for_searching)