mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-06-27 02:39:58 +00:00
fix transformers 4.48.0
This commit is contained in:
parent
3c406233cd
commit
99dfb3dfab
@ -53,7 +53,7 @@ from transformers.utils import (
|
||||
)
|
||||
from .gemma_config import CostWiseGemmaConfig
|
||||
from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm, Gemma2RotaryEmbedding, rotate_half, apply_rotary_pos_emb
|
||||
from transformers.models.gemma2.modeling_gemma2 import Gemma2MLP, repeat_kv, Gemma2Attention, Gemma2FlashAttention2, Gemma2SdpaAttention, GEMMA2_ATTENTION_CLASSES, Gemma2DecoderLayer, GEMMA2_START_DOCSTRING
|
||||
from transformers.models.gemma2.modeling_gemma2 import Gemma2MLP, repeat_kv, Gemma2Attention, Gemma2DecoderLayer, GEMMA2_START_DOCSTRING
|
||||
from transformers.models.gemma2.modeling_gemma2 import GEMMA2_INPUTS_DOCSTRING
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
@ -105,12 +105,6 @@ class CostWiseGemma2PreTrainedModel(PreTrainedModel):
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
GEMMA2_ATTENTION_CLASSES = {
|
||||
"eager": Gemma2Attention,
|
||||
"flash_attention_2": Gemma2FlashAttention2,
|
||||
"sdpa": Gemma2SdpaAttention,
|
||||
}
|
||||
|
||||
|
||||
_CONFIG_FOR_DOC = "CostWiseGemmaConfig"
|
||||
|
||||
|
@ -41,7 +41,7 @@ from transformers.modeling_attn_mask_utils import (
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, \
|
||||
SequenceClassifierOutputWithPast
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
|
||||
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
from transformers.utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
@ -63,6 +63,9 @@ except:
|
||||
|
||||
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
|
||||
# It means that the function will not be traced through and simply appear as a node in the graph.
|
||||
from packaging import version
|
||||
parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version)
|
||||
is_torch_greater_or_equal_than_1_13 = parsed_torch_version_base >= version.parse("1.13")
|
||||
if is_torch_fx_available():
|
||||
if not is_torch_greater_or_equal_than_1_13:
|
||||
import torch.fx
|
||||
|
Loading…
x
Reference in New Issue
Block a user