This commit is contained in:
ZiyiXia 2024-11-04 11:19:46 +00:00
commit ee3437f4eb
33 changed files with 23 additions and 28 deletions

View File

@ -47,8 +47,6 @@ class BaseLLMEmbedder(AbsEmbedder):
batch_size (int, optional): Batch size for inference. Defaults to :data:`256`.
query_max_length (int, optional): Maximum length for query. Defaults to :data:`512`.
passage_max_length (int, optional): Maximum length for passage. Defaults to :data:`512`.
instruction (Optional[str], optional): Instruction for embedding with :attr:`instruction_format`. Defaults to :data:`None`.
instruction_format (str, optional): Instruction format when using :attr:`instruction`. Defaults to :data:`"{}{}"`.
convert_to_numpy (bool, optional): If True, the output embedding will be a Numpy array. Otherwise, it will be a Torch Tensor.
Defaults to :data:`True`.
@ -72,8 +70,6 @@ class BaseLLMEmbedder(AbsEmbedder):
batch_size: int = 256,
query_max_length: int = 512,
passage_max_length: int = 512,
instruction: Optional[str] = None,
instruction_format: str = "{}{}",
convert_to_numpy: bool = True,
**kwargs: Any,
):
@ -87,8 +83,6 @@ class BaseLLMEmbedder(AbsEmbedder):
batch_size=batch_size,
query_max_length=query_max_length,
passage_max_length=passage_max_length,
instruction=instruction,
instruction_format=instruction_format,
convert_to_numpy=convert_to_numpy,
**kwargs
)

View File

@ -54,6 +54,8 @@ class ICLLLMEmbedder(AbsEmbedder):
batch_size (int, optional): Batch size for inference. Defaults to :data:`256`.
query_max_length (int, optional): Maximum length for query. Defaults to :data:`512`.
passage_max_length (int, optional): Maximum length for passage. Defaults to :data:`512`.
convert_to_numpy (bool, optional): If True, the output embedding will be a Numpy array. Otherwise, it will be a Torch Tensor.
Defaults to :data:`True`.
Attributes:
DEFAULT_POOLING_METHOD: The default pooling method when running the model.
@ -77,8 +79,6 @@ class ICLLLMEmbedder(AbsEmbedder):
batch_size: int = 256,
query_max_length: int = 512,
passage_max_length: int = 512,
instruction: Optional[str] = None,
instruction_format: str = "{}{}",
convert_to_numpy: bool = True,
**kwargs: Any,
):
@ -92,10 +92,8 @@ class ICLLLMEmbedder(AbsEmbedder):
batch_size=batch_size,
query_max_length=query_max_length,
passage_max_length=passage_max_length,
instruction=instruction,
instruction_format=instruction_format,
convert_to_numpy=convert_to_numpy,
kwargs=kwargs
**kwargs
)
self.tokenizer = AutoTokenizer.from_pretrained(

View File

@ -28,8 +28,6 @@ class BaseEmbedder(AbsEmbedder):
batch_size (int, optional): Batch size for inference. Defaults to :data:`256`.
query_max_length (int, optional): Maximum length for query. Defaults to :data:`512`.
passage_max_length (int, optional): Maximum length for passage. Defaults to :data:`512`.
instruction (Optional[str], optional): Instruction for embedding with :attr:`instruction_format`. Defaults to :data:`None`.
instruction_format (str, optional): Instruction format when using :attr:`instruction`. Defaults to :data:`"{}{}"`.
convert_to_numpy (bool, optional): If True, the output embedding will be a Numpy array. Otherwise, it will be a Torch Tensor.
Defaults to :data:`True`.
@ -55,8 +53,6 @@ class BaseEmbedder(AbsEmbedder):
batch_size: int = 256,
query_max_length: int = 512,
passage_max_length: int = 512,
instruction: Optional[str] = None,
instruction_format: str = "{}{}",
convert_to_numpy: bool = True,
**kwargs: Any,
):
@ -70,8 +66,6 @@ class BaseEmbedder(AbsEmbedder):
batch_size=batch_size,
query_max_length=query_max_length,
passage_max_length=passage_max_length,
instruction=instruction,
instruction_format=instruction_format,
convert_to_numpy=convert_to_numpy,
**kwargs
)
@ -201,9 +195,6 @@ class BaseEmbedder(AbsEmbedder):
if device == "cpu": self.use_fp16 = False
if self.use_fp16: self.model.half()
if device == "cpu": self.use_fp16 = False
if self.use_fp16: self.model.half()
self.model.to(device)
self.model.eval()

View File

@ -38,8 +38,6 @@ class M3Embedder(AbsEmbedder):
batch_size (int, optional): Batch size for inference. Defaults to :data:`256`.
query_max_length (int, optional): Maximum length for query. Defaults to :data:`512`.
passage_max_length (int, optional): Maximum length for passage. Defaults to :data:`512`.
instruction (Optional[str], optional): Instruction for embedding with :attr:`instruction_format`. Defaults to :data:`None`.
instruction_format (str, optional): Instruction format when using :attr:`instruction`. Defaults to :data:`"{}{}"`.
return_dense (bool, optional): If true, will return the dense embedding. Defaults to :data:`True`.
return_sparse (bool, optional): If true, will return the sparce embedding. Defaults to :data:`False`.
return_colbert_vecs (bool, optional): If true, will return the colbert vectors. Defaults to :data:`False`.
@ -66,8 +64,6 @@ class M3Embedder(AbsEmbedder):
batch_size: int = 256,
query_max_length: int = 512,
passage_max_length: int = 512,
instruction: Optional[str] = None,
instruction_format: str = "{}{}",
return_dense: bool = True,
return_sparse: bool = False,
return_colbert_vecs: bool = False,
@ -83,8 +79,6 @@ class M3Embedder(AbsEmbedder):
batch_size=batch_size,
query_max_length=query_max_length,
passage_max_length=passage_max_length,
instruction=instruction,
instruction_format=instruction_format,
return_dense=return_dense,
return_sparse=return_sparse,
return_colbert_vecs=return_colbert_vecs,

View File

@ -221,7 +221,7 @@ class BGEM3Model(nn.Module):
if teacher_scores is not None:
# print("Use soft-label distillation...")
teacher_targets = F.softmax(teacher_scores, dim=-1) # B N
group_size = p_sparse_vecs.size(0) // q_sparse_vecs.size(0)
group_size = p_dense_vecs.size(0) // q_dense_vecs.size(0)
# dense loss
dense_scores = self.dense_score(q_dense_vecs, p_dense_vecs) # B, B * N

View File

@ -0,0 +1,18 @@
from setuptools import setup, find_packages
setup(
name="visual_bge",
version="0.1.0",
description='visual_bge',
long_description="./README.md",
long_description_content_type="text/markdown",
url='https://github.com/FlagOpen/FlagEmbedding/tree/master/research/visual_bge',
packages=find_packages(),
install_requires=[
'torchvision',
'timm',
'einops',
'ftfy'
],
python_requires='>=3.6',
)

View File

@ -4,7 +4,7 @@ import torch
from torch import nn
from torch.nn import functional as F
from FlagEmbedding.visual.eva_clip.utils import freeze_batch_norm_2d
from visual_bge.eva_clip.utils import freeze_batch_norm_2d
class Bottleneck(nn.Module):