mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-06-27 02:39:58 +00:00
123 lines
4.6 KiB
Python
123 lines
4.6 KiB
Python
![]() |
import argparse
|
||
|
import json
|
||
|
import random
|
||
|
import numpy as np
|
||
|
import faiss
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
from FlagEmbedding import FlagModel
|
||
|
|
||
|
|
||
|
def get_args():
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument('--model_name_or_path', default="BAAI/bge-base-en", type=str)
|
||
|
parser.add_argument('--input_file', default=None, type=str)
|
||
|
parser.add_argument('--candidate_pool', default=None, type=str)
|
||
|
parser.add_argument('--output_file', default=None, type=str)
|
||
|
parser.add_argument('--range_for_sampling', default="10-210", type=str, help="range to sample negatives")
|
||
|
parser.add_argument('--use_gpu_for_searching', action='store_true', help='use faiss-gpu')
|
||
|
parser.add_argument('--negative_number', default=15, type=int, help='the number of negatives')
|
||
|
parser.add_argument('--query_instruction_for_retrieval', default="")
|
||
|
|
||
|
return parser.parse_args()
|
||
|
|
||
|
|
||
|
def create_index(embeddings, use_gpu):
|
||
|
index = faiss.IndexFlatIP(len(embeddings[0]))
|
||
|
embeddings = np.asarray(embeddings, dtype=np.float32)
|
||
|
if use_gpu:
|
||
|
co = faiss.GpuMultipleClonerOptions()
|
||
|
co.shard = True
|
||
|
co.useFloat16 = True
|
||
|
index = faiss.index_cpu_to_all_gpus(index, co=co)
|
||
|
index.add(embeddings)
|
||
|
return index
|
||
|
|
||
|
|
||
|
def batch_search(index,
|
||
|
query,
|
||
|
topk: int = 200,
|
||
|
batch_size: int = 64):
|
||
|
all_scores, all_inxs = [], []
|
||
|
for start_index in tqdm(range(0, len(query), batch_size), desc="Batches", disable=len(query) < 256):
|
||
|
batch_query = query[start_index:start_index + batch_size]
|
||
|
batch_scores, batch_inxs = index.search(np.asarray(batch_query, dtype=np.float32), k=topk)
|
||
|
all_scores.extend(batch_scores.tolist())
|
||
|
all_inxs.extend(batch_inxs.tolist())
|
||
|
return all_scores, all_inxs
|
||
|
|
||
|
|
||
|
def get_corpus(candidate_pool):
|
||
|
corpus = []
|
||
|
for line in open(candidate_pool):
|
||
|
line = json.loads(line.strip())
|
||
|
corpus.append(line['text'])
|
||
|
return corpus
|
||
|
|
||
|
|
||
|
def find_knn_neg(model, input_file, candidate_pool, output_file, sample_range, negative_number, use_gpu):
|
||
|
corpus = []
|
||
|
queries = []
|
||
|
train_data = []
|
||
|
for line in open(input_file):
|
||
|
line = json.loads(line.strip())
|
||
|
train_data.append(line)
|
||
|
corpus.extend(line['pos'])
|
||
|
if 'neg' in line:
|
||
|
corpus.extend(line['neg'])
|
||
|
queries.append(line['query'])
|
||
|
|
||
|
if candidate_pool is not None:
|
||
|
if not isinstance(candidate_pool, list):
|
||
|
candidate_pool = get_corpus(candidate_pool)
|
||
|
corpus = list(set(candidate_pool))
|
||
|
else:
|
||
|
corpus = list(set(corpus))
|
||
|
|
||
|
print(f'inferencing embedding for corpus (number={len(corpus)})--------------')
|
||
|
p_vecs = model.encode(corpus, batch_size=256)
|
||
|
print(f'inferencing embedding for queries (number={len(queries)})--------------')
|
||
|
q_vecs = model.encode_queries(queries, batch_size=256)
|
||
|
|
||
|
print('create index and search------------------')
|
||
|
index = create_index(p_vecs, use_gpu=use_gpu)
|
||
|
_, all_inxs = batch_search(index, q_vecs, topk=sample_range[-1])
|
||
|
assert len(all_inxs) == len(train_data)
|
||
|
|
||
|
for i, data in enumerate(train_data):
|
||
|
query = data['query']
|
||
|
inxs = all_inxs[i][sample_range[0]:sample_range[1]]
|
||
|
filtered_inx = []
|
||
|
for inx in inxs:
|
||
|
if inx == -1: break
|
||
|
if corpus[inx] not in data['pos'] and corpus[inx] != query:
|
||
|
filtered_inx.append(inx)
|
||
|
|
||
|
if len(filtered_inx) > negative_number:
|
||
|
filtered_inx = random.sample(filtered_inx, negative_number)
|
||
|
data['neg'] = [corpus[inx] for inx in filtered_inx]
|
||
|
|
||
|
with open(output_file, 'w') as f:
|
||
|
for data in train_data:
|
||
|
if len(data['neg']) < negative_number:
|
||
|
samples = random.sample(corpus, negative_number - len(data['neg']) + len(data['pos']))
|
||
|
samples = [sent for sent in samples if sent not in data['pos']]
|
||
|
data['neg'].extend(samples[: negative_number - len(data['neg'])])
|
||
|
f.write(json.dumps(data, ensure_ascii=False) + '\n')
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
args = get_args()
|
||
|
sample_range = args.range_for_sampling.split('-')
|
||
|
sample_range = [int(x) for x in sample_range]
|
||
|
|
||
|
model = FlagModel(args.model_name_or_path, query_instruction_for_retrieval=args.query_instruction_for_retrieval)
|
||
|
|
||
|
find_knn_neg(model,
|
||
|
input_file=args.input_file,
|
||
|
candidate_pool=args.candidate_pool,
|
||
|
output_file=args.output_file,
|
||
|
sample_range=sample_range,
|
||
|
negative_number=args.negative_number,
|
||
|
use_gpu=args.use_gpu_for_searching)
|