mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-07-03 23:20:36 +00:00
724 lines
22 KiB
Plaintext
724 lines
22 KiB
Plaintext
![]() |
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# Data preparation for fine-tuning"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"In this tutorial, we will show an example of the first step for fine-tuning: dataset preparation."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## 0. Installation"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 1,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# % pip install -U datasets"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 2,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import os\n",
|
|||
|
"\n",
|
|||
|
"os.environ[\"HF_ENDPOINT\"]=\"https://hf-mirror.com\""
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Suppose we are willing to fine-tune our model for financial tasks. We found an open-source dataset that could be useful: [financial-qa-10k](https://huggingface.co/datasets/virattt/financial-qa-10K). Let's see how to properly prepare our dataset for fine-tuning."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"The raw dataset has the following structure:\n",
|
|||
|
"- 5 columns of: 'question', 'answer', 'context', 'ticker', and 'filing'.\n",
|
|||
|
"- 7000 rows."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 3,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"/share/project/xzy/Envs/ft/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
|||
|
" from .autonotebook import tqdm as notebook_tqdm\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"Dataset({\n",
|
|||
|
" features: ['question', 'answer', 'context', 'ticker', 'filing'],\n",
|
|||
|
" num_rows: 7000\n",
|
|||
|
"})"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 3,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from datasets import load_dataset\n",
|
|||
|
"\n",
|
|||
|
"ds = load_dataset(\"virattt/financial-qa-10K\", split=\"train\")\n",
|
|||
|
"ds"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## 1. Data for Fine-tuning"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Construct the dataset to the following format:\n",
|
|||
|
"\n",
|
|||
|
"``` python\n",
|
|||
|
"{\"query\": str, \"pos\": List[str], \"neg\":List[str], \"pos_scores\": List[int], \"neg_scores\": List[int], \"prompt\": str, \"type\": str}\n",
|
|||
|
"```"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"`query` is the query, and `pos` is a list of positive texts, `neg` is a list of negative texts. `pos_scores` is a list of scores corresponding to the query and pos, `neg_scores` is a list of scores corresponding to the `query` and `neg`, if you don't use knowledge distillation, it can be ignored. `prompt` is the prompt used for the query, it will cover query_instruction_for_retrieval. `type` is used for bge-en-icl, it includes `normal`, `symmetric_class`, `symmetric_clustering`, .etc. If you have no negative texts for a query, you can random sample some from the entire corpus as the negatives."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"We select the columns 'question' and 'context' as our query and answer(pos), and rename the columns. Then add the 'id' column for later evaluation use."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 4,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"{'query': 'What area did NVIDIA initially focus on before expanding to other computationally intensive fields?',\n",
|
|||
|
" 'pos': 'Since our original focus on PC graphics, we have expanded to several other large and important computationally intensive fields.',\n",
|
|||
|
" 'id': '0'}"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 4,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"ds = ds.select_columns(column_names=[\"question\", \"context\"])\n",
|
|||
|
"ds = ds.rename_column(\"question\", \"query\")\n",
|
|||
|
"ds = ds.rename_column(\"context\", \"pos\")\n",
|
|||
|
"ds = ds.add_column(\"id\", [str(i) for i in range(len(ds))])\n",
|
|||
|
"ds[0]"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Negative examples are important during the training of embedding models. Our initial dataset does not come with negative texts. Thus we directly sample a few from the whole corpus."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 5,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Map: 100%|██████████| 7000/7000 [00:00<00:00, 22336.83 examples/s]\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import numpy as np\n",
|
|||
|
"\n",
|
|||
|
"np.random.seed(520)\n",
|
|||
|
"neg_num = 10\n",
|
|||
|
"\n",
|
|||
|
"def str_to_lst(data):\n",
|
|||
|
" data[\"pos\"] = [data[\"pos\"]]\n",
|
|||
|
" return data\n",
|
|||
|
"\n",
|
|||
|
"# sample negative texts\n",
|
|||
|
"new_col = []\n",
|
|||
|
"for i in range(len(ds)):\n",
|
|||
|
" ids = np.random.randint(0, len(ds), size=neg_num)\n",
|
|||
|
" while i in ids:\n",
|
|||
|
" ids = np.random.randint(0, len(ds), size=neg_num)\n",
|
|||
|
" neg = [ds[i.item()][\"pos\"] for i in ids]\n",
|
|||
|
" new_col.append(neg)\n",
|
|||
|
"ds = ds.add_column(\"neg\", new_col)\n",
|
|||
|
"\n",
|
|||
|
"# change the key of 'pos' to a list\n",
|
|||
|
"ds = ds.map(str_to_lst)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Lastly, we add the prompt which is used for query. It will be the `query_instruction_for_retrieval` during inference."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 6,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"instruction = \"Represent this sentence for searching relevant passages: \"\n",
|
|||
|
"ds = ds.add_column(\"prompt\", [instruction]*len(ds))"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Now a single row of the dataset is:"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 7,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"{'query': 'What area did NVIDIA initially focus on before expanding to other computationally intensive fields?',\n",
|
|||
|
" 'pos': ['Since our original focus on PC graphics, we have expanded to several other large and important computationally intensive fields.'],\n",
|
|||
|
" 'id': '0',\n",
|
|||
|
" 'neg': ['Kroger expects that its value creation model will deliver total shareholder return within a target range of 8% to 11% over time.',\n",
|
|||
|
" 'CSB purchased First Mortgages of $2.9 billion during 2023.',\n",
|
|||
|
" 'See Note 13 to our Consolidated Financial Statements for information on certain legal proceedings for which there are contingencies.',\n",
|
|||
|
" 'Diluted earnings per share were $16.69 in fiscal 2022 compared to $15.53 in fiscal 2021.',\n",
|
|||
|
" 'In the year ended December 31, 2023, Total net sales and revenue increased primarily due to: (1) increased net wholesale volumes primarily due to increased sales of crossover vehicles and full-size pickup trucks, partially offset by decreased sales of mid-size pickup trucks; (2) favorable Price as a result of low dealer inventory levels and strong demand for our products; (3) favorable Mix associated with increased sales of full-size pickup trucks and full-size SUVs and decreased sales of vans, passenger cars and mid-size pickup trucks, partially offset by increased sales of crossover vehicles; and (4) favorable Other due to increased sales of parts and accessories.',\n",
|
|||
|
" 'As of December 31, 2023, we had 3,157 full-time employees.',\n",
|
|||
|
" 'Item 3. Legal Proceedings. The information contained in Note 18 ‘‘Commitments and Contingencies’’ included in Item 8 of this 10-K is incorporated herein by reference.',\n",
|
|||
|
" 'Under the amended 2019 Secured Facility, the maturity date is set to July 20, 2026.',\n",
|
|||
|
" 'Accounts receivable for Las Vegas Sands Corp. on December 31, 2023, totaled $685 million, with a provision for credit losses of $201 million, resulting in a net balance of $484 million.',\n",
|
|||
|
" 'Operating expenses as a percentage of segment net sales decreased 25 basis points for fiscal 2023 when compared to the previous fiscal year, primarily driven by strong sales growth and lower incremental COVID-19 related costs, partially offset by increased wage costs.'],\n",
|
|||
|
" 'prompt': 'Represent this sentence for searching relevant passages: '}"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 7,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"ds[0]"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Then we split the dataset into training set and testing set."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 8,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"split = ds.train_test_split(test_size=0.1, shuffle=True, seed=520)\n",
|
|||
|
"train = split[\"train\"]\n",
|
|||
|
"test = split[\"test\"]"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Now we are ready to store the data for later fine-tuning:"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 15,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Creating json from Arrow format: 100%|██████████| 7/7 [00:00<00:00, 39.73ba/s]\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"16583481"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 15,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"train.to_json(\"ft_data/training.json\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Test Data for Evaluation"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"The last step is to construct the testing dataset following the [format](https://github.com/FlagOpen/FlagEmbedding/tree/master/examples/evaluation#8-custom-dataset) for evaluation."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 10,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"Dataset({\n",
|
|||
|
" features: ['query', 'pos', 'id', 'neg', 'prompt'],\n",
|
|||
|
" num_rows: 700\n",
|
|||
|
"})"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 10,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"test"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"First select the columns for queries:"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 11,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"{'id': '1289',\n",
|
|||
|
" 'text': 'How does Starbucks recognize the interest and penalties related to income tax matters on their financial statements?'}"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 11,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"queries = test.select_columns(column_names=[\"id\", \"query\"])\n",
|
|||
|
"queries = queries.rename_column(\"query\", \"text\")\n",
|
|||
|
"queries[0]"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Then select the columns for corpus:"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 12,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"corpus = ds.select_columns(column_names=[\"id\", \"pos\"])\n",
|
|||
|
"corpus = corpus.rename_column(\"pos\", \"text\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Finally, make the qrels that indicating the relations of queries and corresponding corpus\""
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 13,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Flattening the indices: 100%|██████████| 700/700 [00:00<00:00, 180956.10 examples/s]\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"{'qid': '1289', 'docid': '1289', 'relevance': 1}"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 13,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"qrels = test.select_columns([\"id\"])\n",
|
|||
|
"qrels = qrels.rename_column(\"id\", \"qid\")\n",
|
|||
|
"qrels = qrels.add_column(\"docid\", list(test[\"id\"]))\n",
|
|||
|
"qrels = qrels.add_column(\"relevance\", [1]*len(test))\n",
|
|||
|
"qrels[0]"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Store the training set"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 14,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 210.42ba/s]\n",
|
|||
|
"Creating json from Arrow format: 100%|██████████| 7/7 [00:00<00:00, 261.19ba/s]\n",
|
|||
|
"Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 591.08ba/s]\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"30574"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 14,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"queries.to_json(\"ft_data/test_queries.jsonl\")\n",
|
|||
|
"corpus.to_json(\"ft_data/corpus.jsonl\")\n",
|
|||
|
"qrels.to_json(\"ft_data/test_qrels.jsonl\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Finetune"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 10,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from FlagEmbedding import FlagModel\n",
|
|||
|
"\n",
|
|||
|
"finetuned_path = \"test_encoder_only_base_bge-large-en-v1.5\"\n",
|
|||
|
"model_name = \"BAAI/bge-large-en-v1.5\"\n",
|
|||
|
"model = FlagModel(finetuned_path, \n",
|
|||
|
"# model = FlagModel(model_name,\n",
|
|||
|
" query_instruction_for_retrieval=\"Represent this sentence for searching relevant passages:\",\n",
|
|||
|
" devices=[0,1],\n",
|
|||
|
" use_fp16=False)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 11,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"initial target device: 100%|██████████| 2/2 [00:30<00:00, 15.31s/it]\n",
|
|||
|
"pre tokenize: 100%|██████████| 2/2 [00:00<00:00, 116.32it/s]\n",
|
|||
|
"You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
|
|||
|
"pre tokenize: 100%|██████████| 2/2 [00:00<00:00, 123.47it/s]\n",
|
|||
|
"You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
|
|||
|
"/share/project/xzy/Envs/ft/lib/python3.11/site-packages/_distutils_hack/__init__.py:54: UserWarning: Reliance on distutils from stdlib is deprecated. Users must rely on setuptools to provide the distutils module. Avoid importing distutils or import setuptools first, and avoid setting SETUPTOOLS_USE_DISTUTILS=stdlib. Register concerns at https://github.com/pypa/setuptools/issues/new?template=distutils-deprecation.yml\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/share/project/xzy/Envs/ft/lib/python3.11/site-packages/_distutils_hack/__init__.py:54: UserWarning: Reliance on distutils from stdlib is deprecated. Users must rely on setuptools to provide the distutils module. Avoid importing distutils or import setuptools first, and avoid setting SETUPTOOLS_USE_DISTUTILS=stdlib. Register concerns at https://github.com/pypa/setuptools/issues/new?template=distutils-deprecation.yml\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"Inference Embeddings: 100%|██████████| 2/2 [00:00<00:00, 13.06it/s]\n",
|
|||
|
"Inference Embeddings: 100%|██████████| 2/2 [00:00<00:00, 13.14it/s]\n",
|
|||
|
"Chunks: 100%|██████████| 2/2 [00:05<00:00, 2.56s/it]\n",
|
|||
|
"pre tokenize: 100%|██████████| 14/14 [00:00<00:00, 55.58it/s]\n",
|
|||
|
"pre tokenize: 100%|██████████| 14/14 [00:00<00:00, 27.82it/s]\n",
|
|||
|
"Inference Embeddings: 100%|██████████| 14/14 [00:02<00:00, 6.24it/s]\n",
|
|||
|
"Inference Embeddings: 100%|██████████| 14/14 [00:03<00:00, 4.07it/s]\n",
|
|||
|
"Chunks: 100%|██████████| 2/2 [00:04<00:00, 2.05s/it]\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"queries_text = [q[1] for q in queries.items()]\n",
|
|||
|
"corpus_text = [corpus[str(i)][0] for i in range(len(corpus))]\n",
|
|||
|
"\n",
|
|||
|
"queries_embeddings = model.encode_queries(queries_text)\n",
|
|||
|
"corpus_embeddings = model.encode_corpus(corpus_text)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 12,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"total number of vectors: 7000\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import faiss\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"\n",
|
|||
|
"# get the length of our embedding vectors, vectors by bge-base-en-v1.5 have length 768\n",
|
|||
|
"dim = corpus_embeddings.shape[-1]\n",
|
|||
|
"\n",
|
|||
|
"# create the faiss index and store the corpus embeddings into the vector space\n",
|
|||
|
"index = faiss.index_factory(dim, 'Flat', faiss.METRIC_INNER_PRODUCT)\n",
|
|||
|
"# corpus_embeddings = corpus_embeddings.astype(np.float32)\n",
|
|||
|
"# train and add the embeddings to the index\n",
|
|||
|
"index.train(corpus_embeddings)\n",
|
|||
|
"index.add(corpus_embeddings)\n",
|
|||
|
"\n",
|
|||
|
"print(f\"total number of vectors: {index.ntotal}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 13,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Searching: 100%|██████████| 22/22 [00:00<00:00, 31.84it/s]\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from tqdm import tqdm\n",
|
|||
|
"\n",
|
|||
|
"query_size = len(queries_embeddings)\n",
|
|||
|
"\n",
|
|||
|
"all_scores = []\n",
|
|||
|
"all_indices = []\n",
|
|||
|
"\n",
|
|||
|
"for i in tqdm(range(0, query_size, 32), desc=\"Searching\"):\n",
|
|||
|
" j = min(i + 32, query_size)\n",
|
|||
|
" query_embedding = queries_embeddings[i: j]\n",
|
|||
|
" score, indice = index.search(query_embedding.astype(np.float32), k=100)\n",
|
|||
|
" all_scores.append(score)\n",
|
|||
|
" all_indices.append(indice)\n",
|
|||
|
"\n",
|
|||
|
"all_scores = np.concatenate(all_scores, axis=0)\n",
|
|||
|
"all_indices = np.concatenate(all_indices, axis=0)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 14,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"results = {}\n",
|
|||
|
"for idx, (scores, indices) in enumerate(zip(all_scores, all_indices)):\n",
|
|||
|
" results[queries_ids[idx]] = {}\n",
|
|||
|
" for score, index in zip(scores, indices):\n",
|
|||
|
" if index != -1:\n",
|
|||
|
" results[queries_ids[idx]][corpus_ids[index]] = float(score)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 15,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"defaultdict(<class 'list'>, {'NDCG@10': 0.84061, 'NDCG@100': 0.85484})\n",
|
|||
|
"defaultdict(<class 'list'>, {'MAP@10': 0.81157, 'MAP@100': 0.81471})\n",
|
|||
|
"defaultdict(<class 'list'>, {'Recall@10': 0.93, 'Recall@100': 0.99429})\n",
|
|||
|
"defaultdict(<class 'list'>, {'P@10': 0.093, 'P@100': 0.00994})\n",
|
|||
|
"defaultdict(<class 'list'>, {'MRR@10': 0.81157, 'MRR@100': 0.81471})\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from FlagEmbedding.abc.evaluation.utils import evaluate_metrics, evaluate_mrr\n",
|
|||
|
"\n",
|
|||
|
"k_values = [10,100]\n",
|
|||
|
"eval_res = evaluate_metrics(qrels, results, k_values)\n",
|
|||
|
"mrr = evaluate_mrr(qrels, results, k_values)\n",
|
|||
|
"\n",
|
|||
|
"for res in eval_res:\n",
|
|||
|
" print(res)\n",
|
|||
|
"print(mrr)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"defaultdict(<class 'list'>, {'NDCG@1': 0.58286, 'NDCG@5': 0.68588, 'NDCG@10': 0.70405})\n",
|
|||
|
"defaultdict(<class 'list'>, {'Recall@1': 0.58286, 'Recall@5': 0.76714, 'Recall@10': 0.82286})\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# Original test result"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"defaultdict(<class 'list'>, {'NDCG@1': 0.75571, 'NDCG@5': 0.84706, 'NDCG@10': 0.85623})\n",
|
|||
|
"defaultdict(<class 'list'>, {'Recall@1': 0.75571, 'Recall@5': 0.92286, 'Recall@10': 0.95143})\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# Fake test result"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 9,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"You're using a XLMRobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"[6.453125]\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from FlagEmbedding import FlagReranker\n",
|
|||
|
"\n",
|
|||
|
"reranker = FlagReranker(\n",
|
|||
|
" 'BAAI/bge-reranker-base', \n",
|
|||
|
" query_max_length=256,\n",
|
|||
|
" use_fp16=True,\n",
|
|||
|
" devices=['cuda:1'],\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"score = reranker.compute_score(['I am happy to help', 'Assisting you is my pleasure'])\n",
|
|||
|
"print(score)"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kernelspec": {
|
|||
|
"display_name": "ft",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.11.10"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 2
|
|||
|
}
|