487 lines
17 KiB
Plaintext
Raw Normal View History

2024-11-08 15:34:28 +08:00
{
"cells": [
{
"cell_type": "markdown",
"id": "06cff9e4",
"metadata": {},
"source": [
"# BGE Series"
]
},
{
"cell_type": "markdown",
"id": "880e229d",
"metadata": {},
"source": [
"In this Part, we will walk through the BGE series and introduce how to use the BGE embedding models."
]
},
{
"cell_type": "markdown",
"id": "2516fd49",
"metadata": {},
"source": [
"## 1. BAAI General Embedding"
]
},
{
"cell_type": "markdown",
"id": "2113ee71",
"metadata": {},
"source": [
"BGE stands for BAAI General Embedding, it's a series of embeddings models developed and published by Beijing Academy of Artificial Intelligence (BAAI)."
]
},
{
"cell_type": "markdown",
"id": "16515b99",
"metadata": {},
"source": [
"A full support of APIs and related usages of BGE is maintained in [FlagEmbedding](https://github.com/FlagOpen/FlagEmbedding) on GitHub.\n",
"\n",
"Run the following cell to install FlagEmbedding in your environment."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "88095fd0",
"metadata": {},
"outputs": [],
"source": [
"%%capture\n",
"%pip install -U FlagEmbedding"
]
},
{
"cell_type": "markdown",
"id": "bc6e30a0",
"metadata": {},
"source": [
"The collection of BGE models can be found in [Huggingface collection](https://huggingface.co/collections/BAAI/bge-66797a74476eb1f085c7446d)."
]
},
{
"cell_type": "markdown",
"id": "67a16ccf",
"metadata": {},
"source": [
"## 2. BGE Series Models"
]
},
{
"cell_type": "markdown",
"id": "2e10034a",
"metadata": {},
"source": [
"### 2.1 BGE"
]
},
{
"cell_type": "markdown",
"id": "0cdc6702",
"metadata": {},
"source": [
"The very first version of BGE has 6 models, with 'large', 'base', and 'small' for English and Chinese. "
]
},
{
"cell_type": "markdown",
"id": "04b75f72",
"metadata": {},
"source": [
"| Model | Language | Parameters | Model Size | Description | Base Model |\n",
"|:-------|:--------:|:--------------:|:--------------:|:-----------------:|:----------------:|\n",
"| [BAAI/bge-large-en](https://huggingface.co/BAAI/bge-large-en) | English | 335M | 1.34 GB | Embedding Model which map text into vector | BERT |\n",
"| [BAAI/bge-base-en](https://huggingface.co/BAAI/bge-base-en) | English | 109M | 438 MB | a base-scale model but with similar ability to `bge-large-en` | BERT |\n",
"| [BAAI/bge-small-en](https://huggingface.co/BAAI/bge-small-en) | English | 33.4M | 133 MB | a small-scale model but with competitive performance | BERT |\n",
"| [BAAI/bge-large-zh](https://huggingface.co/BAAI/bge-large-zh) | Chinese | 326M | 1.3 GB | Embedding Model which map text into vector | BERT |\n",
"| [BAAI/bge-base-zh](https://huggingface.co/BAAI/bge-base-zh) | Chinese | 102M | 409 MB | a base-scale model but with similar ability to `bge-large-zh` | BERT |\n",
"| [BAAI/bge-small-zh](https://huggingface.co/BAAI/bge-small-zh) | Chinese | 24M | 95.8 MB | a small-scale model but with competitive performance | BERT |"
]
},
{
"cell_type": "markdown",
"id": "c9c45d17",
"metadata": {},
"source": [
"For inference, import FlagModel from FlagEmbedding and initialize the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "89e07751",
"metadata": {},
"outputs": [],
"source": [
"from FlagEmbedding import FlagModel\n",
"\n",
"# Load BGE model\n",
"model = FlagModel('BAAI/bge-base-en',\n",
" query_instruction_for_retrieval=\"Represent this sentence for searching relevant passages:\",\n",
" use_fp16=True)\n",
"\n",
"queries = [\"query 1\", \"query 2\"]\n",
"corpus = [\"passage 1\", \"passage 2\"]\n",
"\n",
"# encode the queries and corpus\n",
"q_embeddings = model.encode(queries)\n",
"p_embeddings = model.encode(corpus)\n",
"\n",
"# compute the similarity scores\n",
"scores = q_embeddings @ p_embeddings.T\n",
"print(scores)"
]
},
{
"cell_type": "markdown",
"id": "6c8e69ed",
"metadata": {},
"source": [
"To use `FlagModel`:\n",
"```\n",
"FlagModel.encode(sentences, batch_size=256, max_length=512, convert_to_numpy=True)\n",
"```\n",
"The *encode()* function directly encode the input sentences to embedding vectors.\n",
"```\n",
"FlagModel.encode_queries(sentences, batch_size=256, max_length=512, convert_to_numpy=True)\n",
"```\n",
"The *encode_queries()* function concatenate the `query_instruction_for_retrieval` with each of the input query, and then call `encode()`."
]
},
{
"cell_type": "markdown",
"id": "2c86a5a3",
"metadata": {},
"source": [
"### 2.2 BGE v1.5"
]
},
{
"cell_type": "markdown",
"id": "454ff7aa",
"metadata": {},
"source": [
"BGE 1.5 alleviate the issue of the similarity distribution, and enhance retrieval ability without instruction."
]
},
{
"cell_type": "markdown",
"id": "30b1f897",
"metadata": {},
"source": [
"| Model | Language | Parameters | Model Size | Description | Base Model |\n",
"|:-------|:--------:|:--------------:|:--------------:|:-----------------:|:----------------:|\n",
"| [BAAI/bge-large-en-v1.5](https://huggingface.co/BAAI/bge-large-en-v1.5) | English | 335M | 1.34 GB | version 1.5 with more reasonable similarity distribution | BERT |\n",
"| [BAAI/bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) | English | 109M | 438 MB | version 1.5 with more reasonable similarity distribution | BERT |\n",
"| [BAAI/bge-small-en-v1.5](https://huggingface.co/BAAI/bge-small-en-v1.5) | English | 33.4M | 133 MB | version 1.5 with more reasonable similarity distribution | BERT |\n",
"| [BAAI/bge-large-zh-v1.5](https://huggingface.co/BAAI/bge-large-zh-v1.5) | Chinese | 326M | 1.3 GB | version 1.5 with more reasonable similarity distribution | BERT |\n",
"| [BAAI/bge-base-zh-v1.5](https://huggingface.co/BAAI/bge-base-zh-v1.5) | Chinese | 102M | 409 MB | version 1.5 with more reasonable similarity distribution | BERT |\n",
"| [BAAI/bge-small-zh-v1.5](https://huggingface.co/BAAI/bge-small-zh-v1.5) | Chinese | 24M | 95.8 MB | version 1.5 with more reasonable similarity distribution | BERT |"
]
},
{
"cell_type": "markdown",
"id": "ed00c504",
"metadata": {},
"source": [
"BGE 1.5 models shares the same API of `FlagModel` with BGE models."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "9b17afcc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0.736794 0.5989914]\n",
" [0.5684842 0.7461165]]\n"
]
}
],
"source": [
"model = FlagModel('BAAI/bge-base-en-v1.5',\n",
" query_instruction_for_retrieval=\"Represent this sentence for searching relevant passages:\",\n",
" use_fp16=True)\n",
"\n",
"queries = [\"query 1\", \"query 2\"]\n",
"corpus = [\"passage 1\", \"passage 2\"]\n",
"\n",
"# encode the queries and corpus\n",
"q_embeddings = model.encode(queries)\n",
"p_embeddings = model.encode(corpus)\n",
"\n",
"# compute the similarity scores\n",
"scores = q_embeddings @ p_embeddings.T\n",
"print(scores)"
]
},
{
"cell_type": "markdown",
"id": "38c3ce1c",
"metadata": {},
"source": [
"### 2.3 LLM-Embedder"
]
},
{
"cell_type": "markdown",
"id": "1bc3fee0",
"metadata": {},
"source": [
"LLM-Embedder is a unified embedding model supporting diverse retrieval augmentation needs for LLMs. It is fine-tuned over 6 tasks:\n",
"- Question Answering (qa)\n",
"- Conversational Search (convsearch)\n",
"- Long Conversation (chat)\n",
"- Long-Rnage Language Modeling (lrlm)\n",
"- In-Context Learning (icl)\n",
"- Tool Learning (tool)"
]
},
{
"cell_type": "markdown",
"id": "13b926e9",
"metadata": {},
"source": [
"| Model | Language | Parameters | Model Size | Description | Base Model |\n",
"|:-------|:--------:|:--------------:|:--------------:|:-----------------:|:----------------:|\n",
"| [BAAI/llm-embedder](https://huggingface.co/BAAI/llm-embedder) | English | 109M | 438 MB | a unified embedding model to support diverse retrieval augmentation needs for LLMs | BERT |"
]
},
{
"cell_type": "markdown",
"id": "a7b3f109",
"metadata": {},
"source": [
"To use `LLMEmbedder`:\n",
"```python\n",
"LLMEmbedder.encode_queries(\n",
" queries, \n",
" batch_size=256, \n",
" max_length=256, \n",
" task='qa'\n",
")\n",
"```\n",
"The *encode_queries()* will call the *_encode()* functions (similar to the *encode()* in `FlagModel`) and add the corresponding query instruction of the given *task* in front of each of the input *queries*.\n",
"```python\n",
"LLMEmbedder.encode_keys(\n",
" keys, \n",
" batch_size=256, \n",
" max_length=512, \n",
" task='qa'\n",
")\n",
"```\n",
"Similarly, *encode_keys()* also calls *_encode()* and automatically add instructions according to given task."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "5f077420",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0.89705944 0.85341793]\n",
" [0.8462474 0.90914035]]\n"
]
}
],
"source": [
"from FlagEmbedding import LLMEmbedder\n",
"\n",
"# load the LLMEmbedder model\n",
"model = LLMEmbedder('BAAI/llm-embedder', use_fp16=False)\n",
"\n",
"# Define queries and keys\n",
"queries = [\"test query 1\", \"test query 2\"]\n",
"keys = [\"test key 1\", \"test key 2\"]\n",
"\n",
"# Encode for a specific task (qa, icl, chat, lrlm, tool, convsearch)\n",
"task = \"qa\"\n",
"query_embeddings = model.encode_queries(queries, task=task)\n",
"key_embeddings = model.encode_keys(keys, task=task)\n",
"\n",
"# compute the similarity scores\n",
"similarity = query_embeddings @ key_embeddings.T\n",
"print(similarity)"
]
},
{
"cell_type": "markdown",
"id": "dcf2a82b",
"metadata": {},
"source": [
"### 2.4 BGE M3"
]
},
{
"cell_type": "markdown",
"id": "cc5b5a5e",
"metadata": {},
"source": [
"BGE-M3 is the new version of BGE models that is distinguished for its versatility in:\n",
"- Multi-Functionality: Simultaneously perform the three common retrieval functionalities of embedding model: dense retrieval, multi-vector retrieval, and sparse retrieval.\n",
"- Multi-Linguality: Supports more than 100 working languages.\n",
"- Multi-Granularity: Can proces inputs with different granularityies, spanning from short sentences to long documents of up to 8192 tokens.\n",
"\n",
"For more details, feel free to check out the [paper](https://arxiv.org/pdf/2402.03216)."
]
},
{
"cell_type": "markdown",
"id": "41348e03",
"metadata": {},
"source": [
"| Model | Language | Parameters | Model Size | Description | Base Model |\n",
"|:-------|:--------:|:--------------:|:--------------:|:-----------------:|:----------------:|\n",
"| [BAAI/bge-m3](https://huggingface.co/BAAI/bge-m3) | Multilingual | 568M | 2.27 GB | Multi-Functionality(dense retrieval, sparse retrieval, multi-vector(colbert)), Multi-Linguality, and Multi-Granularity(8192 tokens) | XLM-RoBERTa |"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "d4647625",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 228780.22it/s]\n"
]
}
],
"source": [
"from FlagEmbedding import BGEM3FlagModel\n",
"\n",
"model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)\n",
"\n",
"sentences = [\"What is BGE M3?\", \"Defination of BM25\"]"
]
},
{
"cell_type": "markdown",
"id": "1f89f1a9",
"metadata": {},
"source": [
"```python\n",
"BGEM3FlagModel.encode(\n",
" sentences, \n",
" batch_size=12, \n",
" max_length=8192, \n",
" return_dense=True, \n",
" return_sparse=False, \n",
" return_colbert_vecs=False\n",
")\n",
"```\n",
"It returns a dictionary like:\n",
"```python\n",
"{\n",
" 'dense_vecs': 'array of dense embeddings of inputs if return_dense=True, otherwise None,'\n",
" 'lexical_weights': 'array of dictionaries with keys and values are ids of tokens and their corresponding weights if return_sparse=True, otherwise None,'\n",
" 'colbert_vecs': 'array of multi-vector embeddings of inputs if return_cobert_vecs=True, otherwise None,'\n",
"}\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "f0b11cf0",
"metadata": {},
"outputs": [],
"source": [
"# If you don't need such a long length of 8192 input tokens, you can set max_length to a smaller value to speed up encoding.\n",
"embeddings = model.encode(\n",
" sentences, \n",
" max_length=10,\n",
" return_dense=True, \n",
" return_sparse=True, \n",
" return_colbert_vecs=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "72cba126",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"dense embedding:\n",
"[[-0.03411707 -0.04707828 -0.00089447 ... 0.04828531 0.00755427\n",
" -0.02961654]\n",
" [-0.01041734 -0.04479263 -0.02429199 ... -0.00819298 0.01503995\n",
" 0.01113793]]\n",
"sparse embedding:\n",
"[defaultdict(<class 'int'>, {'4865': 0.08362077, '83': 0.081469566, '335': 0.12964639, '11679': 0.25186998, '276': 0.17001738, '363': 0.26957875, '32': 0.040755156}), defaultdict(<class 'int'>, {'262': 0.050144322, '5983': 0.13689369, '2320': 0.045134712, '111': 0.06342201, '90017': 0.25167602, '2588': 0.33353207})]\n",
"multi-vector:\n",
"[array([[-8.6726490e-03, -4.8921868e-02, -3.0449261e-03, ...,\n",
" -2.2082448e-02, 5.7268854e-02, 1.2811369e-02],\n",
" [-8.8765034e-03, -4.6860173e-02, -9.5845405e-03, ...,\n",
" -3.1404708e-02, 5.3911421e-02, 6.8714428e-03],\n",
" [ 1.8445771e-02, -4.2359587e-02, 8.6754939e-04, ...,\n",
" -1.9803897e-02, 3.8384371e-02, 7.6852231e-03],\n",
" ...,\n",
" [-2.5543230e-02, -1.6561864e-02, -4.2125367e-02, ...,\n",
" -4.5030322e-02, 4.4091221e-02, -1.0043185e-02],\n",
" [ 4.9905590e-05, -5.5475257e-02, 8.4884483e-03, ...,\n",
" -2.2911752e-02, 6.0379632e-02, 9.3577225e-03],\n",
" [ 2.5895271e-03, -2.9331330e-02, -1.8961012e-02, ...,\n",
" -8.0389353e-03, 3.2842189e-02, 4.3894034e-02]], dtype=float32), array([[ 0.01715658, 0.03835309, -0.02311821, ..., 0.00146474,\n",
" 0.02993429, -0.05985384],\n",
" [ 0.00996143, 0.039217 , -0.03855301, ..., 0.00599566,\n",
" 0.02722942, -0.06509776],\n",
" [ 0.01777726, 0.03919311, -0.01709837, ..., 0.00805702,\n",
" 0.03988946, -0.05069073],\n",
" ...,\n",
" [ 0.05474931, 0.0075684 , 0.00329455, ..., -0.01651684,\n",
" 0.02397249, 0.00368039],\n",
" [ 0.0093503 , 0.05022853, -0.02385841, ..., 0.02575599,\n",
" 0.00786822, -0.03260205],\n",
" [ 0.01805054, 0.01337725, 0.00016697, ..., 0.01843987,\n",
" 0.01374448, 0.00310114]], dtype=float32)]\n"
]
}
],
"source": [
"print(f\"dense embedding:\\n{embeddings['dense_vecs']}\")\n",
"print(f\"sparse embedding:\\n{embeddings['lexical_weights']}\")\n",
"print(f\"multi-vector:\\n{embeddings['colbert_vecs']}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}