mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-06-27 02:39:58 +00:00
hard negative mining
This commit is contained in:
parent
3915d173b8
commit
ab4c941ce9
2
.gitignore
vendored
2
.gitignore
vendored
@ -132,5 +132,3 @@ model_card.md
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
/FlagEmbedding/baai_general_embedding/hn_mine.py
|
||||
/FlagEmbedding/baai_general_embedding/finetune/hn_mine.py
|
||||
|
@ -66,7 +66,25 @@ Noted that use your instruction as the value of argument `query_instruction_for_
|
||||
See [examples/finetune](../../examples/finetune) for a toy data and training example.
|
||||
|
||||
|
||||
**Hard Negatives**
|
||||
|
||||
Hard negatives is a widely used method to improve the quality of sentence embedding.
|
||||
You can mine hard negatives following this command:
|
||||
```bash
|
||||
python -m FlagEmbedding.baai_general_embedding.finetune.hn_mine \
|
||||
--model_name_or_path BAAI/bge-base-en \
|
||||
--input_file toy_finetune_data.jsonl \
|
||||
--output_file toy_finetune_data_minedHN.jsonl \
|
||||
--range_for_sampling 2-200
|
||||
```
|
||||
|
||||
- `input_file`: json data for finetuning. This script will retrieval top-k documents for each query,
|
||||
and random sample negatives from the top-k documents (not including the positive documents).
|
||||
- `output_file`: path to save json data with mined hard negatives for finetuning
|
||||
- `range_for_sampling`: where to sample negative. For example, `2-100` means sampling negative from top2-top200 documents.
|
||||
- `candidate_pool`: The pool to retrieval. Default value is None, and this script will retrieve from the combination of all `neg` in `input_file`.
|
||||
The format of this file is the same as pretrain data. If input a candidate_pool, this script will retrieve negative from this file.
|
||||
- `use_gpu_for_searching`: whether use faiss-gpu to retrieve negatives.
|
||||
|
||||
#### 2. Train
|
||||
```
|
||||
|
139
FlagEmbedding/baai_general_embedding/finetune/hn_mine.py
Normal file
139
FlagEmbedding/baai_general_embedding/finetune/hn_mine.py
Normal file
@ -0,0 +1,139 @@
|
||||
import argparse
|
||||
import json
|
||||
import random
|
||||
|
||||
import faiss
|
||||
from tqdm import tqdm
|
||||
|
||||
from FlagEmbedding import FlagModel
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model_name_or_path', default="BAAI/bge-base-en", type=str)
|
||||
parser.add_argument('--input_file', default=None, type=str)
|
||||
parser.add_argument('--candidate_pool', default=None, type=str)
|
||||
parser.add_argument('--output_file', default=None, type=str)
|
||||
parser.add_argument('--range_for_sampling', default=None, type=str, help="range to sample negatives")
|
||||
parser.add_argument('--use_gpu_for_searching', action='store_true', help='use faiss-gpu')
|
||||
parser.add_argument('--negative_number', default=15, help='use faiss-gpu')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def create_index(embeddings, use_gpu):
|
||||
index = faiss.IndexFlatIP(len(embeddings[0]))
|
||||
if use_gpu:
|
||||
co = faiss.GpuMultipleClonerOptions()
|
||||
co.shard = True
|
||||
co.useFloat16 = True
|
||||
index = faiss.index_cpu_to_all_gpus(index, co=co)
|
||||
index.add(embeddings)
|
||||
return index
|
||||
|
||||
|
||||
def batch_search(index,
|
||||
query,
|
||||
topk: int = 200,
|
||||
batch_size: int = 64):
|
||||
all_scores, all_inxs = [], []
|
||||
for start_index in tqdm(range(0, len(query), batch_size), desc="Batches", disable=len(query) < 256):
|
||||
batch_query = query[start_index:start_index + batch_size]
|
||||
batch_scores, batch_inxs = index.search(batch_query, k=topk)
|
||||
all_scores.extend(batch_scores.tolist())
|
||||
all_inxs.extend(batch_inxs.tolist())
|
||||
return all_scores, all_inxs
|
||||
|
||||
|
||||
def get_corpus(candidate_pool):
|
||||
corpus = []
|
||||
for line in open(candidate_pool):
|
||||
line = json.loads(line.strip())
|
||||
corpus.append(line['text'])
|
||||
return corpus
|
||||
|
||||
|
||||
def find_knn_neg(model, input_file, candidate_pool, output_file, sample_range, negative_number, use_gpu):
|
||||
corpus = []
|
||||
queries = []
|
||||
train_data = []
|
||||
for line in open(input_file):
|
||||
line = json.loads(line.strip())
|
||||
train_data.append(line)
|
||||
corpus.extend(line['neg'])
|
||||
queries.append(line['query'])
|
||||
|
||||
if candidate_pool is not None:
|
||||
corpus = get_corpus(candidate_pool)
|
||||
corpus = list(set(corpus))
|
||||
|
||||
print(f'inferencing embedding for corpus (number={len(corpus)})--------------')
|
||||
p_vecs = model.encode(corpus, batch_size=256)
|
||||
print(f'inferencing embedding for queries (number={len(queries)})--------------')
|
||||
q_vecs = model.encode(queries, batch_size=256)
|
||||
|
||||
print('creat index and search------------------')
|
||||
index = create_index(p_vecs, use_gpu=use_gpu)
|
||||
_, all_inxs = batch_search(index, q_vecs, topk=sample_range[-1])
|
||||
assert len(all_inxs) == len(train_data)
|
||||
|
||||
for i, data in enumerate(train_data):
|
||||
query = data['query']
|
||||
inxs = all_inxs[i][sample_range[0]:sample_range[1]]
|
||||
filtered_inx = []
|
||||
for inx in inxs:
|
||||
if inx == -1: break
|
||||
if corpus[inx] not in data['pos'] and corpus[inx] != query:
|
||||
filtered_inx.append(inx)
|
||||
|
||||
if len(filtered_inx) > negative_number:
|
||||
filtered_inx = random.sample(filtered_inx, negative_number)
|
||||
data['neg'] = [corpus[inx] for inx in filtered_inx]
|
||||
|
||||
with open(output_file, 'w') as f:
|
||||
for data in train_data:
|
||||
if len(data['neg']) < negative_number:
|
||||
data['neg'].extend(random.sample(corpus, negative_number - len(data['neg'])))
|
||||
f.write(json.dumps(data) + '\n')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_args()
|
||||
sample_range = args.range_for_sampling.split('-')
|
||||
sample_range = [int(x) for x in sample_range]
|
||||
|
||||
model = FlagModel(args.model_name_or_path)
|
||||
|
||||
find_knn_neg(model,
|
||||
input_file=args.input_file,
|
||||
candidate_pool=args.candidate_pool,
|
||||
output_file=args.output_file,
|
||||
sample_range=sample_range,
|
||||
negative_number=args.negative_number,
|
||||
use_gpu=args.use_gpu_for_searching)
|
||||
|
||||
"""
|
||||
**Hard Negatives**
|
||||
|
||||
```bash
|
||||
python -m FlagEmbedding.baai_general_embedding.finetune.hn_mine \
|
||||
--model_name_or_path BAAI/bge-base-en \
|
||||
--input_file toy_finetune_data.jsonl \
|
||||
--output_file toy_finetune_data_minedHN.jsonl \
|
||||
--range_for_sampling 0-10
|
||||
|
||||
python -m FlagEmbedding.baai_general_embedding.finetune.hn_mine \
|
||||
--model_name_or_path BAAI/bge-base-zh \
|
||||
--input_file /share/dataset/for_embeddings/finetune/zh/post-finetune-prompt/dureader_retrieval-data.jsonl \
|
||||
--output_file /share/dataset/for_embeddings/finetune/zh/post-finetune-prompt/dureader_retrieval-data-hn.jsonl \
|
||||
--range_for_sampling 2-200 \
|
||||
--use_gpu_for_searching
|
||||
|
||||
python -m FlagEmbedding.baai_general_embedding.finetune.hn_mine \
|
||||
--model_name_or_path BAAI/bge-base-zh \
|
||||
--input_file /share/dataset/for_embeddings/finetune/zh/post-finetune-prompt/cMedQAv2.jsonl \
|
||||
--output_file /share/dataset/for_embeddings/finetune/zh/post-finetune-prompt/cMedQAv2.jsonl \
|
||||
--range_for_sampling 2-200 \
|
||||
--use_gpu_for_searching
|
||||
```
|
||||
"""
|
@ -37,6 +37,26 @@ Noted that use your instruction as the value of argument `query_instruction_for_
|
||||
|
||||
See [toy_finetune_data.jsonl]() for a toy data file.
|
||||
|
||||
**Hard Negatives**
|
||||
|
||||
Hard negatives is a widely used method to improve the quality of sentence embedding.
|
||||
You can mine hard negatives following this command:
|
||||
```bash
|
||||
python -m FlagEmbedding.baai_general_embedding.finetune.hn_mine \
|
||||
--model_name_or_path BAAI/bge-base-en \
|
||||
--input_file toy_finetune_data.jsonl \
|
||||
--output_file toy_finetune_data_minedHN.jsonl \
|
||||
--range_for_sampling 2-200
|
||||
```
|
||||
|
||||
- `input_file`: json data for finetuning. This script will retrieval top-k documents for each query,
|
||||
and random sample negatives from the top-k documents (not including the positive documents).
|
||||
- `output_file`: path to save json data with mined hard negatives for finetuning
|
||||
- `range_for_sampling`: where to sample negative. For example, `2-100` means sampling negative from top2-top200 documents.
|
||||
- `candidate_pool`: The pool to retrieval. Default value is None, and this script will retrieve from the combination of all `neg` in `input_file`.
|
||||
The format of this file is the same as pretrain data. If input a candidate_pool, this script will retrieve negative from this file.
|
||||
- `use_gpu_for_searching`: whether use faiss-gpu to retrieve negatives.
|
||||
|
||||
|
||||
## Train
|
||||
```
|
||||
|
Loading…
x
Reference in New Issue
Block a user