mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-07-22 16:37:16 +00:00
278 lines
7.3 KiB
Plaintext
278 lines
7.3 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# BGE Auto Embedder"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"FlagEmbedding provides a high level class `FlagAutoModel` that unify the inference of embedding models. Besides BGE series, it also supports other popular open-source embedding models such as E5, GTE, SFR, etc. In this tutorial, we will have an idea how to use it."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"% pip install FlagEmbedding"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 1. Usage"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"First, import `FlagAutoModel` from FlagEmbedding, and use the `from_finetuned()` function to initialize the model:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from FlagEmbedding import FlagAutoModel\n",
|
|
"\n",
|
|
"model = FlagAutoModel.from_finetuned(\n",
|
|
" 'BAAI/bge-base-en-v1.5',\n",
|
|
" query_instruction_for_retrieval=\"Represent this sentence for searching relevant passages: \",\n",
|
|
" devices=\"cuda:0\", # if not specified, will use all available gpus or cpu when no gpu available\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Then use the model exactly same to `FlagModel` (`FlagM3Model` if using BGE M3, `FlagLLMModel` if using BGE Multilingual Gemma2, `FlagICLModel` if using BGE ICL)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[[0.76 0.6714]\n",
|
|
" [0.6177 0.7603]]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"queries = [\"query 1\", \"query 2\"]\n",
|
|
"corpus = [\"passage 1\", \"passage 2\"]\n",
|
|
"\n",
|
|
"# encode the queries and corpus\n",
|
|
"q_embeddings = model.encode_queries(queries)\n",
|
|
"p_embeddings = model.encode_corpus(corpus)\n",
|
|
"\n",
|
|
"# compute the similarity scores\n",
|
|
"scores = q_embeddings @ p_embeddings.T\n",
|
|
"print(scores)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 2. Explanation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"`FlagAutoModel` use an OrderedDict `MODEL_MAPPING` to store all the supported models configuration:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"['bge-en-icl',\n",
|
|
" 'bge-multilingual-gemma2',\n",
|
|
" 'bge-m3',\n",
|
|
" 'bge-large-en-v1.5',\n",
|
|
" 'bge-base-en-v1.5',\n",
|
|
" 'bge-small-en-v1.5',\n",
|
|
" 'bge-large-zh-v1.5',\n",
|
|
" 'bge-base-zh-v1.5',\n",
|
|
" 'bge-small-zh-v1.5',\n",
|
|
" 'bge-large-en',\n",
|
|
" 'bge-base-en',\n",
|
|
" 'bge-small-en',\n",
|
|
" 'bge-large-zh',\n",
|
|
" 'bge-base-zh',\n",
|
|
" 'bge-small-zh',\n",
|
|
" 'e5-mistral-7b-instruct',\n",
|
|
" 'e5-large-v2',\n",
|
|
" 'e5-base-v2',\n",
|
|
" 'e5-small-v2',\n",
|
|
" 'multilingual-e5-large-instruct',\n",
|
|
" 'multilingual-e5-large',\n",
|
|
" 'multilingual-e5-base',\n",
|
|
" 'multilingual-e5-small',\n",
|
|
" 'e5-large',\n",
|
|
" 'e5-base',\n",
|
|
" 'e5-small',\n",
|
|
" 'gte-Qwen2-7B-instruct',\n",
|
|
" 'gte-Qwen2-1.5B-instruct',\n",
|
|
" 'gte-Qwen1.5-7B-instruct',\n",
|
|
" 'gte-multilingual-base',\n",
|
|
" 'gte-large-en-v1.5',\n",
|
|
" 'gte-base-en-v1.5',\n",
|
|
" 'gte-large',\n",
|
|
" 'gte-base',\n",
|
|
" 'gte-small',\n",
|
|
" 'gte-large-zh',\n",
|
|
" 'gte-base-zh',\n",
|
|
" 'gte-small-zh',\n",
|
|
" 'SFR-Embedding-2_R',\n",
|
|
" 'SFR-Embedding-Mistral',\n",
|
|
" 'Linq-Embed-Mistral']"
|
|
]
|
|
},
|
|
"execution_count": 12,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"from FlagEmbedding.inference.embedder.model_mapping import AUTO_EMBEDDER_MAPPING\n",
|
|
"\n",
|
|
"list(AUTO_EMBEDDER_MAPPING.keys())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"EmbedderConfig(model_class=<class 'FlagEmbedding.inference.embedder.decoder_only.icl.ICLLLMEmbedder'>, pooling_method=<PoolingMethod.LAST_TOKEN: 'last_token'>, trust_remote_code=False, query_instruction_format='<instruct>{}\\n<query>{}')\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(AUTO_EMBEDDER_MAPPING['bge-en-icl'])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Taking a look at the value of each key, which is an object of `EmbedderConfig`. It consists four attributes:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"```python\n",
|
|
"@dataclass\n",
|
|
"class EmbedderConfig:\n",
|
|
" model_class: Type[AbsEmbedder]\n",
|
|
" pooling_method: PoolingMethod\n",
|
|
" trust_remote_code: bool = False\n",
|
|
" query_instruction_format: str = \"{}{}\"\n",
|
|
"```"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Not only the BGE series, it supports other models such as E5 similarly:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"EmbedderConfig(model_class=<class 'FlagEmbedding.inference.embedder.decoder_only.icl.ICLLLMEmbedder'>, pooling_method=<PoolingMethod.LAST_TOKEN: 'last_token'>, trust_remote_code=False, query_instruction_format='<instruct>{}\\n<query>{}')\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(AUTO_EMBEDDER_MAPPING['bge-en-icl'])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 3. Customization"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"If you want to use your own models through `FlagAutoModel`, consider the following steps:\n",
|
|
"\n",
|
|
"1. Check the type of your embedding model and choose the appropriate model class, is it an encoder or a decoder?\n",
|
|
"2. What kind of pooling method it uses? CLS token, mean pooling, or last token?\n",
|
|
"3. Does your model needs `trust_remote_code=Ture` to ran?\n",
|
|
"4. Is there a query instruction format for retrieval?\n",
|
|
"\n",
|
|
"After these four attributes are assured, add your model name as the key and corresponding EmbedderConfig as the value to `MODEL_MAPPING`. Now have a try!"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "dev",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.12.7"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|