mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-06-27 02:39:58 +00:00
update environment
This commit is contained in:
parent
b51706d63a
commit
bf410d51e9
@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import sys
|
||||
import warnings
|
||||
import numpy as np
|
||||
from tqdm import trange
|
||||
@ -10,8 +11,6 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from FlagEmbedding.abc.inference import AbsReranker
|
||||
from FlagEmbedding.inference.reranker.encoder_only.base import sigmoid
|
||||
|
||||
from .models.gemma_model import CostWiseGemmaForCausalLM
|
||||
|
||||
|
||||
def last_logit_pool_lightweight(logits: Tensor,
|
||||
attention_mask: Tensor) -> Tensor:
|
||||
@ -144,6 +143,15 @@ class LightweightLLMReranker(AbsReranker):
|
||||
normalize: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
try:
|
||||
from .models.gemma_model import CostWiseGemmaForCausalLM
|
||||
except:
|
||||
print('*') * 20
|
||||
print('*') * 20
|
||||
print('error for load lightweight reranker, please install transformers==4.46.0')
|
||||
print('*') * 20
|
||||
print('*') * 20
|
||||
sys.exit()
|
||||
|
||||
super().__init__(
|
||||
model_name_or_path=model_name_or_path,
|
||||
|
@ -7,12 +7,12 @@ func_timeout==4.3.5
|
||||
pandas==2.2.1
|
||||
sqlglot==22.1.1
|
||||
rank_bm25==0.2.2
|
||||
peft==0.10.0
|
||||
transformers==4.41.1
|
||||
jinja2
|
||||
datasets
|
||||
sentencepiece
|
||||
flash-attn
|
||||
modelscope
|
||||
peft
|
||||
deepspeed
|
||||
bitsandbytes
|
Loading…
x
Reference in New Issue
Block a user