mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2026-01-05 11:41:28 +00:00
update Tutorials
This commit is contained in:
parent
ddad0f9cb9
commit
4ca58751f1
395
Tutorials/1_Embedding/1.1_Intro&Inference.ipynb
Normal file
395
Tutorials/1_Embedding/1.1_Intro&Inference.ipynb
Normal file
@ -0,0 +1,395 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Intro to Embedding"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"For text retrieval, pattern matching is the most intuitive way. People would use certain characters, words, phrases, or sentence patterns. However, not only for human, it is also extremely inefficient for computer to do pattern matching between a query and a collection of text files to find the possible results. \n",
|
||||
"\n",
|
||||
"For images and acoustic waves, there are rgb pixels and digital signals. Similarly, in order to accomplish more sophisticated tasks of natural language such as retrieval, classification, clustering, or semantic search, we need a way to represent text data. That's how text embedding comes in front of the stage."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1. Background"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Traditional text embedding methods like one-hot encoding and bag-of-words (BoW) represent words and sentences as sparse vectors based on their statistical features, such as word appearance and frequency within a document. More advanced methods like TF-IDF and BM25 improve on these by considering a word's importance across an entire corpus, while n-gram techniques capture word order in small groups. However, these approaches suffer from the \"curse of dimensionality\" and fail to capture semantic similarity like \"cat\" and \"kitty\", difference like \"play the watch\" and \"watch the play\"."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# example of bag-of-words\n",
|
||||
"sentence1 = \"I love basketball\"\n",
|
||||
"sentence2 = \"I have a basketball match\"\n",
|
||||
"\n",
|
||||
"words = ['I', 'love', 'basketball', 'have', 'a', 'match']\n",
|
||||
"sen1_vec = [1, 1, 1, 0, 0, 0]\n",
|
||||
"sen2_vec = [1, 0, 1, 1, 1, 1]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To overcome these limitations, dense word embeddings were developed, mapping words to vectors in a low-dimensional space that captures semantic and relational information. Early models like Word2Vec demonstrated the power of dense embeddings using neural networks. Subsequent advancements with neural network architectures like RNNs, LSTMs, and Transformers have enabled more sophisticated models such as BERT, RoBERTa, and GPT to excel in capturing complex word relationships and contexts. **BAAI General Embedding (BGE)** provide a series of open-source models that could satisfy all kinds of demands."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Get Embedding"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The first step of modern text retrieval is embedding the text. So let's take a look at how to use the embedding models."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Install the packages:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%%capture\n",
|
||||
"%pip install -U FlagEmbedding sentence_transformers openai cohere"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We'll use the following three sentences as the inputs:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentences = [\n",
|
||||
" \"That is a happy dog\",\n",
|
||||
" \"That is a very happy person\",\n",
|
||||
" \"Today is a sunny day\",\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Open-source Models"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"A huge portion of embedding models are in the open source community. The advantages of open-source models include:\n",
|
||||
"- Free, no extra cost. But make sure to check the License and your use case before using.\n",
|
||||
"- No frequency limit, can accelerate a lot if you have enough GPUs to parallelize.\n",
|
||||
"- Transparent and might be reproducible.\n",
|
||||
"\n",
|
||||
"Let's take a look at two representatives:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### BGE"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"BGE is a series of embedding models and rerankers published by BAAI. Several of them reached SOTA at the time they released."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Embeddings:\n",
|
||||
"(3, 768)\n",
|
||||
"Similarity scores:\n",
|
||||
"[[1. 0.7900386 0.57525384]\n",
|
||||
" [0.7900386 0.9999998 0.59190154]\n",
|
||||
" [0.57525384 0.59190154 0.99999994]]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from FlagEmbedding import FlagModel\n",
|
||||
"\n",
|
||||
"# Load BGE model\n",
|
||||
"model = FlagModel('BAAI/bge-base-en-v1.5')\n",
|
||||
"\n",
|
||||
"# encode the queries and corpus\n",
|
||||
"embeddings = model.encode(sentences)\n",
|
||||
"print(f\"Embeddings:\\n{embeddings.shape}\")\n",
|
||||
"\n",
|
||||
"scores = embeddings @ embeddings.T\n",
|
||||
"print(f\"Similarity scores:\\n{scores}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Sentence Transformers"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Sentence Transformers is a library for sentence embeddings with a huge amount of embedding models and datasets for related tasks."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Embeddings:\n",
|
||||
"(3, 384)\n",
|
||||
"Similarity scores:\n",
|
||||
"[[0.99999976 0.6210502 0.24906276]\n",
|
||||
" [0.6210502 0.9999997 0.21061528]\n",
|
||||
" [0.24906276 0.21061528 0.9999999 ]]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from sentence_transformers import SentenceTransformer\n",
|
||||
"\n",
|
||||
"model = SentenceTransformer(\"all-MiniLM-L6-v2\")\n",
|
||||
"\n",
|
||||
"embeddings = model.encode(sentences, normalize_embeddings=True)\n",
|
||||
"print(f\"Embeddings:\\n{embeddings.shape}\")\n",
|
||||
"\n",
|
||||
"scores = embeddings @ embeddings.T\n",
|
||||
"print(f\"Similarity scores:\\n{scores}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Commercial Models"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"There are also plenty choices of commercial models. They have the advantages of:\n",
|
||||
"- Efficient memory usage, fast inference with no need of GPUs.\n",
|
||||
"- Systematic support, commercial models have closer connections with their other products.\n",
|
||||
"- Better training data, commercial models might be trained on larger, higher-quality datasets than some open-source models."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### OpenAI"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Along with GPT series, OpenAI has their own embedding models. Make sure to fill in your own API key in the field `\"YOUR_API_KEY\"`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = \"YOUR_API_KEY\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Then run the following cells to get the embeddings. Check their official [documentation](https://platform.openai.com/docs/guides/embeddings) for more details."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from openai import OpenAI\n",
|
||||
"\n",
|
||||
"client = OpenAI()\n",
|
||||
"\n",
|
||||
"response = client.embeddings.create(input = sentences, model=\"text-embedding-3-small\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Embeddings:\n",
|
||||
"(3, 1536)\n",
|
||||
"Similarity scores:\n",
|
||||
"[[1.00000004 0.697673 0.34739798]\n",
|
||||
" [0.697673 1.00000005 0.31969923]\n",
|
||||
" [0.34739798 0.31969923 0.99999998]]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"embeddings = np.asarray([response.data[i].embedding for i in range(len(sentences))])\n",
|
||||
"print(f\"Embeddings:\\n{embeddings.shape}\")\n",
|
||||
"\n",
|
||||
"scores = embeddings @ embeddings.T\n",
|
||||
"print(f\"Similarity scores:\\n{scores}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Voyage AI"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Voyage AI provides embedding models and rerankers for different purpus and in various fields. Their API keys can be freely used in low frequency and token length."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"os.environ[\"VOYAGE_API_KEY\"] = \"YOUR_API_KEY\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Check their official [documentation](https://docs.voyageai.com/docs/api-key-and-installation) for more details."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import voyageai\n",
|
||||
"\n",
|
||||
"vo = voyageai.Client()\n",
|
||||
"\n",
|
||||
"result = vo.embed(sentences, model=\"voyage-large-2-instruct\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Embeddings:\n",
|
||||
"(3, 1024)\n",
|
||||
"Similarity scores:\n",
|
||||
"[[0.99999997 0.87282517 0.63276503]\n",
|
||||
" [0.87282517 0.99999998 0.64720015]\n",
|
||||
" [0.63276503 0.64720015 0.99999999]]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"embeddings = np.asarray(result.embeddings)\n",
|
||||
"print(f\"Embeddings:\\n{embeddings.shape}\")\n",
|
||||
"\n",
|
||||
"scores = embeddings @ embeddings.T\n",
|
||||
"print(f\"Similarity scores:\\n{scores}\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "base",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@ -1,69 +1,11 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d34077d7-dd3a-49d7-b87b-2914d52992a6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Intro"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "cb0755e6-6ef9-4eb9-b3e6-131a0820ffc7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"For text retrieval, pattern matching is the most intuitive way. People would use certain characters, words, phrases, or sentence patterns. However, not only for human, it is also extremely inefficient for computer to do pattern matching between a query and a collection of text files to find the possible results. \n",
|
||||
"\n",
|
||||
"For images and acoustic waves, there are rgb pixels and digital signals. Similarly, in order to accomplish more sophisticated tasks of natural language such as retrieval, classification, clustering, or semantic search, we need a way to represent text data. That's how text embedding comes in front of the stage."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e53fb5bc",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1. Background"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ca2b21dc",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Traditional text embedding methods like one-hot encoding and bag-of-words (BoW) represent words and sentences as sparse vectors based on their statistical features, such as word appearance and frequency within a document. More advanced methods like TF-IDF and BM25 improve on these by considering a word's importance across an entire corpus, while n-gram techniques capture word order in small groups. However, these approaches suffer from the \"curse of dimensionality\" and fail to capture semantic similarity like \"cat\" and \"kitty\", difference like \"play the watch\" and \"watch the play\"."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "7d7ea92b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# example of bag-of-words\n",
|
||||
"sentence1 = \"I love basketball\"\n",
|
||||
"sentence2 = \"I have a basketball match\"\n",
|
||||
"\n",
|
||||
"words = ['I', 'love', 'basketball', 'have', 'a', 'match']\n",
|
||||
"sen1_vec = [1, 1, 1, 0, 0, 0]\n",
|
||||
"sen2_vec = [1, 0, 1, 1, 1, 1]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9958d001",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To overcome these limitations, dense word embeddings were developed, mapping words to vectors in a low-dimensional space that captures semantic and relational information. Early models like Word2Vec demonstrated the power of dense embeddings using neural networks. Subsequent advancements with neural network architectures like RNNs, LSTMs, and Transformers have enabled more sophisticated models such as BERT, RoBERTa, and GPT to excel in capturing complex word relationships and contexts. **BAAI General Embedding (BGE)** provide a series of open-source models that could satisfy all kinds of demands."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "06cff9e4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2. BAAI General Embedding"
|
||||
"# BGE Series"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -71,21 +13,62 @@
|
||||
"id": "880e229d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In this Part, we will walk through the BGE series and introduce how to use those embedding models.\n",
|
||||
"In this Part, we will walk through the BGE series and introduce how to use the BGE embedding models."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2516fd49",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1. BAAI General Embedding"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2113ee71",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"BGE stands for BAAI General Embedding, it's a series of embeddings models developed and published by Beijing Academy of Artificial Intelligence (BAAI)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "16515b99",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"A full support of APIs and related usages of BGE is maintained in [FlagEmbedding](https://github.com/FlagOpen/FlagEmbedding) on GitHub.\n",
|
||||
"\n",
|
||||
"First, install the FlagEmbedding in your environment."
|
||||
"Run the following cell to install FlagEmbedding in your environment."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "af533a6f",
|
||||
"id": "88095fd0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%%capture\n",
|
||||
"%pip install -U FlagEmbedding"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bc6e30a0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The collection of BGE models can be found in [Huggingface collection](https://huggingface.co/collections/BAAI/bge-66797a74476eb1f085c7446d)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "67a16ccf",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2. BGE Series Models"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2e10034a",
|
||||
@ -127,19 +110,10 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": null,
|
||||
"id": "89e07751",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[[0.8888277 0.82843924]\n",
|
||||
" [0.80761224 0.8892383 ]]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from FlagEmbedding import FlagModel\n",
|
||||
"\n",
|
||||
@ -181,7 +155,7 @@
|
||||
"id": "2c86a5a3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 2.2 BGE 1.5"
|
||||
"### 2.2 BGE v1.5"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -217,7 +191,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 4,
|
||||
"id": "9b17afcc",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -297,7 +271,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 5,
|
||||
"id": "5f077420",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -363,7 +337,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 5,
|
||||
"id": "d4647625",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -371,7 +345,7 @@
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 216946.76it/s]\n"
|
||||
"Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 228780.22it/s]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -380,9 +354,7 @@
|
||||
"\n",
|
||||
"model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)\n",
|
||||
"\n",
|
||||
"sentences_1 = [\"What is BGE M3?\", \"Defination of BM25\"]\n",
|
||||
"sentences_2 = [\"BGE M3 is an embedding model supporting dense retrieval, lexical matching and multi-vector interaction.\", \n",
|
||||
" \"BM25 is a bag-of-words retrieval function that ranks a set of documents based on the query terms appearing in each document\"]"
|
||||
"sentences = [\"What is BGE M3?\", \"Defination of BM25\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -403,190 +375,80 @@
|
||||
"It returns a dictionary like:\n",
|
||||
"```\n",
|
||||
"{\n",
|
||||
" 'dense_vecs': array of dense embeddings if return_dense=Ture, otherwise None,\n",
|
||||
" 'dense_vecs': array of dense embeddings of inputs if return_dense=True, otherwise None,\n",
|
||||
" 'lexical_weights': array of dictionaries with keys and values are ids of tokens and their corresponding weights if return_sparse=True, otherwise None,\n",
|
||||
" 'colbert_vecs': \n",
|
||||
" 'colbert_vecs': array of multi-vector embeddings of inputs if return_cobert_vecs=True, otherwise None,\n",
|
||||
"}\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7c06e113",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### 2.4.1 Dense Retrieval"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e7361f61",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"It's almost the same to BGE or BGE 1.5 models if using BGE M3 for dense embedding. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "76f15175",
|
||||
"execution_count": 6,
|
||||
"id": "f0b11cf0",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[[0.6259035 0.34749585]\n",
|
||||
" [0.349868 0.6782462 ]]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# If you don't need such a long length of 8192 input tokens, you can set max_length to a smaller value to speed up encoding.\n",
|
||||
"embeddings_1 = model.encode(sentences_1, max_length=10)['dense_vecs']\n",
|
||||
"embeddings_2 = model.encode(sentences_2, max_length=100)['dense_vecs']\n",
|
||||
"\n",
|
||||
"# compute the similarity scores\n",
|
||||
"similarity = embeddings_1 @ embeddings_2.T\n",
|
||||
"print(similarity)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8be70fe2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### 2.4.2 Sparse Retrieval"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4d0d91a1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Set `return_sparse` to true to make the model return sparse vector. If a term token appears multiple times in the sentence, we only retain its max weight."
|
||||
"embeddings = model.encode(\n",
|
||||
" sentences, \n",
|
||||
" max_length=10,\n",
|
||||
" return_dense=True, \n",
|
||||
" return_sparse=True, \n",
|
||||
" return_colbert_vecs=True\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "5258d5cb",
|
||||
"execution_count": 8,
|
||||
"id": "72cba126",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[{'What': 0.08362077, 'is': 0.081469566, 'B': 0.12964639, 'GE': 0.25186998, 'M': 0.17001738, '3': 0.26957875, '?': 0.040755156}, {'De': 0.050144322, 'fin': 0.13689369, 'ation': 0.045134712, 'of': 0.06342201, 'BM': 0.25167602, '25': 0.33353207}]\n"
|
||||
"dense embedding:\n",
|
||||
"[[-0.03411707 -0.04707828 -0.00089447 ... 0.04828531 0.00755427\n",
|
||||
" -0.02961654]\n",
|
||||
" [-0.01041734 -0.04479263 -0.02429199 ... -0.00819298 0.01503995\n",
|
||||
" 0.01113793]]\n",
|
||||
"sparse embedding:\n",
|
||||
"[defaultdict(<class 'int'>, {'4865': 0.08362077, '83': 0.081469566, '335': 0.12964639, '11679': 0.25186998, '276': 0.17001738, '363': 0.26957875, '32': 0.040755156}), defaultdict(<class 'int'>, {'262': 0.050144322, '5983': 0.13689369, '2320': 0.045134712, '111': 0.06342201, '90017': 0.25167602, '2588': 0.33353207})]\n",
|
||||
"multi-vector:\n",
|
||||
"[array([[-8.6726490e-03, -4.8921868e-02, -3.0449261e-03, ...,\n",
|
||||
" -2.2082448e-02, 5.7268854e-02, 1.2811369e-02],\n",
|
||||
" [-8.8765034e-03, -4.6860173e-02, -9.5845405e-03, ...,\n",
|
||||
" -3.1404708e-02, 5.3911421e-02, 6.8714428e-03],\n",
|
||||
" [ 1.8445771e-02, -4.2359587e-02, 8.6754939e-04, ...,\n",
|
||||
" -1.9803897e-02, 3.8384371e-02, 7.6852231e-03],\n",
|
||||
" ...,\n",
|
||||
" [-2.5543230e-02, -1.6561864e-02, -4.2125367e-02, ...,\n",
|
||||
" -4.5030322e-02, 4.4091221e-02, -1.0043185e-02],\n",
|
||||
" [ 4.9905590e-05, -5.5475257e-02, 8.4884483e-03, ...,\n",
|
||||
" -2.2911752e-02, 6.0379632e-02, 9.3577225e-03],\n",
|
||||
" [ 2.5895271e-03, -2.9331330e-02, -1.8961012e-02, ...,\n",
|
||||
" -8.0389353e-03, 3.2842189e-02, 4.3894034e-02]], dtype=float32), array([[ 0.01715658, 0.03835309, -0.02311821, ..., 0.00146474,\n",
|
||||
" 0.02993429, -0.05985384],\n",
|
||||
" [ 0.00996143, 0.039217 , -0.03855301, ..., 0.00599566,\n",
|
||||
" 0.02722942, -0.06509776],\n",
|
||||
" [ 0.01777726, 0.03919311, -0.01709837, ..., 0.00805702,\n",
|
||||
" 0.03988946, -0.05069073],\n",
|
||||
" ...,\n",
|
||||
" [ 0.05474931, 0.0075684 , 0.00329455, ..., -0.01651684,\n",
|
||||
" 0.02397249, 0.00368039],\n",
|
||||
" [ 0.0093503 , 0.05022853, -0.02385841, ..., 0.02575599,\n",
|
||||
" 0.00786822, -0.03260205],\n",
|
||||
" [ 0.01805054, 0.01337725, 0.00016697, ..., 0.01843987,\n",
|
||||
" 0.01374448, 0.00310114]], dtype=float32)]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"output_1 = model.encode(sentences_1, return_sparse=True)\n",
|
||||
"output_2 = model.encode(sentences_2, return_sparse=True)\n",
|
||||
"\n",
|
||||
"# you can see the weight for each token:\n",
|
||||
"print(model.convert_id_to_token(output_1['lexical_weights']))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bffc20c9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Based on the tokens' weights of query and passage, the relevance score between them is computed by the joint importance of the co-existed terms within the query and passage:\n",
|
||||
"\n",
|
||||
"$$s_{lex} = \\sum_{t\\in q\\cap p}(w_{qt} * w_{pt})$$"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"id": "867e2148",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"0.19554448500275612\n",
|
||||
"0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# compute the scores via lexical mathcing\n",
|
||||
"score_1 = model.compute_lexical_matching_score(output_1['lexical_weights'][0], output_2['lexical_weights'][0])\n",
|
||||
"print(score_1)\n",
|
||||
"\n",
|
||||
"score_2 = model.compute_lexical_matching_score(output_1['lexical_weights'][0], output_1['lexical_weights'][1])\n",
|
||||
"print(score_2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "cdd6bbe2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### 3.4.2 Multi-Vector"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8c243114",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The multi-vector method utilizes the entire output embeddings for the representation of query and passage."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"id": "21b27182",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"(8, 1024)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"output_1 = model.encode(sentences_1, return_dense=True, return_sparse=True, return_colbert_vecs=True)\n",
|
||||
"output_2 = model.encode(sentences_2, return_dense=True, return_sparse=True, return_colbert_vecs=True)\n",
|
||||
"\n",
|
||||
"print(f\"({len(output_1['colbert_vecs'][0])}, {len(output_1['colbert_vecs'][0][0])})\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "dec335e5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Following ColBert, we use late-interaction to compute the fine-grained relevance score:\n",
|
||||
"$$s_{mul}=\\frac{1}{N}\\sum_{i=1}^N\\max_{j=1}^M E_q[i]\\cdot E_p^T[j]$$"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"id": "575e38e5",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"0.7796662449836731\n",
|
||||
"0.4621177911758423\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(model.colbert_score(output_1['colbert_vecs'][0], output_2['colbert_vecs'][0]).item())\n",
|
||||
"print(model.colbert_score(output_1['colbert_vecs'][0], output_2['colbert_vecs'][1]).item())"
|
||||
"print(f\"dense embedding:\\n{embeddings['dense_vecs']}\")\n",
|
||||
"print(f\"sparse embedding:\\n{embeddings['lexical_weights']}\")\n",
|
||||
"print(f\"multi-vector:\\n{embeddings['colbert_vecs']}\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
419
Tutorials/1_Embedding/1.2.2_BGE_Explanation.ipynb
Normal file
419
Tutorials/1_Embedding/1.2.2_BGE_Explanation.ipynb
Normal file
@ -0,0 +1,419 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# BGE Explanation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In this section, we will go through BGE and BGE-v1.5's structure and how they generate embeddings."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 0. Installation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Install the required packages in your environment."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%%capture\n",
|
||||
"%pip install -U transformers FlagEmbedding"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1. Encode sentences"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To know how exactly a sentence is encoded, let's first load the tokenizer and model from HF transformers instead of FlagEmbedding"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import AutoTokenizer, AutoModel\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(\"BAAI/bge-base-en-v1.5\")\n",
|
||||
"model = AutoModel.from_pretrained(\"BAAI/bge-base-en-v1.5\")\n",
|
||||
"\n",
|
||||
"sentences = [\"embedding\", \"I love machine learning and nlp\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Run the following cell to check the model of bge-base-en-v1.5. It has the exactly same structure of BERT-base, 12 encoder layers and hidden dimension of 768.\n",
|
||||
"\n",
|
||||
"Note that the corresponding models of BGE and BGE-v1.5 have same structures. For example, bge-base-en and bge-base-en-v1.5 have the same structure."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"BertModel(\n",
|
||||
" (embeddings): BertEmbeddings(\n",
|
||||
" (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
|
||||
" (position_embeddings): Embedding(512, 768)\n",
|
||||
" (token_type_embeddings): Embedding(2, 768)\n",
|
||||
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
||||
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||||
" )\n",
|
||||
" (encoder): BertEncoder(\n",
|
||||
" (layer): ModuleList(\n",
|
||||
" (0-11): 12 x BertLayer(\n",
|
||||
" (attention): BertAttention(\n",
|
||||
" (self): BertSelfAttention(\n",
|
||||
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
|
||||
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
|
||||
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
|
||||
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||||
" )\n",
|
||||
" (output): BertSelfOutput(\n",
|
||||
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
|
||||
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
||||
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (intermediate): BertIntermediate(\n",
|
||||
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
|
||||
" (intermediate_act_fn): GELUActivation()\n",
|
||||
" )\n",
|
||||
" (output): BertOutput(\n",
|
||||
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
|
||||
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
||||
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (pooler): BertPooler(\n",
|
||||
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
|
||||
" (activation): Tanh()\n",
|
||||
" )\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.eval()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"First, let's tokenize the sentences."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'input_ids': tensor([[ 101, 7861, 8270, 4667, 102, 0, 0, 0, 0],\n",
|
||||
" [ 101, 1045, 2293, 3698, 4083, 1998, 17953, 2361, 102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
|
||||
" [0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 0, 0, 0, 0],\n",
|
||||
" [1, 1, 1, 1, 1, 1, 1, 1, 1]])}"
|
||||
]
|
||||
},
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"inputs = tokenizer(\n",
|
||||
" sentences, \n",
|
||||
" padding=True, \n",
|
||||
" truncation=True, \n",
|
||||
" return_tensors='pt', \n",
|
||||
" max_length=512\n",
|
||||
")\n",
|
||||
"inputs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"From the results, we can see that each sentence begins with token 101 and ends with 102, they are the `[CLS]` and `[SEP]` special token used in BERT."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"torch.Size([2, 9, 768])"
|
||||
]
|
||||
},
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"last_hidden_state = model(**inputs, return_dict=True).last_hidden_state\n",
|
||||
"last_hidden_state.shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Here we implement the pooling function, with two choices of using `[CLS]`'s last hidden state, or the mean pooling of the whole last hidden state."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def pooling(last_hidden_state: torch.Tensor, pooling_method='cls', attention_mask: torch.Tensor = None):\n",
|
||||
" if pooling_method == 'cls':\n",
|
||||
" return last_hidden_state[:, 0]\n",
|
||||
" elif pooling_method == 'mean':\n",
|
||||
" s = torch.sum(last_hidden_state * attention_mask.unsqueeze(-1).float(), dim=1)\n",
|
||||
" d = attention_mask.sum(dim=1, keepdim=True).float()\n",
|
||||
" return s / d"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Different from more commonly used mean pooling, BGE is trained to use the last hidden state of `[CLS]` as the sentence embedding: \n",
|
||||
"\n",
|
||||
"`sentence_embeddings = model_output[0][:, 0]`\n",
|
||||
"\n",
|
||||
"If you use mean pooling, there will be a significant decrease in performance. Therefore, make sure to use the correct method to obtain sentence vectors."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"torch.Size([2, 768])"
|
||||
]
|
||||
},
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"embeddings = pooling(\n",
|
||||
" last_hidden_state, \n",
|
||||
" pooling_method='cls', \n",
|
||||
" attention_mask=inputs['attention_mask']\n",
|
||||
")\n",
|
||||
"embeddings.shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Assembling them together, we get the whole encoding function:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def _encode(sentences, max_length=512, convert_to_numpy=True):\n",
|
||||
"\n",
|
||||
" # handle the case of single sentence and a list of sentences\n",
|
||||
" input_was_string = False\n",
|
||||
" if isinstance(sentences, str):\n",
|
||||
" sentences = [sentences]\n",
|
||||
" input_was_string = True\n",
|
||||
"\n",
|
||||
" inputs = tokenizer(\n",
|
||||
" sentences, \n",
|
||||
" padding=True, \n",
|
||||
" truncation=True, \n",
|
||||
" return_tensors='pt', \n",
|
||||
" max_length=max_length\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" last_hidden_state = model(**inputs, return_dict=True).last_hidden_state\n",
|
||||
" \n",
|
||||
" embeddings = pooling(\n",
|
||||
" last_hidden_state, \n",
|
||||
" pooling_method='cls', \n",
|
||||
" attention_mask=inputs['attention_mask']\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # normalize the embedding vectors\n",
|
||||
" embeddings = torch.nn.functional.normalize(embeddings, dim=-1)\n",
|
||||
"\n",
|
||||
" # convert to numpy if needed\n",
|
||||
" if convert_to_numpy:\n",
|
||||
" embeddings = embeddings.detach().numpy()\n",
|
||||
"\n",
|
||||
" return embeddings[0] if input_was_string else embeddings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2. Comparison"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now let's run the function we wrote to get the embeddings of the two sentences:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Embeddings:\n",
|
||||
"[[ 1.4549762e-02 -9.6840411e-03 3.7761475e-03 ... -8.5092714e-04\n",
|
||||
" 2.8417887e-02 6.3214332e-02]\n",
|
||||
" [ 3.3924331e-05 -3.2998275e-03 1.7206438e-02 ... 3.5703944e-03\n",
|
||||
" 1.8721525e-02 -2.0371782e-02]]\n",
|
||||
"Similarity scores:\n",
|
||||
"[[0.9999997 0.6077381]\n",
|
||||
" [0.6077381 0.9999999]]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"embeddings = _encode(sentences)\n",
|
||||
"print(f\"Embeddings:\\n{embeddings}\")\n",
|
||||
"\n",
|
||||
"scores = embeddings @ embeddings.T\n",
|
||||
"print(f\"Similarity scores:\\n{scores}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Then, run the API provided in FlagEmbedding:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Embeddings:\n",
|
||||
"[[ 1.4549762e-02 -9.6840411e-03 3.7761475e-03 ... -8.5092714e-04\n",
|
||||
" 2.8417887e-02 6.3214332e-02]\n",
|
||||
" [ 3.3924331e-05 -3.2998275e-03 1.7206438e-02 ... 3.5703944e-03\n",
|
||||
" 1.8721525e-02 -2.0371782e-02]]\n",
|
||||
"Similarity scores:\n",
|
||||
"[[0.9999997 0.6077381]\n",
|
||||
" [0.6077381 0.9999999]]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from FlagEmbedding import FlagModel\n",
|
||||
"\n",
|
||||
"model = FlagModel('BAAI/bge-base-en-v1.5')\n",
|
||||
"\n",
|
||||
"embeddings = model.encode(sentences)\n",
|
||||
"print(f\"Embeddings:\\n{embeddings}\")\n",
|
||||
"\n",
|
||||
"scores = embeddings @ embeddings.T\n",
|
||||
"print(f\"Similarity scores:\\n{scores}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"As we expect, the two encoding functions return exactly the same results. The full implementation in FlagEmbedding handles large datasets by batching and contains GPU support and parallelization. Feel free to check the [source code](https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/flag_models.py#L370) for more details."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "base",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
574
Tutorials/5_Reranking/reranker.ipynb
Normal file
574
Tutorials/5_Reranking/reranker.ipynb
Normal file
@ -0,0 +1,574 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Reranker"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Reranker is disigned in cross-encoder architecture that takes the query and text at the same time and directly output their score of similarity. It is more capable on scoring the query-text relevance, but with the tradeoff of slower speed. Thus, complete retrieval system usually contains retrievers in the first stage to do a large scope retrieval, and then follows by rerankers to rerank the results more precisely.\n",
|
||||
"\n",
|
||||
"In this tutorial, we will go through text retrieval pipeline with reranker and evaluate the results before and after reranking.\n",
|
||||
"\n",
|
||||
"Note: Step 1-4 are identical to the tutorial of [evaluation](https://github.com/FlagOpen/FlagEmbedding/tree/master/Tutorials/4_Evaluation). We suggest to first go through that if you are not familiar with retrieval."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 0. Setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Install the dependencies in the environment."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install -U FlagEmbedding faiss-cpu"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1. Dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Download and preprocess the MS Marco dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from datasets import load_dataset\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"data = load_dataset(\"namespace-Pt/msmarco\", split=\"dev\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"queries = np.array(data[:100][\"query\"])\n",
|
||||
"corpus = sum(data[:5000][\"positive\"], [])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2. Embedding"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Inference Embeddings: 100%|██████████| 21/21 [01:59<00:00, 5.68s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"shape of the corpus embeddings: (5331, 768)\n",
|
||||
"data type of the embeddings: float32\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from FlagEmbedding import FlagModel\n",
|
||||
"\n",
|
||||
"# get the BGE embedding model\n",
|
||||
"model = FlagModel('BAAI/bge-base-en-v1.5',\n",
|
||||
" query_instruction_for_retrieval=\"Represent this sentence for searching relevant passages:\",\n",
|
||||
" use_fp16=True)\n",
|
||||
"\n",
|
||||
"# get the embedding of the corpus\n",
|
||||
"corpus_embeddings = model.encode(corpus)\n",
|
||||
"\n",
|
||||
"print(\"shape of the corpus embeddings:\", corpus_embeddings.shape)\n",
|
||||
"print(\"data type of the embeddings: \", corpus_embeddings.dtype)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 3. Indexing"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"total number of vectors: 5331\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import faiss\n",
|
||||
"\n",
|
||||
"# get the length of our embedding vectors, vectors by bge-base-en-v1.5 have length 768\n",
|
||||
"dim = corpus_embeddings.shape[-1]\n",
|
||||
"\n",
|
||||
"# create the faiss index and store the corpus embeddings into the vector space\n",
|
||||
"index = faiss.index_factory(dim, 'Flat', faiss.METRIC_INNER_PRODUCT)\n",
|
||||
"corpus_embeddings = corpus_embeddings.astype(np.float32)\n",
|
||||
"index.train(corpus_embeddings)\n",
|
||||
"index.add(corpus_embeddings)\n",
|
||||
"\n",
|
||||
"print(f\"total number of vectors: {index.ntotal}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 4. Retrieval"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"query_embeddings = model.encode_queries(queries)\n",
|
||||
"ground_truths = [d[\"positive\"] for d in data]\n",
|
||||
"corpus = np.asarray(corpus)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Searching: 100%|██████████| 1/1 [00:00<00:00, 22.35it/s]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from tqdm import tqdm\n",
|
||||
"\n",
|
||||
"res_scores, res_ids, res_text = [], [], []\n",
|
||||
"query_size = len(query_embeddings)\n",
|
||||
"batch_size = 256\n",
|
||||
"# The cutoffs we will use during evaluation, and set k to be the maximum of the cutoffs.\n",
|
||||
"cut_offs = [1, 10]\n",
|
||||
"k = max(cut_offs)\n",
|
||||
"\n",
|
||||
"for i in tqdm(range(0, query_size, batch_size), desc=\"Searching\"):\n",
|
||||
" q_embedding = query_embeddings[i: min(i+batch_size, query_size)].astype(np.float32)\n",
|
||||
" # search the top k answers for each of the queries\n",
|
||||
" score, idx = index.search(q_embedding, k=k)\n",
|
||||
" res_scores += list(score)\n",
|
||||
" res_ids += list(idx)\n",
|
||||
" res_text += list(corpus[idx])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 5. Reranking"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now we will use a reranker to rerank the list of answers we retrieved using our index. Hopefully, this will lead to better results."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The following table lists the available BGE rerankers. Feel free to try out to see their differences!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"| Model | Language | Parameters | Description | Base Model |\n",
|
||||
"|:-------|:--------:|:----:|:-----------------:|:--------------------------------------:|\n",
|
||||
"| [BAAI/bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3) | Multilingual | 568M | a lightweight cross-encoder model, possesses strong multilingual capabilities, easy to deploy, with fast inference. | XLM-RoBERTa-Large |\n",
|
||||
"| [BAAI/bge-reranker-v2-gemma](https://huggingface.co/BAAI/bge-reranker-v2-gemma) | Multilingual | 2.51B | a cross-encoder model which is suitable for multilingual contexts, performs well in both English proficiency and multilingual capabilities. | Gemma2-2B |\n",
|
||||
"| [BAAI/bge-reranker-v2-minicpm-layerwise](https://huggingface.co/BAAI/bge-reranker-v2-minicpm-layerwise) | Multilingual | 2.72B | a cross-encoder model which is suitable for multilingual contexts, performs well in both English and Chinese proficiency, allows freedom to select layers for output, facilitating accelerated inference. | MiniCPM |\n",
|
||||
"| [BAAI/bge-reranker-v2.5-gemma2-lightweight](https://huggingface.co/BAAI/bge-reranker-v2.5-gemma2-lightweight) | Multilingual | 9.24B | a cross-encoder model which is suitable for multilingual contexts, performs well in both English and Chinese proficiency, allows freedom to select layers, compress ratio and compress layers for output, facilitating accelerated inference. | Gemma2-9B |\n",
|
||||
"| [BAAI/bge-reranker-large](https://huggingface.co/BAAI/bge-reranker-large) | Chinese and English | 560M | a cross-encoder model which is more accurate but less efficient | XLM-RoBERTa-Large |\n",
|
||||
"| [BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base) | Chinese and English | 278M | a cross-encoder model which is more accurate but less efficient | XLM-RoBERTa-Base |"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"First, let's use a small example to see how reranker works:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[-9.474676132202148, -2.823843240737915, 5.76226806640625]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from FlagEmbedding import FlagReranker\n",
|
||||
"\n",
|
||||
"reranker = FlagReranker('BAAI/bge-reranker-large', use_fp16=True) \n",
|
||||
"# Setting use_fp16 to True speeds up computation with a slight performance degradation\n",
|
||||
"\n",
|
||||
"# use the compute_score() function to calculate scores for each input sentence pair\n",
|
||||
"scores = reranker.compute_score([\n",
|
||||
" ['what is panda?', 'Today is a sunny day'], \n",
|
||||
" ['what is panda?', 'The tiger (Panthera tigris) is a member of the genus Panthera and the largest living cat species native to Asia.'],\n",
|
||||
" ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']\n",
|
||||
" ])\n",
|
||||
"print(scores)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now, let's use the reranker to rerank our previously retrieved results:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"new_ids, new_scores, new_text = [], [], []\n",
|
||||
"for i in range(len(queries)):\n",
|
||||
" # get the new scores of the previously retrieved results\n",
|
||||
" new_score = reranker.compute_score([[queries[i], text] for text in res_text[i]])\n",
|
||||
" # sort the lists of ids and scores by the new scores\n",
|
||||
" new_id = [tup[1] for tup in sorted(list(zip(new_score, res_ids[i])), reverse=True)]\n",
|
||||
" new_scores.append(sorted(new_score, reverse=True))\n",
|
||||
" new_ids.append(new_id)\n",
|
||||
" new_text.append(corpus[new_id])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 6. Evaluate"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"For details of these metrics, please checkout the tutorial of [evaluation](https://github.com/FlagOpen/FlagEmbedding/tree/master/Tutorials/4_Evaluation)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 6.1 Recall"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def calc_recall(preds, truths, cutoffs):\n",
|
||||
" recalls = np.zeros(len(cutoffs))\n",
|
||||
" for text, truth in zip(preds, truths):\n",
|
||||
" for i, c in enumerate(cutoffs):\n",
|
||||
" recall = np.intersect1d(truth, text[:c])\n",
|
||||
" recalls[i] += len(recall) / max(min(len(recall), len(truth)), 1)\n",
|
||||
" recalls /= len(preds)\n",
|
||||
" return recalls"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Before reranking:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"recall@1:\t0.97\n",
|
||||
"recall@10:\t1.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"recalls_init = calc_recall(res_text, ground_truths, cut_offs)\n",
|
||||
"for i, c in enumerate(cut_offs):\n",
|
||||
" print(f\"recall@{c}:\\t{recalls_init[i]}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"After reranking:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"recall@1:\t0.99\n",
|
||||
"recall@10:\t1.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"recalls_rerank = calc_recall(new_text, ground_truths, cut_offs)\n",
|
||||
"for i, c in enumerate(cut_offs):\n",
|
||||
" print(f\"recall@{c}:\\t{recalls_rerank[i]}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 6.2 MRR"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def MRR(preds, truth, cutoffs):\n",
|
||||
" mrr = [0 for _ in range(len(cutoffs))]\n",
|
||||
" for pred, t in zip(preds, truth):\n",
|
||||
" for i, c in enumerate(cutoffs):\n",
|
||||
" for j, p in enumerate(pred):\n",
|
||||
" if j < c and p in t:\n",
|
||||
" mrr[i] += 1/(j+1)\n",
|
||||
" break\n",
|
||||
" mrr = [k/len(preds) for k in mrr]\n",
|
||||
" return mrr"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Before reranking:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"MRR@1:\t0.97\n",
|
||||
"MRR@10:\t0.9825\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"mrr_init = MRR(res_text, ground_truths, cut_offs)\n",
|
||||
"for i, c in enumerate(cut_offs):\n",
|
||||
" print(f\"MRR@{c}:\\t{mrr_init[i]}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"After reranking:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"MRR@1:\t0.99\n",
|
||||
"MRR@10:\t0.995\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"mrr_rerank = MRR(new_text, ground_truths, cut_offs)\n",
|
||||
"for i, c in enumerate(cut_offs):\n",
|
||||
" print(f\"MRR@{c}:\\t{mrr_rerank[i]}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 6.3 nDCG"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Before reranking:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"nDCG@1: 0.97\n",
|
||||
"nDCG@10: 0.9869253606521631\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from sklearn.metrics import ndcg_score\n",
|
||||
"\n",
|
||||
"pred_hard_encodings = []\n",
|
||||
"for pred, label in zip(res_text, ground_truths):\n",
|
||||
" pred_hard_encoding = list(np.isin(pred, label).astype(int))\n",
|
||||
" pred_hard_encodings.append(pred_hard_encoding)\n",
|
||||
"\n",
|
||||
"for i, c in enumerate(cut_offs):\n",
|
||||
" nDCG = ndcg_score(pred_hard_encodings, res_scores, k=c)\n",
|
||||
" print(f\"nDCG@{c}: {nDCG}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"After reranking:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"nDCG@1: 0.99\n",
|
||||
"nDCG@10: 0.9963092975357145\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"pred_hard_encodings_rerank = []\n",
|
||||
"for pred, label in zip(new_text, ground_truths):\n",
|
||||
" pred_hard_encoding = list(np.isin(pred, label).astype(int))\n",
|
||||
" pred_hard_encodings_rerank.append(pred_hard_encoding)\n",
|
||||
"\n",
|
||||
"for i, c in enumerate(cut_offs):\n",
|
||||
" nDCG = ndcg_score(pred_hard_encodings_rerank, new_scores, k=c)\n",
|
||||
" print(f\"nDCG@{c}: {nDCG}\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "base",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user