{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# BGE-VL-v1&v1.5" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this tutorial, we will go through the multimodel retrieval models BGE-VL series, which achieved state-of-the-art performance on four popular zero-shot composed image retrieval benchmarks and the massive multimodal embedding benchmark (MMEB)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 0. Installation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Install the required packages in your environment.\n", "\n", "- Our model works well on transformers==4.45.2, and we recommend using this version." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%pip install numpy torch transformers pillow" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. BGE-VL-CLIP" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "| Model | Language | Parameters | Model Size | Description | Base Model |\n", "|:-------|:--------:|:--------------:|:--------------:|:-----------------:|:----------------:|\n", "| [BAAI/bge-vl-base](https://huggingface.co/BAAI/BGE-VL-base) | English | 150M | 299 MB | Light weight multimodel embedder among image and text | CLIP-base |\n", "| [BAAI/bge-vl-large](https://huggingface.co/BAAI/BGE-VL-large) | English | 428M | 855 MB | Large scale multimodel embedder among image and text | CLIP-large |" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "BGE-VL-base and BGE-VL-large are trained based on CLIP base and CLIP large, which both contain a vision transformer and a text transformer:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/share/project/xzy/Envs/ft/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", "/share/project/xzy/Envs/ft/lib/python3.11/site-packages/_distutils_hack/__init__.py:54: UserWarning: Reliance on distutils from stdlib is deprecated. Users must rely on setuptools to provide the distutils module. Avoid importing distutils or import setuptools first, and avoid setting SETUPTOOLS_USE_DISTUTILS=stdlib. Register concerns at https://github.com/pypa/setuptools/issues/new?template=distutils-deprecation.yml\n", " warnings.warn(\n", "/share/project/xzy/Envs/ft/lib/python3.11/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n", " warnings.warn(\n" ] }, { "data": { "text/plain": [ "CLIPModel(\n", " (text_model): CLIPTextTransformer(\n", " (embeddings): CLIPTextEmbeddings(\n", " (token_embedding): Embedding(49408, 512)\n", " (position_embedding): Embedding(77, 512)\n", " )\n", " (encoder): CLIPEncoder(\n", " (layers): ModuleList(\n", " (0-11): 12 x CLIPEncoderLayer(\n", " (self_attn): CLIPSdpaAttention(\n", " (k_proj): Linear(in_features=512, out_features=512, bias=True)\n", " (v_proj): Linear(in_features=512, out_features=512, bias=True)\n", " (q_proj): Linear(in_features=512, out_features=512, bias=True)\n", " (out_proj): Linear(in_features=512, out_features=512, bias=True)\n", " )\n", " (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (mlp): CLIPMLP(\n", " (activation_fn): QuickGELUActivation()\n", " (fc1): Linear(in_features=512, out_features=2048, bias=True)\n", " (fc2): Linear(in_features=2048, out_features=512, bias=True)\n", " )\n", " (layer_norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " )\n", " )\n", " )\n", " (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (vision_model): CLIPVisionTransformer(\n", " (embeddings): CLIPVisionEmbeddings(\n", " (patch_embedding): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16), bias=False)\n", " (position_embedding): Embedding(197, 768)\n", " )\n", " (pre_layrnorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", " (encoder): CLIPEncoder(\n", " (layers): ModuleList(\n", " (0-11): 12 x CLIPEncoderLayer(\n", " (self_attn): CLIPSdpaAttention(\n", " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", " (mlp): CLIPMLP(\n", " (activation_fn): QuickGELUActivation()\n", " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", " )\n", " (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", " )\n", " )\n", " )\n", " (post_layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (visual_projection): Linear(in_features=768, out_features=512, bias=False)\n", " (text_projection): Linear(in_features=512, out_features=512, bias=False)\n", ")" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "import torch\n", "from transformers import AutoModel\n", "\n", "MODEL_NAME = \"BAAI/BGE-VL-base\" # or \"BAAI/BGE-VL-base\"\n", "\n", "model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True) # You must set trust_remote_code=True\n", "model.set_processor(MODEL_NAME)\n", "model.eval()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[0.2647, 0.1242]])\n" ] } ], "source": [ "with torch.no_grad():\n", " query = model.encode(\n", " images = \"../../imgs/cir_query.png\", \n", " text = \"Make the background dark, as if the camera has taken the photo at night\"\n", " )\n", "\n", " candidates = model.encode(\n", " images = [\"../../imgs/cir_candi_1.png\", \"../../imgs/cir_candi_2.png\"]\n", " )\n", " \n", " scores = query @ candidates.T\n", "print(scores)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. BGE-VL-MLLM" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "| Model | Language | Parameters | Model Size | Description | Base Model |\n", "|:-------|:--------:|:--------------:|:--------------:|:-----------------:|:----------------:|\n", "| [BAAI/bge-vl-MLLM-S1](https://huggingface.co/BAAI/BGE-VL-MLLM-S1) | English | 7.57B | 15.14 GB | SOTA in composed image retrieval, trained on MegaPairs dataset | LLaVA-1.6 |\n", "| [BAAI/bge-vl-MLLM-S2](https://huggingface.co/BAAI/BGE-VL-MLLM-S2) | English | 7.57B | 15.14 GB | Finetune BGE-VL-MLLM-S1 with one epoch on MMEB training set | LLaVA-1.6 |" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/share/project/xzy/Envs/ft/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", "/share/project/xzy/Envs/ft/lib/python3.11/site-packages/_distutils_hack/__init__.py:54: UserWarning: Reliance on distutils from stdlib is deprecated. Users must rely on setuptools to provide the distutils module. Avoid importing distutils or import setuptools first, and avoid setting SETUPTOOLS_USE_DISTUTILS=stdlib. Register concerns at https://github.com/pypa/setuptools/issues/new?template=distutils-deprecation.yml\n", " warnings.warn(\n", "Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00, 1.28it/s]\n" ] }, { "data": { "text/plain": [ "LLaVANextForEmbedding(\n", " (vision_tower): CLIPVisionModel(\n", " (vision_model): CLIPVisionTransformer(\n", " (embeddings): CLIPVisionEmbeddings(\n", " (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)\n", " (position_embedding): Embedding(577, 1024)\n", " )\n", " (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", " (encoder): CLIPEncoder(\n", " (layers): ModuleList(\n", " (0-23): 24 x CLIPEncoderLayer(\n", " (self_attn): CLIPSdpaAttention(\n", " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " )\n", " (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", " (mlp): CLIPMLP(\n", " (activation_fn): QuickGELUActivation()\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " )\n", " (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", " )\n", " )\n", " )\n", " (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", " )\n", " )\n", " (multi_modal_projector): LlavaNextMultiModalProjector(\n", " (linear_1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELUActivation()\n", " (linear_2): Linear(in_features=4096, out_features=4096, bias=True)\n", " )\n", " (language_model): MistralForCausalLM(\n", " (model): MistralModel(\n", " (embed_tokens): Embedding(32005, 4096)\n", " (layers): ModuleList(\n", " (0-31): 32 x MistralDecoderLayer(\n", " (self_attn): MistralSdpaAttention(\n", " (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", " (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", " (rotary_emb): MistralRotaryEmbedding()\n", " )\n", " (mlp): MistralMLP(\n", " (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n", " (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n", " (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n", " (act_fn): SiLU()\n", " )\n", " (input_layernorm): MistralRMSNorm((4096,), eps=1e-05)\n", " (post_attention_layernorm): MistralRMSNorm((4096,), eps=1e-05)\n", " )\n", " )\n", " (norm): MistralRMSNorm((4096,), eps=1e-05)\n", " )\n", " (lm_head): Linear(in_features=4096, out_features=32005, bias=False)\n", " )\n", ")" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "from transformers import AutoModel\n", "from PIL import Image\n", "\n", "MODEL_NAME= \"BAAI/BGE-VL-MLLM-S1\"\n", "\n", "model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True)\n", "model.eval()\n", "model.cuda()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[0.4109, 0.1807]], device='cuda:0')\n" ] } ], "source": [ "with torch.no_grad():\n", " model.set_processor(MODEL_NAME)\n", "\n", " query_inputs = model.data_process(\n", " text=\"Make the background dark, as if the camera has taken the photo at night\", \n", " images=\"../../imgs/cir_query.png\",\n", " q_or_c=\"q\",\n", " task_instruction=\"Retrieve the target image that best meets the combined criteria by using both the provided image and the image retrieval instructions: \"\n", " )\n", "\n", " candidate_inputs = model.data_process(\n", " images=[\"../../imgs/cir_candi_1.png\", \"../../imgs/cir_candi_2.png\"],\n", " q_or_c=\"c\",\n", " )\n", "\n", " query_embs = model(**query_inputs, output_hidden_states=True)[:, -1, :]\n", " candi_embs = model(**candidate_inputs, output_hidden_states=True)[:, -1, :]\n", " \n", " query_embs = torch.nn.functional.normalize(query_embs, dim=-1)\n", " candi_embs = torch.nn.functional.normalize(candi_embs, dim=-1)\n", "\n", " scores = torch.matmul(query_embs, candi_embs.T)\n", "print(scores)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. BGE-VL-v1.5" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "BGE-VL-v1.5 series is a new version of BGE-VL, bringing better performance on both retrieval and multi-modal understanding. It is trained on 30M MegaPairs data and extra 10M natural and synthetic data.\n", "\n", "`bge-vl-v1.5-zs` is a zero-shot model, only trained on the data mentioned above. `bge-vl-v1.5-mmeb` is the fine-tuned version on MMEB training set." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "| Model | Language | Parameters | Model Size | Description | Base Model |\n", "|:-------|:--------:|:--------------:|:--------------:|:-----------------:|:----------------:|\n", "| [BAAI/BGE-VL-v1.5-zs](https://huggingface.co/BAAI/BGE-VL-v1.5-zs) | English | 7.57B | 15.14 GB | Better multi-modal retrieval model with performs well in all kinds of tasks | LLaVA-1.6 |\n", "| [BAAI/BGE-VL-v1.5-mmeb](https://huggingface.co/BAAI/BGE-VL-v1.5-mmeb) | English | 7.57B | 15.14 GB | Better multi-modal retrieval model, additionally fine-tuned on MMEB training set | LLaVA-1.6 |" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can use BGE-VL-v1.5 models in the exact same way as BGE-VL-MLLM." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00, 2.26it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "tensor([[0.3880, 0.1815]], device='cuda:0')\n" ] } ], "source": [ "import torch\n", "from transformers import AutoModel\n", "from PIL import Image\n", "\n", "MODEL_NAME= \"BAAI/BGE-VL-v1.5-mmeb\" # \"BAAI/BGE-VL-v1.5-zs\"\n", "\n", "model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True)\n", "model.eval()\n", "model.cuda()\n", "\n", "with torch.no_grad():\n", " model.set_processor(MODEL_NAME)\n", "\n", " query_inputs = model.data_process(\n", " text=\"Make the background dark, as if the camera has taken the photo at night\", \n", " images=\"../../imgs/cir_query.png\",\n", " q_or_c=\"q\",\n", " task_instruction=\"Retrieve the target image that best meets the combined criteria by using both the provided image and the image retrieval instructions: \"\n", " )\n", "\n", " candidate_inputs = model.data_process(\n", " images=[\"../../imgs/cir_candi_1.png\", \"../../imgs/cir_candi_2.png\"],\n", " q_or_c=\"c\",\n", " )\n", "\n", " query_embs = model(**query_inputs, output_hidden_states=True)[:, -1, :]\n", " candi_embs = model(**candidate_inputs, output_hidden_states=True)[:, -1, :]\n", " \n", " query_embs = torch.nn.functional.normalize(query_embs, dim=-1)\n", " candi_embs = torch.nn.functional.normalize(candi_embs, dim=-1)\n", "\n", " scores = torch.matmul(query_embs, candi_embs.T)\n", "print(scores)\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" } }, "nbformat": 4, "nbformat_minor": 2 }