diff --git a/.gitignore b/.gitignore index a8b7ed1..fe1214f 100644 --- a/.gitignore +++ b/.gitignore @@ -132,5 +132,3 @@ model_card.md # Pyre type checker .pyre/ -/FlagEmbedding/baai_general_embedding/hn_mine.py -/FlagEmbedding/baai_general_embedding/finetune/hn_mine.py diff --git a/FlagEmbedding/baai_general_embedding/README.md b/FlagEmbedding/baai_general_embedding/README.md index 68cb9b9..ea6499a 100644 --- a/FlagEmbedding/baai_general_embedding/README.md +++ b/FlagEmbedding/baai_general_embedding/README.md @@ -66,7 +66,25 @@ Noted that use your instruction as the value of argument `query_instruction_for_ See [examples/finetune](../../examples/finetune) for a toy data and training example. +**Hard Negatives** +Hard negatives is a widely used method to improve the quality of sentence embedding. +You can mine hard negatives following this command: +```bash +python -m FlagEmbedding.baai_general_embedding.finetune.hn_mine \ +--model_name_or_path BAAI/bge-base-en \ +--input_file toy_finetune_data.jsonl \ +--output_file toy_finetune_data_minedHN.jsonl \ +--range_for_sampling 2-200 +``` + +- `input_file`: json data for finetuning. This script will retrieval top-k documents for each query, +and random sample negatives from the top-k documents (not including the positive documents). +- `output_file`: path to save json data with mined hard negatives for finetuning +- `range_for_sampling`: where to sample negative. For example, `2-100` means sampling negative from top2-top200 documents. +- `candidate_pool`: The pool to retrieval. Default value is None, and this script will retrieve from the combination of all `neg` in `input_file`. +The format of this file is the same as pretrain data. If input a candidate_pool, this script will retrieve negative from this file. +- `use_gpu_for_searching`: whether use faiss-gpu to retrieve negatives. #### 2. Train ``` diff --git a/FlagEmbedding/baai_general_embedding/finetune/hn_mine.py b/FlagEmbedding/baai_general_embedding/finetune/hn_mine.py new file mode 100644 index 0000000..9b1d7d9 --- /dev/null +++ b/FlagEmbedding/baai_general_embedding/finetune/hn_mine.py @@ -0,0 +1,139 @@ +import argparse +import json +import random + +import faiss +from tqdm import tqdm + +from FlagEmbedding import FlagModel + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--model_name_or_path', default="BAAI/bge-base-en", type=str) + parser.add_argument('--input_file', default=None, type=str) + parser.add_argument('--candidate_pool', default=None, type=str) + parser.add_argument('--output_file', default=None, type=str) + parser.add_argument('--range_for_sampling', default=None, type=str, help="range to sample negatives") + parser.add_argument('--use_gpu_for_searching', action='store_true', help='use faiss-gpu') + parser.add_argument('--negative_number', default=15, help='use faiss-gpu') + + return parser.parse_args() + + +def create_index(embeddings, use_gpu): + index = faiss.IndexFlatIP(len(embeddings[0])) + if use_gpu: + co = faiss.GpuMultipleClonerOptions() + co.shard = True + co.useFloat16 = True + index = faiss.index_cpu_to_all_gpus(index, co=co) + index.add(embeddings) + return index + + +def batch_search(index, + query, + topk: int = 200, + batch_size: int = 64): + all_scores, all_inxs = [], [] + for start_index in tqdm(range(0, len(query), batch_size), desc="Batches", disable=len(query) < 256): + batch_query = query[start_index:start_index + batch_size] + batch_scores, batch_inxs = index.search(batch_query, k=topk) + all_scores.extend(batch_scores.tolist()) + all_inxs.extend(batch_inxs.tolist()) + return all_scores, all_inxs + + +def get_corpus(candidate_pool): + corpus = [] + for line in open(candidate_pool): + line = json.loads(line.strip()) + corpus.append(line['text']) + return corpus + + +def find_knn_neg(model, input_file, candidate_pool, output_file, sample_range, negative_number, use_gpu): + corpus = [] + queries = [] + train_data = [] + for line in open(input_file): + line = json.loads(line.strip()) + train_data.append(line) + corpus.extend(line['neg']) + queries.append(line['query']) + + if candidate_pool is not None: + corpus = get_corpus(candidate_pool) + corpus = list(set(corpus)) + + print(f'inferencing embedding for corpus (number={len(corpus)})--------------') + p_vecs = model.encode(corpus, batch_size=256) + print(f'inferencing embedding for queries (number={len(queries)})--------------') + q_vecs = model.encode(queries, batch_size=256) + + print('creat index and search------------------') + index = create_index(p_vecs, use_gpu=use_gpu) + _, all_inxs = batch_search(index, q_vecs, topk=sample_range[-1]) + assert len(all_inxs) == len(train_data) + + for i, data in enumerate(train_data): + query = data['query'] + inxs = all_inxs[i][sample_range[0]:sample_range[1]] + filtered_inx = [] + for inx in inxs: + if inx == -1: break + if corpus[inx] not in data['pos'] and corpus[inx] != query: + filtered_inx.append(inx) + + if len(filtered_inx) > negative_number: + filtered_inx = random.sample(filtered_inx, negative_number) + data['neg'] = [corpus[inx] for inx in filtered_inx] + + with open(output_file, 'w') as f: + for data in train_data: + if len(data['neg']) < negative_number: + data['neg'].extend(random.sample(corpus, negative_number - len(data['neg']))) + f.write(json.dumps(data) + '\n') + + +if __name__ == '__main__': + args = get_args() + sample_range = args.range_for_sampling.split('-') + sample_range = [int(x) for x in sample_range] + + model = FlagModel(args.model_name_or_path) + + find_knn_neg(model, + input_file=args.input_file, + candidate_pool=args.candidate_pool, + output_file=args.output_file, + sample_range=sample_range, + negative_number=args.negative_number, + use_gpu=args.use_gpu_for_searching) + +""" +**Hard Negatives** + +```bash +python -m FlagEmbedding.baai_general_embedding.finetune.hn_mine \ +--model_name_or_path BAAI/bge-base-en \ +--input_file toy_finetune_data.jsonl \ +--output_file toy_finetune_data_minedHN.jsonl \ +--range_for_sampling 0-10 + +python -m FlagEmbedding.baai_general_embedding.finetune.hn_mine \ +--model_name_or_path BAAI/bge-base-zh \ +--input_file /share/dataset/for_embeddings/finetune/zh/post-finetune-prompt/dureader_retrieval-data.jsonl \ +--output_file /share/dataset/for_embeddings/finetune/zh/post-finetune-prompt/dureader_retrieval-data-hn.jsonl \ +--range_for_sampling 2-200 \ +--use_gpu_for_searching + +python -m FlagEmbedding.baai_general_embedding.finetune.hn_mine \ +--model_name_or_path BAAI/bge-base-zh \ +--input_file /share/dataset/for_embeddings/finetune/zh/post-finetune-prompt/cMedQAv2.jsonl \ +--output_file /share/dataset/for_embeddings/finetune/zh/post-finetune-prompt/cMedQAv2.jsonl \ +--range_for_sampling 2-200 \ +--use_gpu_for_searching +``` +""" \ No newline at end of file diff --git a/examples/finetune/README.md b/examples/finetune/README.md index 98d00fc..dbd6f32 100644 --- a/examples/finetune/README.md +++ b/examples/finetune/README.md @@ -37,6 +37,26 @@ Noted that use your instruction as the value of argument `query_instruction_for_ See [toy_finetune_data.jsonl]() for a toy data file. +**Hard Negatives** + +Hard negatives is a widely used method to improve the quality of sentence embedding. +You can mine hard negatives following this command: +```bash +python -m FlagEmbedding.baai_general_embedding.finetune.hn_mine \ +--model_name_or_path BAAI/bge-base-en \ +--input_file toy_finetune_data.jsonl \ +--output_file toy_finetune_data_minedHN.jsonl \ +--range_for_sampling 2-200 +``` + +- `input_file`: json data for finetuning. This script will retrieval top-k documents for each query, +and random sample negatives from the top-k documents (not including the positive documents). +- `output_file`: path to save json data with mined hard negatives for finetuning +- `range_for_sampling`: where to sample negative. For example, `2-100` means sampling negative from top2-top200 documents. +- `candidate_pool`: The pool to retrieval. Default value is None, and this script will retrieve from the combination of all `neg` in `input_file`. +The format of this file is the same as pretrain data. If input a candidate_pool, this script will retrieve negative from this file. +- `use_gpu_for_searching`: whether use faiss-gpu to retrieve negatives. + ## Train ``` diff --git a/setup.py b/setup.py index 759afea..6fedc60 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ with open("README.md", mode="r", encoding="utf-8") as readme_file: setup( name='FlagEmbedding', - version='1.0.4', + version='1.0.5', description='FlagEmbedding', long_description=readme, long_description_content_type="text/markdown",