From bf410d51e9e685e7f1b2e3816f09f3dc64ad8292 Mon Sep 17 00:00:00 2001 From: cfli <545999961@qq.com> Date: Wed, 28 May 2025 14:55:54 +0800 Subject: [PATCH] update environment --- .../inference/reranker/decoder_only/lightweight.py | 12 ++++++++++-- .../finetune/compensation/__init__.py | 0 research/Matroyshka_reranker/requirements.txt | 2 +- 3 files changed, 11 insertions(+), 3 deletions(-) create mode 100644 research/Matroyshka_reranker/finetune/compensation/__init__.py diff --git a/FlagEmbedding/inference/reranker/decoder_only/lightweight.py b/FlagEmbedding/inference/reranker/decoder_only/lightweight.py index 61118b7..000478a 100644 --- a/FlagEmbedding/inference/reranker/decoder_only/lightweight.py +++ b/FlagEmbedding/inference/reranker/decoder_only/lightweight.py @@ -1,4 +1,5 @@ import torch +import sys import warnings import numpy as np from tqdm import trange @@ -10,8 +11,6 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from FlagEmbedding.abc.inference import AbsReranker from FlagEmbedding.inference.reranker.encoder_only.base import sigmoid -from .models.gemma_model import CostWiseGemmaForCausalLM - def last_logit_pool_lightweight(logits: Tensor, attention_mask: Tensor) -> Tensor: @@ -144,6 +143,15 @@ class LightweightLLMReranker(AbsReranker): normalize: bool = False, **kwargs: Any, ) -> None: + try: + from .models.gemma_model import CostWiseGemmaForCausalLM + except: + print('*') * 20 + print('*') * 20 + print('error for load lightweight reranker, please install transformers==4.46.0') + print('*') * 20 + print('*') * 20 + sys.exit() super().__init__( model_name_or_path=model_name_or_path, diff --git a/research/Matroyshka_reranker/finetune/compensation/__init__.py b/research/Matroyshka_reranker/finetune/compensation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/research/Matroyshka_reranker/requirements.txt b/research/Matroyshka_reranker/requirements.txt index b5de2fc..8e7ba34 100644 --- a/research/Matroyshka_reranker/requirements.txt +++ b/research/Matroyshka_reranker/requirements.txt @@ -7,12 +7,12 @@ func_timeout==4.3.5 pandas==2.2.1 sqlglot==22.1.1 rank_bm25==0.2.2 +peft==0.10.0 transformers==4.41.1 jinja2 datasets sentencepiece flash-attn modelscope -peft deepspeed bitsandbytes \ No newline at end of file