FlagEmbedding/Tutorials/1_Embedding/1.2.6_BGE_VL.ipynb
2025-06-04 17:27:36 +08:00

458 lines
18 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# BGE-VL-v1&v1.5"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this tutorial, we will go through the multimodel retrieval models BGE-VL series, which achieved state-of-the-art performance on four popular zero-shot composed image retrieval benchmarks and the massive multimodal embedding benchmark (MMEB)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 0. Installation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Install the required packages in your environment.\n",
"\n",
"- Our model works well on transformers==4.45.2, and we recommend using this version."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install numpy torch transformers pillow"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. BGE-VL-CLIP"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"| Model | Language | Parameters | Model Size | Description | Base Model |\n",
"|:-------|:--------:|:--------------:|:--------------:|:-----------------:|:----------------:|\n",
"| [BAAI/bge-vl-base](https://huggingface.co/BAAI/BGE-VL-base) | English | 150M | 299 MB | Light weight multimodel embedder among image and text | CLIP-base |\n",
"| [BAAI/bge-vl-large](https://huggingface.co/BAAI/BGE-VL-large) | English | 428M | 855 MB | Large scale multimodel embedder among image and text | CLIP-large |"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"BGE-VL-base and BGE-VL-large are trained based on CLIP base and CLIP large, which both contain a vision transformer and a text transformer:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/share/project/xzy/Envs/ft/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"/share/project/xzy/Envs/ft/lib/python3.11/site-packages/_distutils_hack/__init__.py:54: UserWarning: Reliance on distutils from stdlib is deprecated. Users must rely on setuptools to provide the distutils module. Avoid importing distutils or import setuptools first, and avoid setting SETUPTOOLS_USE_DISTUTILS=stdlib. Register concerns at https://github.com/pypa/setuptools/issues/new?template=distutils-deprecation.yml\n",
" warnings.warn(\n",
"/share/project/xzy/Envs/ft/lib/python3.11/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/plain": [
"CLIPModel(\n",
" (text_model): CLIPTextTransformer(\n",
" (embeddings): CLIPTextEmbeddings(\n",
" (token_embedding): Embedding(49408, 512)\n",
" (position_embedding): Embedding(77, 512)\n",
" )\n",
" (encoder): CLIPEncoder(\n",
" (layers): ModuleList(\n",
" (0-11): 12 x CLIPEncoderLayer(\n",
" (self_attn): CLIPSdpaAttention(\n",
" (k_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (v_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (q_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (out_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" )\n",
" (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
" (mlp): CLIPMLP(\n",
" (activation_fn): QuickGELUActivation()\n",
" (fc1): Linear(in_features=512, out_features=2048, bias=True)\n",
" (fc2): Linear(in_features=2048, out_features=512, bias=True)\n",
" )\n",
" (layer_norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" )\n",
" )\n",
" (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (vision_model): CLIPVisionTransformer(\n",
" (embeddings): CLIPVisionEmbeddings(\n",
" (patch_embedding): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16), bias=False)\n",
" (position_embedding): Embedding(197, 768)\n",
" )\n",
" (pre_layrnorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (encoder): CLIPEncoder(\n",
" (layers): ModuleList(\n",
" (0-11): 12 x CLIPEncoderLayer(\n",
" (self_attn): CLIPSdpaAttention(\n",
" (k_proj): Linear(in_features=768, out_features=768, bias=True)\n",
" (v_proj): Linear(in_features=768, out_features=768, bias=True)\n",
" (q_proj): Linear(in_features=768, out_features=768, bias=True)\n",
" (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
" )\n",
" (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (mlp): CLIPMLP(\n",
" (activation_fn): QuickGELUActivation()\n",
" (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
" (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
" )\n",
" (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" )\n",
" )\n",
" (post_layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (visual_projection): Linear(in_features=768, out_features=512, bias=False)\n",
" (text_projection): Linear(in_features=512, out_features=512, bias=False)\n",
")"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"import torch\n",
"from transformers import AutoModel\n",
"\n",
"MODEL_NAME = \"BAAI/BGE-VL-base\" # or \"BAAI/BGE-VL-base\"\n",
"\n",
"model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True) # You must set trust_remote_code=True\n",
"model.set_processor(MODEL_NAME)\n",
"model.eval()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[0.2647, 0.1242]])\n"
]
}
],
"source": [
"with torch.no_grad():\n",
" query = model.encode(\n",
" images = \"../../imgs/cir_query.png\", \n",
" text = \"Make the background dark, as if the camera has taken the photo at night\"\n",
" )\n",
"\n",
" candidates = model.encode(\n",
" images = [\"../../imgs/cir_candi_1.png\", \"../../imgs/cir_candi_2.png\"]\n",
" )\n",
" \n",
" scores = query @ candidates.T\n",
"print(scores)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. BGE-VL-MLLM"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"| Model | Language | Parameters | Model Size | Description | Base Model |\n",
"|:-------|:--------:|:--------------:|:--------------:|:-----------------:|:----------------:|\n",
"| [BAAI/bge-vl-MLLM-S1](https://huggingface.co/BAAI/BGE-VL-MLLM-S1) | English | 7.57B | 15.14 GB | SOTA in composed image retrieval, trained on MegaPairs dataset | LLaVA-1.6 |\n",
"| [BAAI/bge-vl-MLLM-S2](https://huggingface.co/BAAI/BGE-VL-MLLM-S2) | English | 7.57B | 15.14 GB | Finetune BGE-VL-MLLM-S1 with one epoch on MMEB training set | LLaVA-1.6 |"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/share/project/xzy/Envs/ft/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"/share/project/xzy/Envs/ft/lib/python3.11/site-packages/_distutils_hack/__init__.py:54: UserWarning: Reliance on distutils from stdlib is deprecated. Users must rely on setuptools to provide the distutils module. Avoid importing distutils or import setuptools first, and avoid setting SETUPTOOLS_USE_DISTUTILS=stdlib. Register concerns at https://github.com/pypa/setuptools/issues/new?template=distutils-deprecation.yml\n",
" warnings.warn(\n",
"Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00, 1.28it/s]\n"
]
},
{
"data": {
"text/plain": [
"LLaVANextForEmbedding(\n",
" (vision_tower): CLIPVisionModel(\n",
" (vision_model): CLIPVisionTransformer(\n",
" (embeddings): CLIPVisionEmbeddings(\n",
" (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)\n",
" (position_embedding): Embedding(577, 1024)\n",
" )\n",
" (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
" (encoder): CLIPEncoder(\n",
" (layers): ModuleList(\n",
" (0-23): 24 x CLIPEncoderLayer(\n",
" (self_attn): CLIPSdpaAttention(\n",
" (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
" )\n",
" (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
" (mlp): CLIPMLP(\n",
" (activation_fn): QuickGELUActivation()\n",
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
" )\n",
" (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" )\n",
" )\n",
" (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" )\n",
" (multi_modal_projector): LlavaNextMultiModalProjector(\n",
" (linear_1): Linear(in_features=1024, out_features=4096, bias=True)\n",
" (act): GELUActivation()\n",
" (linear_2): Linear(in_features=4096, out_features=4096, bias=True)\n",
" )\n",
" (language_model): MistralForCausalLM(\n",
" (model): MistralModel(\n",
" (embed_tokens): Embedding(32005, 4096)\n",
" (layers): ModuleList(\n",
" (0-31): 32 x MistralDecoderLayer(\n",
" (self_attn): MistralSdpaAttention(\n",
" (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (rotary_emb): MistralRotaryEmbedding()\n",
" )\n",
" (mlp): MistralMLP(\n",
" (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
" (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
" (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
" (act_fn): SiLU()\n",
" )\n",
" (input_layernorm): MistralRMSNorm((4096,), eps=1e-05)\n",
" (post_attention_layernorm): MistralRMSNorm((4096,), eps=1e-05)\n",
" )\n",
" )\n",
" (norm): MistralRMSNorm((4096,), eps=1e-05)\n",
" )\n",
" (lm_head): Linear(in_features=4096, out_features=32005, bias=False)\n",
" )\n",
")"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch\n",
"from transformers import AutoModel\n",
"from PIL import Image\n",
"\n",
"MODEL_NAME= \"BAAI/BGE-VL-MLLM-S1\"\n",
"\n",
"model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True)\n",
"model.eval()\n",
"model.cuda()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[0.4109, 0.1807]], device='cuda:0')\n"
]
}
],
"source": [
"with torch.no_grad():\n",
" model.set_processor(MODEL_NAME)\n",
"\n",
" query_inputs = model.data_process(\n",
" text=\"Make the background dark, as if the camera has taken the photo at night\", \n",
" images=\"../../imgs/cir_query.png\",\n",
" q_or_c=\"q\",\n",
" task_instruction=\"Retrieve the target image that best meets the combined criteria by using both the provided image and the image retrieval instructions: \"\n",
" )\n",
"\n",
" candidate_inputs = model.data_process(\n",
" images=[\"../../imgs/cir_candi_1.png\", \"../../imgs/cir_candi_2.png\"],\n",
" q_or_c=\"c\",\n",
" )\n",
"\n",
" query_embs = model(**query_inputs, output_hidden_states=True)[:, -1, :]\n",
" candi_embs = model(**candidate_inputs, output_hidden_states=True)[:, -1, :]\n",
" \n",
" query_embs = torch.nn.functional.normalize(query_embs, dim=-1)\n",
" candi_embs = torch.nn.functional.normalize(candi_embs, dim=-1)\n",
"\n",
" scores = torch.matmul(query_embs, candi_embs.T)\n",
"print(scores)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. BGE-VL-v1.5"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"BGE-VL-v1.5 series is a new version of BGE-VL, bringing better performance on both retrieval and multi-modal understanding. It is trained on 30M MegaPairs data and extra 10M natural and synthetic data.\n",
"\n",
"`bge-vl-v1.5-zs` is a zero-shot model, only trained on the data mentioned above. `bge-vl-v1.5-mmeb` is the fine-tuned version on MMEB training set."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"| Model | Language | Parameters | Model Size | Description | Base Model |\n",
"|:-------|:--------:|:--------------:|:--------------:|:-----------------:|:----------------:|\n",
"| [BAAI/BGE-VL-v1.5-zs](https://huggingface.co/BAAI/BGE-VL-v1.5-zs) | English | 7.57B | 15.14 GB | Better multi-modal retrieval model with performs well in all kinds of tasks | LLaVA-1.6 |\n",
"| [BAAI/BGE-VL-v1.5-mmeb](https://huggingface.co/BAAI/BGE-VL-v1.5-mmeb) | English | 7.57B | 15.14 GB | Better multi-modal retrieval model, additionally fine-tuned on MMEB training set | LLaVA-1.6 |"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can use BGE-VL-v1.5 models in the exact same way as BGE-VL-MLLM."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00, 2.26it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[0.3880, 0.1815]], device='cuda:0')\n"
]
}
],
"source": [
"import torch\n",
"from transformers import AutoModel\n",
"from PIL import Image\n",
"\n",
"MODEL_NAME= \"BAAI/BGE-VL-v1.5-mmeb\" # \"BAAI/BGE-VL-v1.5-zs\"\n",
"\n",
"model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True)\n",
"model.eval()\n",
"model.cuda()\n",
"\n",
"with torch.no_grad():\n",
" model.set_processor(MODEL_NAME)\n",
"\n",
" query_inputs = model.data_process(\n",
" text=\"Make the background dark, as if the camera has taken the photo at night\", \n",
" images=\"../../imgs/cir_query.png\",\n",
" q_or_c=\"q\",\n",
" task_instruction=\"Retrieve the target image that best meets the combined criteria by using both the provided image and the image retrieval instructions: \"\n",
" )\n",
"\n",
" candidate_inputs = model.data_process(\n",
" images=[\"../../imgs/cir_candi_1.png\", \"../../imgs/cir_candi_2.png\"],\n",
" q_or_c=\"c\",\n",
" )\n",
"\n",
" query_embs = model(**query_inputs, output_hidden_states=True)[:, -1, :]\n",
" candi_embs = model(**candidate_inputs, output_hidden_states=True)[:, -1, :]\n",
" \n",
" query_embs = torch.nn.functional.normalize(query_embs, dim=-1)\n",
" candi_embs = torch.nn.functional.normalize(candi_embs, dim=-1)\n",
"\n",
" scores = torch.matmul(query_embs, candi_embs.T)\n",
"print(scores)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
}
},
"nbformat": 4,
"nbformat_minor": 2
}