mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-06-27 02:39:58 +00:00
update tutorials
This commit is contained in:
parent
5e64baa61e
commit
235f967a01
@ -4,7 +4,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# BGE-VL"
|
||||
"# BGE-VL-v1&v1.5"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -25,7 +25,9 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Install the required packages in your environment."
|
||||
"Install the required packages in your environment.\n",
|
||||
"\n",
|
||||
"- Our model works well on transformers==4.45.2, and we recommend using this version."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -340,6 +342,95 @@
|
||||
" scores = torch.matmul(query_embs, candi_embs.T)\n",
|
||||
"print(scores)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 3. BGE-VL-v1.5"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"BGE-VL-v1.5 series is a new version of BGE-VL, bringing better performance on both retrieval and multi-modal understanding. It is trained on 30M MegaPairs data and extra 10M natural and synthetic data.\n",
|
||||
"\n",
|
||||
"`bge-vl-v1.5-zs` is a zero-shot model, only trained on the data mentioned above. `bge-vl-v1.5-mmeb` is the fine-tuned version on MMEB training set."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"| Model | Language | Parameters | Model Size | Description | Base Model |\n",
|
||||
"|:-------|:--------:|:--------------:|:--------------:|:-----------------:|:----------------:|\n",
|
||||
"| [BAAI/BGE-VL-v1.5-zs](https://huggingface.co/BAAI/BGE-VL-v1.5-zs) | English | 7.57B | 15.14 GB | Better multi-modal retrieval model with performs well in all kinds of tasks | LLaVA-1.6 |\n",
|
||||
"| [BAAI/BGE-VL-v1.5-mmeb](https://huggingface.co/BAAI/BGE-VL-v1.5-mmeb) | English | 7.57B | 15.14 GB | Better multi-modal retrieval model, additionally fine-tuned on MMEB training set | LLaVA-1.6 |"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You can use BGE-VL-v1.5 models in the exact same way as BGE-VL-MLLM."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00, 2.26it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([[0.3880, 0.1815]], device='cuda:0')\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from transformers import AutoModel\n",
|
||||
"from PIL import Image\n",
|
||||
"\n",
|
||||
"MODEL_NAME= \"BAAI/BGE-VL-v1.5-mmeb\" # \"BAAI/BGE-VL-v1.5-zs\"\n",
|
||||
"\n",
|
||||
"model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True)\n",
|
||||
"model.eval()\n",
|
||||
"model.cuda()\n",
|
||||
"\n",
|
||||
"with torch.no_grad():\n",
|
||||
" model.set_processor(MODEL_NAME)\n",
|
||||
"\n",
|
||||
" query_inputs = model.data_process(\n",
|
||||
" text=\"Make the background dark, as if the camera has taken the photo at night\", \n",
|
||||
" images=\"../../imgs/cir_query.png\",\n",
|
||||
" q_or_c=\"q\",\n",
|
||||
" task_instruction=\"Retrieve the target image that best meets the combined criteria by using both the provided image and the image retrieval instructions: \"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" candidate_inputs = model.data_process(\n",
|
||||
" images=[\"../../imgs/cir_candi_1.png\", \"../../imgs/cir_candi_2.png\"],\n",
|
||||
" q_or_c=\"c\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" query_embs = model(**query_inputs, output_hidden_states=True)[:, -1, :]\n",
|
||||
" candi_embs = model(**candidate_inputs, output_hidden_states=True)[:, -1, :]\n",
|
||||
" \n",
|
||||
" query_embs = torch.nn.functional.normalize(query_embs, dim=-1)\n",
|
||||
" candi_embs = torch.nn.functional.normalize(candi_embs, dim=-1)\n",
|
||||
"\n",
|
||||
" scores = torch.matmul(query_embs, candi_embs.T)\n",
|
||||
"print(scores)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@ -358,7 +449,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.10"
|
||||
"version": "3.10.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
241
Tutorials/1_Embedding/1.2.7_BGE_Code_v1.ipynb
Normal file
241
Tutorials/1_Embedding/1.2.7_BGE_Code_v1.ipynb
Normal file
@ -0,0 +1,241 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# BGE-Code-v1"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 0. Installation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install -U FlagEmbedding"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1. Introduction"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"| Model | Language | Parameters | Model Size | Description | Base Model |\n",
|
||||
"|:-------|:--------:|:--------------:|:--------------:|:-----------------:|:----------------:|\n",
|
||||
"| [BAAI/bge-code-v1](https://huggingface.co/BAAI/bge-code-v1) | Multi-lingual | 1.54B | 6.18 GB | LLM-based code embedding model with strong text retrieval and multilingual capabilities. | Qwen-2.5-Coder-1.5B |"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**[BGE-Code-v1](https://huggingface.co/BAAI/bge-code-v1)** is an LLM-based code embedding model that supports code retrieval, text retrieval, and multilingual retrieval. It primarily demonstrates the following capabilities:\n",
|
||||
"- Superior Code Retrieval Performance: The model demonstrates exceptional code retrieval capabilities, supporting natural language queries in both English and Chinese, as well as 20 programming languages.\n",
|
||||
"- Robust Text Retrieval Capabilities: The model maintains strong text retrieval capabilities comparable to text embedding models of similar scale.\n",
|
||||
"- Extensive Multilingual Support: BGE-Code-v1 offers comprehensive multilingual retrieval capabilities, excelling in languages such as English, Chinese, Japanese, French, and more."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import AutoTokenizer, AutoModel\n",
|
||||
"import torch, os\n",
|
||||
"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(\"BAAI/bge-code-v1\")\n",
|
||||
"raw_model = AutoModel.from_pretrained(\"BAAI/bge-code-v1\")\n",
|
||||
"\n",
|
||||
"raw_model.eval()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2. Usage"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Given the following tiny corpus:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"corpus = [\"\"\"\n",
|
||||
"def func_1(arr, target):\n",
|
||||
" low, high = 0, len(arr) - 1\n",
|
||||
" while low <= high:\n",
|
||||
" mid = (low + high) // 2\n",
|
||||
" if arr[mid] == target: return mid\n",
|
||||
" elif arr[mid] < target: low = mid + 1\n",
|
||||
" else: high = mid - 1\n",
|
||||
" return -1\n",
|
||||
"\"\"\",\n",
|
||||
"\"\"\"\n",
|
||||
"def func_2(n, memo={}):\n",
|
||||
" if n <= 1: return n\n",
|
||||
" if n not in memo:\n",
|
||||
" memo[n] = fib(n-1, memo) + fib(n-2, memo)\n",
|
||||
" return memo[n]\n",
|
||||
"\"\"\",\n",
|
||||
"\"\"\"\n",
|
||||
"def func_3(a, b):\n",
|
||||
" while b:\n",
|
||||
" a, b = b, a % b\n",
|
||||
" return a\n",
|
||||
"\"\"\",\n",
|
||||
"\"\"\"\n",
|
||||
"def func_4(n):\n",
|
||||
" if n < 2: return False\n",
|
||||
" for i in range(2, int(n**0.5) + 1):\n",
|
||||
" if n % i == 0: return False\n",
|
||||
" return True\n",
|
||||
"\"\"\",\n",
|
||||
"\"\"\"\n",
|
||||
"int func_5(const vector<int>& arr, int target) {\n",
|
||||
" int low = 0, high = arr.size() - 1;\n",
|
||||
" while (low <= high) {\n",
|
||||
" int mid = low + (high - low) / 2;\n",
|
||||
" if (arr[mid] == target) return mid;\n",
|
||||
" else if (arr[mid] < target) low = mid + 1;\n",
|
||||
" else high = mid - 1;\n",
|
||||
" }\n",
|
||||
" return -1;\n",
|
||||
"}\n",
|
||||
"\"\"\"\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We want to find the answer to the following question:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"query = \"The fastest way to find an element in a sorted array\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 6.08it/s]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from FlagEmbedding import FlagLLMModel\n",
|
||||
"\n",
|
||||
"model = FlagLLMModel('BAAI/bge-code-v1', \n",
|
||||
" query_instruction_format=\"<instruct>{}\\n<query>{}\",\n",
|
||||
" query_instruction_for_retrieval=\"Given a question in text, retrieve SQL queries that are appropriate responses to the question.\",\n",
|
||||
" trust_remote_code=True,\n",
|
||||
" devices=0,\n",
|
||||
" use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"((1536,), (5, 1536))"
|
||||
]
|
||||
},
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"query_emb = model.encode_queries(query)\n",
|
||||
"corpus_emb = model.encode_corpus(corpus)\n",
|
||||
"query_emb.shape, corpus_emb.shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[0.4553 0.2172 0.2277 0.196 0.4355]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"similarity = query_emb @ corpus_emb.T\n",
|
||||
"print(similarity)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can see that the elements with index 0 and 5, which are the implementation of binary search in Python and C++, have conspicuously higher similarity than other candidates."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user