mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-08-11 18:27:23 +00:00
615 lines
20 KiB
Plaintext
615 lines
20 KiB
Plaintext
![]() |
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "d34077d7-dd3a-49d7-b87b-2914d52992a6",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Intro"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "cb0755e6-6ef9-4eb9-b3e6-131a0820ffc7",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"For text retrieval, pattern matching is the most intuitive way. People would use certain characters, words, phrases, or sentence patterns. However, not only for human, it is also extremely inefficient for computer to do pattern matching between a query and a collection of text files to find the possible results. \n",
|
||
|
"\n",
|
||
|
"For images and acoustic waves, there are rgb pixels and digital signals. Similarly, in order to accomplish more sophisticated tasks of natural language such as retrieval, classification, clustering, or semantic search, we need a way to represent text data. That's how text embedding comes in front of the stage."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "e53fb5bc",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## 1. Background"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "ca2b21dc",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Traditional text embedding methods like one-hot encoding and bag-of-words (BoW) represent words and sentences as sparse vectors based on their statistical features, such as word appearance and frequency within a document. More advanced methods like TF-IDF and BM25 improve on these by considering a word's importance across an entire corpus, while n-gram techniques capture word order in small groups. However, these approaches suffer from the \"curse of dimensionality\" and fail to capture semantic similarity like \"cat\" and \"kitty\", difference like \"play the watch\" and \"watch the play\"."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"id": "7d7ea92b",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# example of bag-of-words\n",
|
||
|
"sentence1 = \"I love basketball\"\n",
|
||
|
"sentence2 = \"I have a basketball match\"\n",
|
||
|
"\n",
|
||
|
"words = ['I', 'love', 'basketball', 'have', 'a', 'match']\n",
|
||
|
"sen1_vec = [1, 1, 1, 0, 0, 0]\n",
|
||
|
"sen2_vec = [1, 0, 1, 1, 1, 1]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "9958d001",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"To overcome these limitations, dense word embeddings were developed, mapping words to vectors in a low-dimensional space that captures semantic and relational information. Early models like Word2Vec demonstrated the power of dense embeddings using neural networks. Subsequent advancements with neural network architectures like RNNs, LSTMs, and Transformers have enabled more sophisticated models such as BERT, RoBERTa, and GPT to excel in capturing complex word relationships and contexts. **BAAI General Embedding (BGE)** provide a series of open-source models that could satisfy all kinds of demands."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "06cff9e4",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## 2. BAAI General Embedding"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "880e229d",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"In this Part, we will walk through the BGE series and introduce how to use those embedding models.\n",
|
||
|
"\n",
|
||
|
"First, install the FlagEmbedding in your environment."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "af533a6f",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"%pip install -U FlagEmbedding"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "2e10034a",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### 2.1 BGE"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "0cdc6702",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"The very first version of BGE has 6 models, with 'large', 'base', and 'small' for English and Chinese. "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "04b75f72",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"| Model | Language | Parameters | Model Size | Description | Base Model |\n",
|
||
|
"|:-------|:--------:|:--------------:|:--------------:|:-----------------:|:----------------:|\n",
|
||
|
"| [BAAI/bge-large-en](https://huggingface.co/BAAI/bge-large-en) | English | 500M | 1.34 GB | Embedding Model which map text into vector | BERT |\n",
|
||
|
"| [BAAI/bge-base-en](https://huggingface.co/BAAI/bge-base-en) | English | 109M | 438 MB | a base-scale model but with similar ability to `bge-large-en` | BERT |\n",
|
||
|
"| [BAAI/bge-small-en](https://huggingface.co/BAAI/bge-small-en) | English | 33.4M | 133 MB | a small-scale model but with competitive performance | BERT |\n",
|
||
|
"| [BAAI/bge-large-zh](https://huggingface.co/BAAI/bge-large-zh) | Chinese | 326M | 1.3 GB | Embedding Model which map text into vector | BERT |\n",
|
||
|
"| [BAAI/bge-base-zh](https://huggingface.co/BAAI/bge-base-zh) | Chinese | 102M | 409 MB | a base-scale model but with similar ability to `bge-large-zh` | BERT |\n",
|
||
|
"| [BAAI/bge-small-zh](https://huggingface.co/BAAI/bge-small-zh) | Chinese | 24M | 95.8 MB | a small-scale model but with competitive performance | BERT |"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "c9c45d17",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"For inference, import FlagModel from FlagEmbedding and initialize the model."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 10,
|
||
|
"id": "89e07751",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"[[0.8888277 0.82843924]\n",
|
||
|
" [0.80761224 0.8892383 ]]\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from FlagEmbedding import FlagModel\n",
|
||
|
"\n",
|
||
|
"# Load BGE model\n",
|
||
|
"model = FlagModel('BAAI/bge-base-en',\n",
|
||
|
" query_instruction_for_retrieval=\"Represent this sentence for searching relevant passages:\",\n",
|
||
|
" use_fp16=True)\n",
|
||
|
"\n",
|
||
|
"queries = [\"query 1\", \"query 2\"]\n",
|
||
|
"corpus = [\"passage 1\", \"passage 2\"]\n",
|
||
|
"\n",
|
||
|
"# encode the queries and corpus\n",
|
||
|
"q_embeddings = model.encode(queries)\n",
|
||
|
"p_embeddings = model.encode(corpus)\n",
|
||
|
"\n",
|
||
|
"# compute the similarity scores\n",
|
||
|
"scores = q_embeddings @ p_embeddings.T\n",
|
||
|
"print(scores)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "6c8e69ed",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"To use `FlagModel`:\n",
|
||
|
"```\n",
|
||
|
"FlagModel.encode(sentences, batch_size=256, max_length=512, convert_to_numpy=True)\n",
|
||
|
"```\n",
|
||
|
"The *encode()* function directly encode the input sentences to embedding vectors.\n",
|
||
|
"```\n",
|
||
|
"FlagModel.encode_queries(sentences, batch_size=256, max_length=512, convert_to_numpy=True)\n",
|
||
|
"```\n",
|
||
|
"The *encode_queries()* function concatenate the `query_instruction_for_retrieval` with each of the input query, and then call `encode()`."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "2c86a5a3",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### 2.2 BGE 1.5"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "454ff7aa",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"BGE 1.5 alleviate the issue of the similarity distribution, and enhance retrieval ability without instruction."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "30b1f897",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"| Model | Language | Parameters | Model Size | Description | Base Model |\n",
|
||
|
"|:-------|:--------:|:--------------:|:--------------:|:-----------------:|:----------------:|\n",
|
||
|
"| [BAAI/bge-large-en-v1.5](https://huggingface.co/BAAI/bge-large-en-v1.5) | English | 335M | 1.34 GB | version 1.5 with more reasonable similarity distribution | BERT |\n",
|
||
|
"| [BAAI/bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) | English | 109M | 438 MB | version 1.5 with more reasonable similarity distribution | BERT |\n",
|
||
|
"| [BAAI/bge-small-en-v1.5](https://huggingface.co/BAAI/bge-small-en-v1.5) | English | 33.4M | 133 MB | version 1.5 with more reasonable similarity distribution | BERT |\n",
|
||
|
"| [BAAI/bge-large-zh-v1.5](https://huggingface.co/BAAI/bge-large-zh-v1.5) | Chinese | 326M | 1.3 GB | version 1.5 with more reasonable similarity distribution | BERT |\n",
|
||
|
"| [BAAI/bge-base-zh-v1.5](https://huggingface.co/BAAI/bge-base-zh-v1.5) | Chinese | 102M | 409 MB | version 1.5 with more reasonable similarity distribution | BERT |\n",
|
||
|
"| [BAAI/bge-small-zh-v1.5](https://huggingface.co/BAAI/bge-small-zh-v1.5) | Chinese | 24M | 95.8 MB | version 1.5 with more reasonable similarity distribution | BERT |"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "ed00c504",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"BGE 1.5 models shares the same API of `FlagModel` with BGE models."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 11,
|
||
|
"id": "9b17afcc",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"[[0.736794 0.5989914]\n",
|
||
|
" [0.5684842 0.7461165]]\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"model = FlagModel('BAAI/bge-base-en-v1.5',\n",
|
||
|
" query_instruction_for_retrieval=\"Represent this sentence for searching relevant passages:\",\n",
|
||
|
" use_fp16=True)\n",
|
||
|
"\n",
|
||
|
"queries = [\"query 1\", \"query 2\"]\n",
|
||
|
"corpus = [\"passage 1\", \"passage 2\"]\n",
|
||
|
"\n",
|
||
|
"# encode the queries and corpus\n",
|
||
|
"q_embeddings = model.encode(queries)\n",
|
||
|
"p_embeddings = model.encode(corpus)\n",
|
||
|
"\n",
|
||
|
"# compute the similarity scores\n",
|
||
|
"scores = q_embeddings @ p_embeddings.T\n",
|
||
|
"print(scores)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "38c3ce1c",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### 2.3 LLM-Embedder"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "1bc3fee0",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"LLM-Embedder is a unified embedding model supporting diverse retrieval augmentation needs for LLMs. It is fine-tuned over 6 tasks:\n",
|
||
|
"- Question Answering (qa)\n",
|
||
|
"- Conversational Search (convsearch)\n",
|
||
|
"- Long Conversation (chat)\n",
|
||
|
"- Long-Rnage Language Modeling (lrlm)\n",
|
||
|
"- In-Context Learning (icl)\n",
|
||
|
"- Tool Learning (tool)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "13b926e9",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"| Model | Language | Parameters | Model Size | Description | Base Model |\n",
|
||
|
"|:-------|:--------:|:--------------:|:--------------:|:-----------------:|:----------------:|\n",
|
||
|
"| [BAAI/llm-embedder](https://huggingface.co/BAAI/llm-embedder) | English | 109M | 438 MB | a unified embedding model to support diverse retrieval augmentation needs for LLMs | BERT |"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "a7b3f109",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"To use `LLMEmbedder`:\n",
|
||
|
"```\n",
|
||
|
"LLMEmbedder.encode_queries(queries, batch_size=256, max_length=256, task='qa')\n",
|
||
|
"```\n",
|
||
|
"The *encode_queries()* will call the *_encode()* functions (similar to the *encode()* in `FlagModel`) and add the corresponding query instruction of the given *task* in front of each of the input *queries*.\n",
|
||
|
"```\n",
|
||
|
"LLMEmbedder.encode_keys(keys, batch_size=256, max_length=512, task='qa')\n",
|
||
|
"```\n",
|
||
|
"Similarly, *encode_keys()* also calls *_encode()* and automatically add instructions according to given task."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 12,
|
||
|
"id": "5f077420",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"[[0.89705944 0.85341793]\n",
|
||
|
" [0.8462474 0.90914035]]\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from FlagEmbedding import LLMEmbedder\n",
|
||
|
"\n",
|
||
|
"# load the LLMEmbedder model\n",
|
||
|
"model = LLMEmbedder('BAAI/llm-embedder', use_fp16=False)\n",
|
||
|
"\n",
|
||
|
"# Define queries and keys\n",
|
||
|
"queries = [\"test query 1\", \"test query 2\"]\n",
|
||
|
"keys = [\"test key 1\", \"test key 2\"]\n",
|
||
|
"\n",
|
||
|
"# Encode for a specific task (qa, icl, chat, lrlm, tool, convsearch)\n",
|
||
|
"task = \"qa\"\n",
|
||
|
"query_embeddings = model.encode_queries(queries, task=task)\n",
|
||
|
"key_embeddings = model.encode_keys(keys, task=task)\n",
|
||
|
"\n",
|
||
|
"# compute the similarity scores\n",
|
||
|
"similarity = query_embeddings @ key_embeddings.T\n",
|
||
|
"print(similarity)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "dcf2a82b",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### 2.4 BGE M3"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "cc5b5a5e",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"BGE-M3 is the new version of BGE models that is distinguished for its versatility in:\n",
|
||
|
"- Multi-Functionality: Simultaneously perform the three common retrieval functionalities of embedding model: dense retrieval, multi-vector retrieval, and sparse retrieval.\n",
|
||
|
"- Multi-Linguality: Supports more than 100 working languages.\n",
|
||
|
"- Multi-Granularity: Can proces inputs with different granularityies, spanning from short sentences to long documents of up to 8192 tokens.\n",
|
||
|
"\n",
|
||
|
"For more details, feel free to check out the [paper](https://arxiv.org/pdf/2402.03216)."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "41348e03",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"| Model | Language | Parameters | Model Size | Description | Base Model |\n",
|
||
|
"|:-------|:--------:|:--------------:|:--------------:|:-----------------:|:----------------:|\n",
|
||
|
"| [BAAI/bge-m3](https://huggingface.co/BAAI/bge-m3) | Multilingual | 568M | 2.27 GB | Multi-Functionality(dense retrieval, sparse retrieval, multi-vector(colbert)), Multi-Linguality, and Multi-Granularity(8192 tokens) | XLM-RoBERTa |"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 13,
|
||
|
"id": "d4647625",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 216946.76it/s]\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from FlagEmbedding import BGEM3FlagModel\n",
|
||
|
"\n",
|
||
|
"model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)\n",
|
||
|
"\n",
|
||
|
"sentences_1 = [\"What is BGE M3?\", \"Defination of BM25\"]\n",
|
||
|
"sentences_2 = [\"BGE M3 is an embedding model supporting dense retrieval, lexical matching and multi-vector interaction.\", \n",
|
||
|
" \"BM25 is a bag-of-words retrieval function that ranks a set of documents based on the query terms appearing in each document\"]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "1f89f1a9",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"```\n",
|
||
|
"BGEM3FlagModel.encode(\n",
|
||
|
" sentences, \n",
|
||
|
" batch_size=12, \n",
|
||
|
" max_length=8192, \n",
|
||
|
" return_dense=True, \n",
|
||
|
" return_sparse=False, \n",
|
||
|
" return_colbert_vecs=False\n",
|
||
|
")\n",
|
||
|
"```\n",
|
||
|
"It returns a dictionary like:\n",
|
||
|
"```\n",
|
||
|
"{\n",
|
||
|
" 'dense_vecs': array of dense embeddings if return_dense=Ture, otherwise None,\n",
|
||
|
" 'lexical_weights': array of dictionaries with keys and values are ids of tokens and their corresponding weights if return_sparse=True, otherwise None,\n",
|
||
|
" 'colbert_vecs': \n",
|
||
|
"}\n",
|
||
|
"```"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "7c06e113",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"#### 2.4.1 Dense Retrieval"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "e7361f61",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"It's almost the same to BGE or BGE 1.5 models if using BGE M3 for dense embedding. "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 14,
|
||
|
"id": "76f15175",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"[[0.6259035 0.34749585]\n",
|
||
|
" [0.349868 0.6782462 ]]\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# If you don't need such a long length of 8192 input tokens, you can set max_length to a smaller value to speed up encoding.\n",
|
||
|
"embeddings_1 = model.encode(sentences_1, max_length=10)['dense_vecs']\n",
|
||
|
"embeddings_2 = model.encode(sentences_2, max_length=100)['dense_vecs']\n",
|
||
|
"\n",
|
||
|
"# compute the similarity scores\n",
|
||
|
"similarity = embeddings_1 @ embeddings_2.T\n",
|
||
|
"print(similarity)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "8be70fe2",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"#### 2.4.2 Sparse Retrieval"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "4d0d91a1",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Set `return_sparse` to true to make the model return sparse vector. If a term token appears multiple times in the sentence, we only retain its max weight."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 15,
|
||
|
"id": "5258d5cb",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"[{'What': 0.08362077, 'is': 0.081469566, 'B': 0.12964639, 'GE': 0.25186998, 'M': 0.17001738, '3': 0.26957875, '?': 0.040755156}, {'De': 0.050144322, 'fin': 0.13689369, 'ation': 0.045134712, 'of': 0.06342201, 'BM': 0.25167602, '25': 0.33353207}]\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"output_1 = model.encode(sentences_1, return_sparse=True)\n",
|
||
|
"output_2 = model.encode(sentences_2, return_sparse=True)\n",
|
||
|
"\n",
|
||
|
"# you can see the weight for each token:\n",
|
||
|
"print(model.convert_id_to_token(output_1['lexical_weights']))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "bffc20c9",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Based on the tokens' weights of query and passage, the relevance score between them is computed by the joint importance of the co-existed terms within the query and passage:\n",
|
||
|
"\n",
|
||
|
"$$s_{lex} = \\sum_{t\\in q\\cap p}(w_{qt} * w_{pt})$$"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 16,
|
||
|
"id": "867e2148",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"0.19554448500275612\n",
|
||
|
"0\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# compute the scores via lexical mathcing\n",
|
||
|
"score_1 = model.compute_lexical_matching_score(output_1['lexical_weights'][0], output_2['lexical_weights'][0])\n",
|
||
|
"print(score_1)\n",
|
||
|
"\n",
|
||
|
"score_2 = model.compute_lexical_matching_score(output_1['lexical_weights'][0], output_1['lexical_weights'][1])\n",
|
||
|
"print(score_2)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "cdd6bbe2",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"#### 3.4.2 Multi-Vector"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "8c243114",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"The multi-vector method utilizes the entire output embeddings for the representation of query and passage."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 17,
|
||
|
"id": "21b27182",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"(8, 1024)\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"output_1 = model.encode(sentences_1, return_dense=True, return_sparse=True, return_colbert_vecs=True)\n",
|
||
|
"output_2 = model.encode(sentences_2, return_dense=True, return_sparse=True, return_colbert_vecs=True)\n",
|
||
|
"\n",
|
||
|
"print(f\"({len(output_1['colbert_vecs'][0])}, {len(output_1['colbert_vecs'][0][0])})\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "dec335e5",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Following ColBert, we use late-interaction to compute the fine-grained relevance score:\n",
|
||
|
"$$s_{mul}=\\frac{1}{N}\\sum_{i=1}^N\\max_{j=1}^M E_q[i]\\cdot E_p^T[j]$$"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 18,
|
||
|
"id": "575e38e5",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"0.7796662449836731\n",
|
||
|
"0.4621177911758423\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"print(model.colbert_score(output_1['colbert_vecs'][0], output_2['colbert_vecs'][0]).item())\n",
|
||
|
"print(model.colbert_score(output_1['colbert_vecs'][0], output_2['colbert_vecs'][1]).item())"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3 (ipykernel)",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.10.13"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 5
|
||
|
}
|