2025-01-16 11:43:42 +00:00

278 lines
7.3 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# BGE Auto Embedder"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"FlagEmbedding provides a high level class `FlagAutoModel` that unify the inference of embedding models. Besides BGE series, it also supports other popular open-source embedding models such as E5, GTE, SFR, etc. In this tutorial, we will have an idea how to use it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"% pip install FlagEmbedding"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Usage"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, import `FlagAutoModel` from FlagEmbedding, and use the `from_finetuned()` function to initialize the model:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"from FlagEmbedding import FlagAutoModel\n",
"\n",
"model = FlagAutoModel.from_finetuned(\n",
" 'BAAI/bge-base-en-v1.5',\n",
" query_instruction_for_retrieval=\"Represent this sentence for searching relevant passages: \",\n",
" devices=\"cuda:0\", # if not specified, will use all available gpus or cpu when no gpu available\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then use the model exactly same to `FlagModel` (`FlagM3Model` if using BGE M3, `FlagLLMModel` if using BGE Multilingual Gemma2, `FlagICLModel` if using BGE ICL)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0.76 0.6714]\n",
" [0.6177 0.7603]]\n"
]
}
],
"source": [
"queries = [\"query 1\", \"query 2\"]\n",
"corpus = [\"passage 1\", \"passage 2\"]\n",
"\n",
"# encode the queries and corpus\n",
"q_embeddings = model.encode_queries(queries)\n",
"p_embeddings = model.encode_corpus(corpus)\n",
"\n",
"# compute the similarity scores\n",
"scores = q_embeddings @ p_embeddings.T\n",
"print(scores)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Explanation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`FlagAutoModel` use an OrderedDict `MODEL_MAPPING` to store all the supported models configuration:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['bge-en-icl',\n",
" 'bge-multilingual-gemma2',\n",
" 'bge-m3',\n",
" 'bge-large-en-v1.5',\n",
" 'bge-base-en-v1.5',\n",
" 'bge-small-en-v1.5',\n",
" 'bge-large-zh-v1.5',\n",
" 'bge-base-zh-v1.5',\n",
" 'bge-small-zh-v1.5',\n",
" 'bge-large-en',\n",
" 'bge-base-en',\n",
" 'bge-small-en',\n",
" 'bge-large-zh',\n",
" 'bge-base-zh',\n",
" 'bge-small-zh',\n",
" 'e5-mistral-7b-instruct',\n",
" 'e5-large-v2',\n",
" 'e5-base-v2',\n",
" 'e5-small-v2',\n",
" 'multilingual-e5-large-instruct',\n",
" 'multilingual-e5-large',\n",
" 'multilingual-e5-base',\n",
" 'multilingual-e5-small',\n",
" 'e5-large',\n",
" 'e5-base',\n",
" 'e5-small',\n",
" 'gte-Qwen2-7B-instruct',\n",
" 'gte-Qwen2-1.5B-instruct',\n",
" 'gte-Qwen1.5-7B-instruct',\n",
" 'gte-multilingual-base',\n",
" 'gte-large-en-v1.5',\n",
" 'gte-base-en-v1.5',\n",
" 'gte-large',\n",
" 'gte-base',\n",
" 'gte-small',\n",
" 'gte-large-zh',\n",
" 'gte-base-zh',\n",
" 'gte-small-zh',\n",
" 'SFR-Embedding-2_R',\n",
" 'SFR-Embedding-Mistral',\n",
" 'Linq-Embed-Mistral']"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from FlagEmbedding.inference.embedder.model_mapping import AUTO_EMBEDDER_MAPPING\n",
"\n",
"list(AUTO_EMBEDDER_MAPPING.keys())"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"EmbedderConfig(model_class=<class 'FlagEmbedding.inference.embedder.decoder_only.icl.ICLLLMEmbedder'>, pooling_method=<PoolingMethod.LAST_TOKEN: 'last_token'>, trust_remote_code=False, query_instruction_format='<instruct>{}\\n<query>{}')\n"
]
}
],
"source": [
"print(AUTO_EMBEDDER_MAPPING['bge-en-icl'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Taking a look at the value of each key, which is an object of `EmbedderConfig`. It consists four attributes:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```python\n",
"@dataclass\n",
"class EmbedderConfig:\n",
" model_class: Type[AbsEmbedder]\n",
" pooling_method: PoolingMethod\n",
" trust_remote_code: bool = False\n",
" query_instruction_format: str = \"{}{}\"\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Not only the BGE series, it supports other models such as E5 similarly:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"EmbedderConfig(model_class=<class 'FlagEmbedding.inference.embedder.decoder_only.icl.ICLLLMEmbedder'>, pooling_method=<PoolingMethod.LAST_TOKEN: 'last_token'>, trust_remote_code=False, query_instruction_format='<instruct>{}\\n<query>{}')\n"
]
}
],
"source": [
"print(AUTO_EMBEDDER_MAPPING['bge-en-icl'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Customization"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you want to use your own models through `FlagAutoModel`, consider the following steps:\n",
"\n",
"1. Check the type of your embedding model and choose the appropriate model class, is it an encoder or a decoder?\n",
"2. What kind of pooling method it uses? CLS token, mean pooling, or last token?\n",
"3. Does your model needs `trust_remote_code=Ture` to ran?\n",
"4. Is there a query instruction format for retrieval?\n",
"\n",
"After these four attributes are assured, add your model name as the key and corresponding EmbedderConfig as the value to `MODEL_MAPPING`. Now have a try!"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "dev",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}