mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-07-13 20:15:55 +00:00
420 lines
12 KiB
Plaintext
420 lines
12 KiB
Plaintext
![]() |
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# BGE Explanation"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"In this section, we will go through BGE and BGE-v1.5's structure and how they generate embeddings."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## 0. Installation"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Install the required packages in your environment."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 18,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"%%capture\n",
|
||
|
"%pip install -U transformers FlagEmbedding"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## 1. Encode sentences"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"To know how exactly a sentence is encoded, let's first load the tokenizer and model from HF transformers instead of FlagEmbedding"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 19,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from transformers import AutoTokenizer, AutoModel\n",
|
||
|
"import torch\n",
|
||
|
"\n",
|
||
|
"tokenizer = AutoTokenizer.from_pretrained(\"BAAI/bge-base-en-v1.5\")\n",
|
||
|
"model = AutoModel.from_pretrained(\"BAAI/bge-base-en-v1.5\")\n",
|
||
|
"\n",
|
||
|
"sentences = [\"embedding\", \"I love machine learning and nlp\"]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Run the following cell to check the model of bge-base-en-v1.5. It has the exactly same structure of BERT-base, 12 encoder layers and hidden dimension of 768.\n",
|
||
|
"\n",
|
||
|
"Note that the corresponding models of BGE and BGE-v1.5 have same structures. For example, bge-base-en and bge-base-en-v1.5 have the same structure."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 20,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"BertModel(\n",
|
||
|
" (embeddings): BertEmbeddings(\n",
|
||
|
" (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
|
||
|
" (position_embeddings): Embedding(512, 768)\n",
|
||
|
" (token_type_embeddings): Embedding(2, 768)\n",
|
||
|
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
||
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||
|
" )\n",
|
||
|
" (encoder): BertEncoder(\n",
|
||
|
" (layer): ModuleList(\n",
|
||
|
" (0-11): 12 x BertLayer(\n",
|
||
|
" (attention): BertAttention(\n",
|
||
|
" (self): BertSelfAttention(\n",
|
||
|
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
|
||
|
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
|
||
|
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
|
||
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||
|
" )\n",
|
||
|
" (output): BertSelfOutput(\n",
|
||
|
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
|
||
|
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
||
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||
|
" )\n",
|
||
|
" )\n",
|
||
|
" (intermediate): BertIntermediate(\n",
|
||
|
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
|
||
|
" (intermediate_act_fn): GELUActivation()\n",
|
||
|
" )\n",
|
||
|
" (output): BertOutput(\n",
|
||
|
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
|
||
|
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
||
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||
|
" )\n",
|
||
|
" )\n",
|
||
|
" )\n",
|
||
|
" )\n",
|
||
|
" (pooler): BertPooler(\n",
|
||
|
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
|
||
|
" (activation): Tanh()\n",
|
||
|
" )\n",
|
||
|
")"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 20,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"model.eval()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"First, let's tokenize the sentences."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 21,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"{'input_ids': tensor([[ 101, 7861, 8270, 4667, 102, 0, 0, 0, 0],\n",
|
||
|
" [ 101, 1045, 2293, 3698, 4083, 1998, 17953, 2361, 102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
|
||
|
" [0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 0, 0, 0, 0],\n",
|
||
|
" [1, 1, 1, 1, 1, 1, 1, 1, 1]])}"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 21,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"inputs = tokenizer(\n",
|
||
|
" sentences, \n",
|
||
|
" padding=True, \n",
|
||
|
" truncation=True, \n",
|
||
|
" return_tensors='pt', \n",
|
||
|
" max_length=512\n",
|
||
|
")\n",
|
||
|
"inputs"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"From the results, we can see that each sentence begins with token 101 and ends with 102, they are the `[CLS]` and `[SEP]` special token used in BERT."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 22,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"torch.Size([2, 9, 768])"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 22,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"last_hidden_state = model(**inputs, return_dict=True).last_hidden_state\n",
|
||
|
"last_hidden_state.shape"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Here we implement the pooling function, with two choices of using `[CLS]`'s last hidden state, or the mean pooling of the whole last hidden state."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 23,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def pooling(last_hidden_state: torch.Tensor, pooling_method='cls', attention_mask: torch.Tensor = None):\n",
|
||
|
" if pooling_method == 'cls':\n",
|
||
|
" return last_hidden_state[:, 0]\n",
|
||
|
" elif pooling_method == 'mean':\n",
|
||
|
" s = torch.sum(last_hidden_state * attention_mask.unsqueeze(-1).float(), dim=1)\n",
|
||
|
" d = attention_mask.sum(dim=1, keepdim=True).float()\n",
|
||
|
" return s / d"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Different from more commonly used mean pooling, BGE is trained to use the last hidden state of `[CLS]` as the sentence embedding: \n",
|
||
|
"\n",
|
||
|
"`sentence_embeddings = model_output[0][:, 0]`\n",
|
||
|
"\n",
|
||
|
"If you use mean pooling, there will be a significant decrease in performance. Therefore, make sure to use the correct method to obtain sentence vectors."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 24,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"torch.Size([2, 768])"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 24,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"embeddings = pooling(\n",
|
||
|
" last_hidden_state, \n",
|
||
|
" pooling_method='cls', \n",
|
||
|
" attention_mask=inputs['attention_mask']\n",
|
||
|
")\n",
|
||
|
"embeddings.shape"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Assembling them together, we get the whole encoding function:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 25,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def _encode(sentences, max_length=512, convert_to_numpy=True):\n",
|
||
|
"\n",
|
||
|
" # handle the case of single sentence and a list of sentences\n",
|
||
|
" input_was_string = False\n",
|
||
|
" if isinstance(sentences, str):\n",
|
||
|
" sentences = [sentences]\n",
|
||
|
" input_was_string = True\n",
|
||
|
"\n",
|
||
|
" inputs = tokenizer(\n",
|
||
|
" sentences, \n",
|
||
|
" padding=True, \n",
|
||
|
" truncation=True, \n",
|
||
|
" return_tensors='pt', \n",
|
||
|
" max_length=max_length\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
" last_hidden_state = model(**inputs, return_dict=True).last_hidden_state\n",
|
||
|
" \n",
|
||
|
" embeddings = pooling(\n",
|
||
|
" last_hidden_state, \n",
|
||
|
" pooling_method='cls', \n",
|
||
|
" attention_mask=inputs['attention_mask']\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
" # normalize the embedding vectors\n",
|
||
|
" embeddings = torch.nn.functional.normalize(embeddings, dim=-1)\n",
|
||
|
"\n",
|
||
|
" # convert to numpy if needed\n",
|
||
|
" if convert_to_numpy:\n",
|
||
|
" embeddings = embeddings.detach().numpy()\n",
|
||
|
"\n",
|
||
|
" return embeddings[0] if input_was_string else embeddings"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## 2. Comparison"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Now let's run the function we wrote to get the embeddings of the two sentences:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 26,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Embeddings:\n",
|
||
|
"[[ 1.4549762e-02 -9.6840411e-03 3.7761475e-03 ... -8.5092714e-04\n",
|
||
|
" 2.8417887e-02 6.3214332e-02]\n",
|
||
|
" [ 3.3924331e-05 -3.2998275e-03 1.7206438e-02 ... 3.5703944e-03\n",
|
||
|
" 1.8721525e-02 -2.0371782e-02]]\n",
|
||
|
"Similarity scores:\n",
|
||
|
"[[0.9999997 0.6077381]\n",
|
||
|
" [0.6077381 0.9999999]]\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"embeddings = _encode(sentences)\n",
|
||
|
"print(f\"Embeddings:\\n{embeddings}\")\n",
|
||
|
"\n",
|
||
|
"scores = embeddings @ embeddings.T\n",
|
||
|
"print(f\"Similarity scores:\\n{scores}\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Then, run the API provided in FlagEmbedding:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 27,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Embeddings:\n",
|
||
|
"[[ 1.4549762e-02 -9.6840411e-03 3.7761475e-03 ... -8.5092714e-04\n",
|
||
|
" 2.8417887e-02 6.3214332e-02]\n",
|
||
|
" [ 3.3924331e-05 -3.2998275e-03 1.7206438e-02 ... 3.5703944e-03\n",
|
||
|
" 1.8721525e-02 -2.0371782e-02]]\n",
|
||
|
"Similarity scores:\n",
|
||
|
"[[0.9999997 0.6077381]\n",
|
||
|
" [0.6077381 0.9999999]]\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from FlagEmbedding import FlagModel\n",
|
||
|
"\n",
|
||
|
"model = FlagModel('BAAI/bge-base-en-v1.5')\n",
|
||
|
"\n",
|
||
|
"embeddings = model.encode(sentences)\n",
|
||
|
"print(f\"Embeddings:\\n{embeddings}\")\n",
|
||
|
"\n",
|
||
|
"scores = embeddings @ embeddings.T\n",
|
||
|
"print(f\"Similarity scores:\\n{scores}\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"As we expect, the two encoding functions return exactly the same results. The full implementation in FlagEmbedding handles large datasets by batching and contains GPU support and parallelization. Feel free to check the [source code](https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/flag_models.py#L370) for more details."
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "base",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.10.13"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 2
|
||
|
}
|