2024-10-18 00:21:39 +08:00
import json
import random
import numpy as np
from tqdm import tqdm
2024-11-23 16:35:12 +08:00
from typing import Optional
from dataclasses import dataclass , field
2024-10-18 00:21:39 +08:00
2024-11-23 16:35:12 +08:00
import faiss
from transformers import HfArgumentParser
from FlagEmbedding import FlagAutoModel
from FlagEmbedding . abc . inference import AbsEmbedder
@dataclass
class DataArgs :
"""
Data arguments for hard negative mining .
"""
input_file : str = field (
metadata = { " help " : " The input file for hard negative mining. " }
)
output_file : str = field (
metadata = { " help " : " The output file for hard negative mining. " }
)
candidate_pool : Optional [ str ] = field (
default = None , metadata = { " help " : " The candidate pool for hard negative mining. If provided, it should be a jsonl file, each line is a dict with a key ' text ' . " }
)
range_for_sampling : str = field (
default = " 10-210 " , metadata = { " help " : " The range to sample negatives. " }
)
negative_number : int = field (
default = 15 , metadata = { " help " : " The number of negatives. " }
)
use_gpu_for_searching : bool = field (
default = False , metadata = { " help " : " Whether to use faiss-gpu for searching. " }
)
search_batch_size : int = field (
default = 64 , metadata = { " help " : " The batch size for searching. " }
)
@dataclass
class ModelArgs :
"""
Model arguments for embedder .
"""
embedder_name_or_path : str = field (
metadata = { " help " : " The embedder name or path. " , " required " : True }
)
embedder_model_class : Optional [ str ] = field (
default = None , metadata = { " help " : " The embedder model class. Available classes: [ ' encoder-only-base ' , ' encoder-only-m3 ' , ' decoder-only-base ' , ' decoder-only-icl ' ]. Default: None. For the custom model, you need to specifiy the model class. " , " choices " : [ " encoder-only-base " , " encoder-only-m3 " , " decoder-only-base " , " decoder-only-icl " ] }
)
normalize_embeddings : bool = field (
default = True , metadata = { " help " : " whether to normalize the embeddings " }
)
pooling_method : str = field (
default = " cls " , metadata = { " help " : " The pooling method fot the embedder. " }
)
use_fp16 : bool = field (
default = True , metadata = { " help " : " whether to use fp16 for inference " }
)
devices : Optional [ str ] = field (
default = None , metadata = { " help " : " Devices to use for inference. " , " nargs " : " + " }
)
query_instruction_for_retrieval : Optional [ str ] = field (
default = None , metadata = { " help " : " Instruction for query " }
)
query_instruction_format_for_retrieval : str = field (
default = " {} {} " , metadata = { " help " : " Format for query instruction " }
)
examples_for_task : Optional [ str ] = field (
default = None , metadata = { " help " : " Examples for task " }
)
examples_instruction_format : str = field (
default = " {} {} " , metadata = { " help " : " Format for examples instruction " }
)
trust_remote_code : bool = field (
default = False , metadata = { " help " : " Trust remote code " }
)
cache_dir : str = field (
default = None , metadata = { " help " : " Cache directory for models. " }
)
# ================ for inference ===============
batch_size : int = field (
default = 3000 , metadata = { " help " : " Batch size for inference. " }
)
embedder_query_max_length : int = field (
default = 512 , metadata = { " help " : " Max length for query. " }
)
embedder_passage_max_length : int = field (
default = 512 , metadata = { " help " : " Max length for passage. " }
)
2025-02-13 22:41:54 +08:00
def __post_init__ ( self ) :
# replace "\\n" with "\n"
if " \\ n " in self . query_instruction_format_for_retrieval :
self . query_instruction_format_for_retrieval = self . query_instruction_format_for_retrieval . replace ( " \\ n " , " \n " )
if " \\ n " in self . examples_instruction_format :
self . examples_instruction_format = self . examples_instruction_format . replace ( " \\ n " , " \n " )
2024-11-23 16:35:12 +08:00
def create_index ( embeddings : np . ndarray , use_gpu : bool = False ) :
2024-10-18 00:21:39 +08:00
index = faiss . IndexFlatIP ( len ( embeddings [ 0 ] ) )
embeddings = np . asarray ( embeddings , dtype = np . float32 )
if use_gpu :
co = faiss . GpuMultipleClonerOptions ( )
co . shard = True
co . useFloat16 = True
index = faiss . index_cpu_to_all_gpus ( index , co = co )
index . add ( embeddings )
return index
2024-11-23 16:35:12 +08:00
def batch_search (
index : faiss . Index ,
query : np . ndarray ,
topk : int = 200 ,
batch_size : int = 64
) :
2024-10-18 00:21:39 +08:00
all_scores , all_inxs = [ ] , [ ]
for start_index in tqdm ( range ( 0 , len ( query ) , batch_size ) , desc = " Batches " , disable = len ( query ) < 256 ) :
batch_query = query [ start_index : start_index + batch_size ]
batch_scores , batch_inxs = index . search ( np . asarray ( batch_query , dtype = np . float32 ) , k = topk )
all_scores . extend ( batch_scores . tolist ( ) )
all_inxs . extend ( batch_inxs . tolist ( ) )
return all_scores , all_inxs
2024-11-23 16:35:12 +08:00
def get_corpus ( candidate_pool : str ) :
2024-10-18 00:21:39 +08:00
corpus = [ ]
2024-11-23 16:35:12 +08:00
with open ( candidate_pool , " r " , encoding = " utf-8 " ) as f :
for line in f . readlines ( ) :
line = json . loads ( line . strip ( ) )
corpus . append ( line [ ' text ' ] )
2024-10-18 00:21:39 +08:00
return corpus
2024-11-23 16:35:12 +08:00
def find_knn_neg (
model : AbsEmbedder ,
input_file : str ,
output_file : str ,
candidate_pool : Optional [ str ] = None ,
sample_range : str = " 10-210 " ,
negative_number : int = 15 ,
use_gpu : bool = False
) :
2024-10-18 00:21:39 +08:00
corpus = [ ]
queries = [ ]
train_data = [ ]
for line in open ( input_file ) :
line = json . loads ( line . strip ( ) )
train_data . append ( line )
corpus . extend ( line [ ' pos ' ] )
if ' neg ' in line :
corpus . extend ( line [ ' neg ' ] )
queries . append ( line [ ' query ' ] )
if candidate_pool is not None :
if not isinstance ( candidate_pool , list ) :
candidate_pool = get_corpus ( candidate_pool )
corpus = list ( set ( candidate_pool ) )
else :
corpus = list ( set ( corpus ) )
print ( f ' inferencing embedding for corpus (number= { len ( corpus ) } )-------------- ' )
2024-11-23 16:35:12 +08:00
p_vecs = model . encode ( corpus )
2024-10-18 00:21:39 +08:00
print ( f ' inferencing embedding for queries (number= { len ( queries ) } )-------------- ' )
2024-11-23 16:35:12 +08:00
q_vecs = model . encode_queries ( queries )
2025-01-18 14:40:52 +08:00
# check if the embeddings are in dictionary format: M3Embedder
if isinstance ( p_vecs , dict ) :
p_vecs = p_vecs [ " dense_vecs " ]
if isinstance ( q_vecs , dict ) :
q_vecs = q_vecs [ " dense_vecs " ]
2024-10-18 00:21:39 +08:00
print ( ' create index and search------------------ ' )
index = create_index ( p_vecs , use_gpu = use_gpu )
_ , all_inxs = batch_search ( index , q_vecs , topk = sample_range [ - 1 ] )
assert len ( all_inxs ) == len ( train_data )
for i , data in enumerate ( train_data ) :
query = data [ ' query ' ]
inxs = all_inxs [ i ] [ sample_range [ 0 ] : sample_range [ 1 ] ]
filtered_inx = [ ]
for inx in inxs :
if inx == - 1 : break
if corpus [ inx ] not in data [ ' pos ' ] and corpus [ inx ] != query :
filtered_inx . append ( inx )
if len ( filtered_inx ) > negative_number :
filtered_inx = random . sample ( filtered_inx , negative_number )
data [ ' neg ' ] = [ corpus [ inx ] for inx in filtered_inx ]
with open ( output_file , ' w ' ) as f :
for data in train_data :
if len ( data [ ' neg ' ] ) < negative_number :
samples = random . sample ( corpus , negative_number - len ( data [ ' neg ' ] ) + len ( data [ ' pos ' ] ) )
samples = [ sent for sent in samples if sent not in data [ ' pos ' ] ]
data [ ' neg ' ] . extend ( samples [ : negative_number - len ( data [ ' neg ' ] ) ] )
f . write ( json . dumps ( data , ensure_ascii = False ) + ' \n ' )
2024-11-23 16:35:12 +08:00
def load_model ( model_args : ModelArgs ) :
model = FlagAutoModel . from_finetuned (
model_name_or_path = model_args . embedder_name_or_path ,
model_class = model_args . embedder_model_class ,
normalize_embeddings = model_args . normalize_embeddings ,
pooling_method = model_args . pooling_method ,
use_fp16 = model_args . use_fp16 ,
query_instruction_for_retrieval = model_args . query_instruction_for_retrieval ,
query_instruction_format = model_args . query_instruction_format_for_retrieval ,
devices = model_args . devices ,
examples_for_task = model_args . examples_for_task ,
examples_instruction_format = model_args . examples_instruction_format ,
trust_remote_code = model_args . trust_remote_code ,
cache_dir = model_args . cache_dir ,
batch_size = model_args . batch_size ,
query_max_length = model_args . embedder_query_max_length ,
passage_max_length = model_args . embedder_passage_max_length ,
)
return model
def main ( data_args : DataArgs , model_args : ModelArgs ) :
model = load_model ( model_args )
find_knn_neg (
model = model ,
input_file = data_args . input_file ,
output_file = data_args . output_file ,
candidate_pool = data_args . candidate_pool ,
sample_range = [ int ( x ) for x in data_args . range_for_sampling . split ( ' - ' ) ] ,
negative_number = data_args . negative_number ,
use_gpu = data_args . use_gpu_for_searching
)
if __name__ == " __main__ " :
parser = HfArgumentParser ( (
DataArgs ,
ModelArgs
) )
data_args , model_args = parser . parse_args_into_dataclasses ( )
data_args : DataArgs
model_args : ModelArgs
main ( data_args , model_args )