mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-07-03 07:05:22 +00:00
300 lines
9.0 KiB
Plaintext
300 lines
9.0 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Evaluate the Fine-tuned Model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"In the previous sections, we prepared the dataset and fine-tuned the model. In this tutorial, we will go through how to evaluate the model with the test dataset we constructed."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 0. Installation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"% pip install -U datasets pytrec_eval FlagEmbedding"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 1. Load Data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We first load data from the files we processed."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from datasets import load_dataset\n",
|
|
"\n",
|
|
"queries = load_dataset(\"json\", data_files=\"ft_data/test_queries.jsonl\")[\"train\"]\n",
|
|
"corpus = load_dataset(\"json\", data_files=\"ft_data/corpus.jsonl\")[\"train\"]\n",
|
|
"qrels = load_dataset(\"json\", data_files=\"ft_data/test_qrels.jsonl\")[\"train\"]\n",
|
|
"\n",
|
|
"queries_text = queries[\"text\"]\n",
|
|
"corpus_text = [text for sub in corpus[\"text\"] for text in sub]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"qrels_dict = {}\n",
|
|
"for line in qrels:\n",
|
|
" if line['qid'] not in qrels_dict:\n",
|
|
" qrels_dict[line['qid']] = {}\n",
|
|
" qrels_dict[line['qid']][line['docid']] = line['relevance']"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 2. Search"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Then we prepare a function to encode the text into embeddings and search the results:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import faiss\n",
|
|
"import numpy as np\n",
|
|
"from tqdm import tqdm\n",
|
|
"\n",
|
|
"\n",
|
|
"def search(model, queries_text, corpus_text):\n",
|
|
" \n",
|
|
" queries_embeddings = model.encode_queries(queries_text)\n",
|
|
" corpus_embeddings = model.encode_corpus(corpus_text)\n",
|
|
" \n",
|
|
" # create and store the embeddings in a Faiss index\n",
|
|
" dim = corpus_embeddings.shape[-1]\n",
|
|
" index = faiss.index_factory(dim, 'Flat', faiss.METRIC_INNER_PRODUCT)\n",
|
|
" corpus_embeddings = corpus_embeddings.astype(np.float32)\n",
|
|
" index.train(corpus_embeddings)\n",
|
|
" index.add(corpus_embeddings)\n",
|
|
" \n",
|
|
" query_size = len(queries_embeddings)\n",
|
|
"\n",
|
|
" all_scores = []\n",
|
|
" all_indices = []\n",
|
|
"\n",
|
|
" # search top 100 answers for all the queries\n",
|
|
" for i in tqdm(range(0, query_size, 32), desc=\"Searching\"):\n",
|
|
" j = min(i + 32, query_size)\n",
|
|
" query_embedding = queries_embeddings[i: j]\n",
|
|
" score, indice = index.search(query_embedding.astype(np.float32), k=100)\n",
|
|
" all_scores.append(score)\n",
|
|
" all_indices.append(indice)\n",
|
|
"\n",
|
|
" all_scores = np.concatenate(all_scores, axis=0)\n",
|
|
" all_indices = np.concatenate(all_indices, axis=0)\n",
|
|
" \n",
|
|
" # store the results into the format for evaluation\n",
|
|
" results = {}\n",
|
|
" for idx, (scores, indices) in enumerate(zip(all_scores, all_indices)):\n",
|
|
" results[queries[\"id\"][idx]] = {}\n",
|
|
" for score, index in zip(scores, indices):\n",
|
|
" if index != -1:\n",
|
|
" results[queries[\"id\"][idx]][corpus[\"id\"][index]] = float(score)\n",
|
|
" \n",
|
|
" return results"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 3. Evaluation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from FlagEmbedding.abc.evaluation.utils import evaluate_metrics, evaluate_mrr\n",
|
|
"from FlagEmbedding import FlagModel\n",
|
|
"\n",
|
|
"k_values = [10,100]\n",
|
|
"\n",
|
|
"raw_name = \"BAAI/bge-large-en-v1.5\"\n",
|
|
"finetuned_path = \"test_encoder_only_base_bge-large-en-v1.5\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The result for the original model:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"pre tokenize: 100%|██████████| 3/3 [00:00<00:00, 129.75it/s]\n",
|
|
"You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
|
|
"Inference Embeddings: 100%|██████████| 3/3 [00:00<00:00, 11.08it/s]\n",
|
|
"pre tokenize: 100%|██████████| 28/28 [00:00<00:00, 164.29it/s]\n",
|
|
"Inference Embeddings: 100%|██████████| 28/28 [00:04<00:00, 6.09it/s]\n",
|
|
"Searching: 100%|██████████| 22/22 [00:08<00:00, 2.56it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"defaultdict(<class 'list'>, {'NDCG@10': 0.70405, 'NDCG@100': 0.73528})\n",
|
|
"defaultdict(<class 'list'>, {'MAP@10': 0.666, 'MAP@100': 0.67213})\n",
|
|
"defaultdict(<class 'list'>, {'Recall@10': 0.82286, 'Recall@100': 0.97286})\n",
|
|
"defaultdict(<class 'list'>, {'P@10': 0.08229, 'P@100': 0.00973})\n",
|
|
"defaultdict(<class 'list'>, {'MRR@10': 0.666, 'MRR@100': 0.67213})\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"raw_model = FlagModel(\n",
|
|
" raw_name, \n",
|
|
" query_instruction_for_retrieval=\"Represent this sentence for searching relevant passages:\",\n",
|
|
" devices=[0],\n",
|
|
" use_fp16=False\n",
|
|
")\n",
|
|
"\n",
|
|
"results = search(raw_model, queries_text, corpus_text)\n",
|
|
"\n",
|
|
"eval_res = evaluate_metrics(qrels_dict, results, k_values)\n",
|
|
"mrr = evaluate_mrr(qrels_dict, results, k_values)\n",
|
|
"\n",
|
|
"for res in eval_res:\n",
|
|
" print(res)\n",
|
|
"print(mrr)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Then the result for the model after fine-tuning:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"pre tokenize: 100%|██████████| 3/3 [00:00<00:00, 164.72it/s]\n",
|
|
"You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
|
|
"Inference Embeddings: 100%|██████████| 3/3 [00:00<00:00, 9.45it/s]\n",
|
|
"pre tokenize: 100%|██████████| 28/28 [00:00<00:00, 160.19it/s]\n",
|
|
"Inference Embeddings: 100%|██████████| 28/28 [00:04<00:00, 6.06it/s]\n",
|
|
"Searching: 100%|██████████| 22/22 [00:07<00:00, 2.80it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"defaultdict(<class 'list'>, {'NDCG@10': 0.84392, 'NDCG@100': 0.85792})\n",
|
|
"defaultdict(<class 'list'>, {'MAP@10': 0.81562, 'MAP@100': 0.81875})\n",
|
|
"defaultdict(<class 'list'>, {'Recall@10': 0.93143, 'Recall@100': 0.99429})\n",
|
|
"defaultdict(<class 'list'>, {'P@10': 0.09314, 'P@100': 0.00994})\n",
|
|
"defaultdict(<class 'list'>, {'MRR@10': 0.81562, 'MRR@100': 0.81875})\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"ft_model = FlagModel(\n",
|
|
" finetuned_path, \n",
|
|
" query_instruction_for_retrieval=\"Represent this sentence for searching relevant passages:\",\n",
|
|
" devices=[0],\n",
|
|
" use_fp16=False\n",
|
|
")\n",
|
|
"\n",
|
|
"results = search(ft_model, queries_text, corpus_text)\n",
|
|
"\n",
|
|
"eval_res = evaluate_metrics(qrels_dict, results, k_values)\n",
|
|
"mrr = evaluate_mrr(qrels_dict, results, k_values)\n",
|
|
"\n",
|
|
"for res in eval_res:\n",
|
|
" print(res)\n",
|
|
"print(mrr)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can see an obvious improvement in all the metrics."
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "ft",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.10"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|