mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2026-01-08 13:11:35 +00:00
update reranker FT
This commit is contained in:
parent
98d2621f3c
commit
456899a100
@ -8,6 +8,7 @@ from peft import LoraConfig, TaskType, get_peft_model, PeftModel
|
||||
from FlagEmbedding.finetune.reranker.decoder_only.layerwise.arguments import RerankerModelArguments
|
||||
|
||||
from .modeling_minicpm_reranker import LayerWiseMiniCPMForCausalLM, LayerWiseHead
|
||||
from .configuration_minicpm_reranker import LayerWiseMiniCPMConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -41,7 +42,7 @@ def get_model(model_args: RerankerModelArguments, only_for_one_logit):
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
token=model_args,
|
||||
token=model_args.token,
|
||||
cache_dir=model_args.cache_dir
|
||||
)
|
||||
else:
|
||||
@ -61,7 +62,7 @@ def get_model(model_args: RerankerModelArguments, only_for_one_logit):
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
# torch_dtype=torch.float16 if training_args.fp16 else torch.bfloat16,
|
||||
use_flash_attention_2=True if model_args.use_flash_attn else False,
|
||||
token=model_args,
|
||||
token=model_args.token,
|
||||
cache_dir=model_args.cache_dir,
|
||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||
config=config,
|
||||
@ -115,7 +116,7 @@ def get_model(model_args: RerankerModelArguments, only_for_one_logit):
|
||||
model_args.model_name_or_path,
|
||||
# torch_dtype=torch.float16 if training_args.fp16 else torch.bfloat16,
|
||||
use_flash_attention_2=True if model_args.use_flash_attn else False,
|
||||
token=model_args,
|
||||
token=model_args.token,
|
||||
cache_dir=model_args.cache_dir,
|
||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||
config=config,
|
||||
@ -155,14 +156,14 @@ def save_merged_model(model_args: RerankerModelArguments, output_dir: str):
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_args.config_name,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
token=model_args,
|
||||
token=model_args.token,
|
||||
cache_dir=model_args.cache_dir
|
||||
)
|
||||
elif model_args.model_name_or_path:
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
token=model_args,
|
||||
token=model_args.token,
|
||||
cache_dir=model_args.cache_dir
|
||||
)
|
||||
else:
|
||||
@ -172,19 +173,19 @@ def save_merged_model(model_args: RerankerModelArguments, output_dir: str):
|
||||
config.use_cache = False
|
||||
|
||||
if model_args.model_type == 'from_raw_model':
|
||||
config = AutoConfig.from_pretrained('BAAI/bge-reranker-v2-minicpm-layerwise',
|
||||
cache_dir=model_args.cache_dir,
|
||||
token=model_args,
|
||||
trust_remote_code=model_args.trust_remote_code)
|
||||
config = LayerWiseMiniCPMConfig.from_pretrained('BAAI/bge-reranker-v2-minicpm-layerwise',
|
||||
cache_dir=model_args.cache_dir,
|
||||
token=model_args.token,
|
||||
trust_remote_code=model_args.trust_remote_code)
|
||||
config.start_layer = model_args.start_layer
|
||||
config.head_multi = model_args.head_multi
|
||||
config.head_type = model_args.head_type
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path,
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
token=model_args,
|
||||
trust_remote_code=model_args.trust_remote_code)
|
||||
model = LayerWiseMiniCPMForCausalLM.from_pretrained(model_args.model_name_or_path,
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
token=model_args.token,
|
||||
trust_remote_code=model_args.trust_remote_code)
|
||||
|
||||
if model_args.raw_peft is not None:
|
||||
for peft_path in model_args.raw_peft:
|
||||
|
||||
@ -8,7 +8,7 @@ from FlagEmbedding.abc.inference import AbsReranker
|
||||
|
||||
|
||||
def sigmoid(x):
|
||||
return 1 / (1 + np.exp(-x))
|
||||
return float(1 / (1 + np.exp(-x)))
|
||||
|
||||
|
||||
class BaseReranker(AbsReranker):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user