420 lines
12 KiB
Plaintext
Raw Normal View History

2024-11-08 15:34:28 +08:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# BGE Explanation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this section, we will go through BGE and BGE-v1.5's structure and how they generate embeddings."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 0. Installation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Install the required packages in your environment."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"%%capture\n",
"%pip install -U transformers FlagEmbedding"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Encode sentences"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To know how exactly a sentence is encoded, let's first load the tokenizer and model from HF transformers instead of FlagEmbedding"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer, AutoModel\n",
"import torch\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"BAAI/bge-base-en-v1.5\")\n",
"model = AutoModel.from_pretrained(\"BAAI/bge-base-en-v1.5\")\n",
"\n",
"sentences = [\"embedding\", \"I love machine learning and nlp\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Run the following cell to check the model of bge-base-en-v1.5. It has the exactly same structure of BERT-base, 12 encoder layers and hidden dimension of 768.\n",
"\n",
"Note that the corresponding models of BGE and BGE-v1.5 have same structures. For example, bge-base-en and bge-base-en-v1.5 have the same structure."
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"BertModel(\n",
" (embeddings): BertEmbeddings(\n",
" (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
" (position_embeddings): Embedding(512, 768)\n",
" (token_type_embeddings): Embedding(2, 768)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (encoder): BertEncoder(\n",
" (layer): ModuleList(\n",
" (0-11): 12 x BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" (intermediate_act_fn): GELUActivation()\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (pooler): BertPooler(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (activation): Tanh()\n",
" )\n",
")"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.eval()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, let's tokenize the sentences."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'input_ids': tensor([[ 101, 7861, 8270, 4667, 102, 0, 0, 0, 0],\n",
" [ 101, 1045, 2293, 3698, 4083, 1998, 17953, 2361, 102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 0, 0, 0, 0],\n",
" [1, 1, 1, 1, 1, 1, 1, 1, 1]])}"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs = tokenizer(\n",
" sentences, \n",
" padding=True, \n",
" truncation=True, \n",
" return_tensors='pt', \n",
" max_length=512\n",
")\n",
"inputs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"From the results, we can see that each sentence begins with token 101 and ends with 102, they are the `[CLS]` and `[SEP]` special token used in BERT."
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([2, 9, 768])"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"last_hidden_state = model(**inputs, return_dict=True).last_hidden_state\n",
"last_hidden_state.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we implement the pooling function, with two choices of using `[CLS]`'s last hidden state, or the mean pooling of the whole last hidden state."
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"def pooling(last_hidden_state: torch.Tensor, pooling_method='cls', attention_mask: torch.Tensor = None):\n",
" if pooling_method == 'cls':\n",
" return last_hidden_state[:, 0]\n",
" elif pooling_method == 'mean':\n",
" s = torch.sum(last_hidden_state * attention_mask.unsqueeze(-1).float(), dim=1)\n",
" d = attention_mask.sum(dim=1, keepdim=True).float()\n",
" return s / d"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Different from more commonly used mean pooling, BGE is trained to use the last hidden state of `[CLS]` as the sentence embedding: \n",
"\n",
"`sentence_embeddings = model_output[0][:, 0]`\n",
"\n",
"If you use mean pooling, there will be a significant decrease in performance. Therefore, make sure to use the correct method to obtain sentence vectors."
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([2, 768])"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"embeddings = pooling(\n",
" last_hidden_state, \n",
" pooling_method='cls', \n",
" attention_mask=inputs['attention_mask']\n",
")\n",
"embeddings.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Assembling them together, we get the whole encoding function:"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"def _encode(sentences, max_length=512, convert_to_numpy=True):\n",
"\n",
" # handle the case of single sentence and a list of sentences\n",
" input_was_string = False\n",
" if isinstance(sentences, str):\n",
" sentences = [sentences]\n",
" input_was_string = True\n",
"\n",
" inputs = tokenizer(\n",
" sentences, \n",
" padding=True, \n",
" truncation=True, \n",
" return_tensors='pt', \n",
" max_length=max_length\n",
" )\n",
"\n",
" last_hidden_state = model(**inputs, return_dict=True).last_hidden_state\n",
" \n",
" embeddings = pooling(\n",
" last_hidden_state, \n",
" pooling_method='cls', \n",
" attention_mask=inputs['attention_mask']\n",
" )\n",
"\n",
" # normalize the embedding vectors\n",
" embeddings = torch.nn.functional.normalize(embeddings, dim=-1)\n",
"\n",
" # convert to numpy if needed\n",
" if convert_to_numpy:\n",
" embeddings = embeddings.detach().numpy()\n",
"\n",
" return embeddings[0] if input_was_string else embeddings"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Comparison"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's run the function we wrote to get the embeddings of the two sentences:"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Embeddings:\n",
"[[ 1.4549762e-02 -9.6840411e-03 3.7761475e-03 ... -8.5092714e-04\n",
" 2.8417887e-02 6.3214332e-02]\n",
" [ 3.3924331e-05 -3.2998275e-03 1.7206438e-02 ... 3.5703944e-03\n",
" 1.8721525e-02 -2.0371782e-02]]\n",
"Similarity scores:\n",
"[[0.9999997 0.6077381]\n",
" [0.6077381 0.9999999]]\n"
]
}
],
"source": [
"embeddings = _encode(sentences)\n",
"print(f\"Embeddings:\\n{embeddings}\")\n",
"\n",
"scores = embeddings @ embeddings.T\n",
"print(f\"Similarity scores:\\n{scores}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then, run the API provided in FlagEmbedding:"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Embeddings:\n",
"[[ 1.4549762e-02 -9.6840411e-03 3.7761475e-03 ... -8.5092714e-04\n",
" 2.8417887e-02 6.3214332e-02]\n",
" [ 3.3924331e-05 -3.2998275e-03 1.7206438e-02 ... 3.5703944e-03\n",
" 1.8721525e-02 -2.0371782e-02]]\n",
"Similarity scores:\n",
"[[0.9999997 0.6077381]\n",
" [0.6077381 0.9999999]]\n"
]
}
],
"source": [
"from FlagEmbedding import FlagModel\n",
"\n",
"model = FlagModel('BAAI/bge-base-en-v1.5')\n",
"\n",
"embeddings = model.encode(sentences)\n",
"print(f\"Embeddings:\\n{embeddings}\")\n",
"\n",
"scores = embeddings @ embeddings.T\n",
"print(f\"Similarity scores:\\n{scores}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As we expect, the two encoding functions return exactly the same results. The full implementation in FlagEmbedding handles large datasets by batching and contains GPU support and parallelization. Feel free to check the [source code](https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/flag_models.py#L370) for more details."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}