mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-06-27 02:39:58 +00:00
313 lines
14 KiB
Python
313 lines
14 KiB
Python
import os
|
|
import json
|
|
import random
|
|
import logging
|
|
import pathlib
|
|
import argparse
|
|
import numpy as np
|
|
from time import time
|
|
from datasets import load_dataset
|
|
from beir import util, LoggingHandler
|
|
from beir.retrieval import models
|
|
from beir.datasets.data_loader import GenericDataLoader
|
|
from beir.retrieval.evaluation import EvaluateRetrieval
|
|
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
|
|
from tqdm import tqdm
|
|
from transformers import HfArgumentParser
|
|
|
|
from arguments import CodeRAGEvalArgs, CodeRAGEvalModelArgs
|
|
from prompts import get_task_def_by_task_name
|
|
from FlagEmbedding import FlagLLMModel, FlagModel
|
|
|
|
|
|
def get_model(model_args: CodeRAGEvalModelArgs):
|
|
embedder_name_or_path = model_args.embedder_name_or_path
|
|
|
|
if model_args.embedder_model_class == "encoder-only-base":
|
|
embedder = FlagModel(
|
|
model_name_or_path=embedder_name_or_path,
|
|
normalize_embeddings=model_args.normalize_embeddings,
|
|
pooling_method=model_args.pooling_method,
|
|
use_fp16=model_args.use_fp16,
|
|
query_instruction_for_retrieval=model_args.query_instruction_for_retrieval,
|
|
query_instruction_format=model_args.query_instruction_format_for_retrieval,
|
|
devices=model_args.devices,
|
|
trust_remote_code=model_args.trust_remote_code,
|
|
cache_dir=model_args.cache_dir,
|
|
batch_size=model_args.embedder_batch_size,
|
|
query_max_length=model_args.embedder_query_max_length,
|
|
passage_max_length=model_args.embedder_passage_max_length,
|
|
)
|
|
elif model_args.embedder_model_class == "decoder-only-base":
|
|
embedder = FlagLLMModel(
|
|
model_name_or_path=embedder_name_or_path,
|
|
normalize_embeddings=model_args.normalize_embeddings,
|
|
pooling_method=model_args.pooling_method,
|
|
use_fp16=model_args.use_fp16,
|
|
query_instruction_for_retrieval=model_args.query_instruction_for_retrieval,
|
|
query_instruction_format=model_args.query_instruction_format_for_retrieval,
|
|
devices=model_args.devices,
|
|
examples_for_task=model_args.examples_for_task,
|
|
examples_instruction_format=model_args.examples_instruction_format,
|
|
trust_remote_code=model_args.trust_remote_code,
|
|
cache_dir=model_args.cache_dir,
|
|
batch_size=model_args.embedder_batch_size,
|
|
query_max_length=model_args.embedder_query_max_length,
|
|
passage_max_length=model_args.embedder_passage_max_length,
|
|
)
|
|
else:
|
|
raise ValueError(f"Invalid model class: {model_args.embedder_model_class}")
|
|
embedder.model.config._name_or_path = model_args.embedder_name_or_path
|
|
|
|
class CustomFlagModel:
|
|
def __init__(self, model):
|
|
self.model = model
|
|
|
|
def encode_queries(self, queries, show_progress_bar, convert_to_tensor, **kwargs):
|
|
if isinstance(queries, str):
|
|
queries = [queries]
|
|
|
|
if isinstance(queries[0], dict):
|
|
queries = [(e.get('title') + ' ' + e['text']).strip() for e in queries]
|
|
|
|
return self.model.encode_queries(queries, **kwargs)
|
|
|
|
def encode_corpus(self, corpus, show_progress_bar, convert_to_tensor, **kwargs):
|
|
if isinstance(corpus, str):
|
|
corpus = [corpus]
|
|
|
|
if isinstance(corpus[0], dict):
|
|
corpus = [(e.get('title') + ' ' + e['text']).strip() for e in corpus]
|
|
|
|
return self.model.encode_corpus(corpus, **kwargs)
|
|
|
|
def encode(self, corpus, show_progress_bar, convert_to_tensor, **kwargs):
|
|
if isinstance(corpus, str):
|
|
corpus = [corpus]
|
|
|
|
if isinstance(corpus[0], dict):
|
|
corpus = [(e.get('title') + ' ' + e['text']).strip() for e in corpus]
|
|
|
|
return self.model.encode(corpus, **kwargs)
|
|
|
|
return CustomFlagModel(embedder)
|
|
|
|
#### Just some code to print debug information to stdout
|
|
logging.basicConfig(format='%(asctime)s - %(message)s',
|
|
datefmt='%Y-%m-%d %H:%M:%S',
|
|
level=logging.INFO,
|
|
handlers=[LoggingHandler()])
|
|
|
|
|
|
def get_top_docs(results: dict, corpus: dict, task_id: str, topk: int = 10) -> list[str]:
|
|
if task_id not in results: return []
|
|
doc_scores = results[task_id]
|
|
doc_scores_sorted = sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)
|
|
doc_scores_sorted = doc_scores_sorted[:topk]
|
|
doc_code_snippets = [corpus[code_id] for code_id, score in doc_scores_sorted]
|
|
return doc_code_snippets
|
|
|
|
|
|
def main(
|
|
eval_args: CodeRAGEvalArgs,
|
|
model_args: CodeRAGEvalModelArgs
|
|
):
|
|
args = eval_args
|
|
|
|
embedder = get_model(model_args)
|
|
model = DRES(
|
|
embedder,
|
|
batch_size=args.batch_size,
|
|
corpus_chunk_size=512 * 9999
|
|
)
|
|
retriever = EvaluateRetrieval(model, score_function="dot")
|
|
|
|
if args.dataset.startswith("swe-bench") or args.dataset.startswith("repoeval"):
|
|
all_eval_results = []
|
|
|
|
if args.dataset.startswith("swe-bench"):
|
|
swebench = load_dataset("princeton-nlp/SWE-bench_Lite")["test"]
|
|
all_top_docs = [[] for _ in swebench]
|
|
|
|
instance_list = [i for i in os.listdir("datasets") if i.startswith(f"{args.dataset}_")]
|
|
instance_list_filtered = []
|
|
|
|
for ins_dir in tqdm(instance_list):
|
|
logging.info("Instance Repo: {}".format(ins_dir))
|
|
# load data and perform retrieval
|
|
corpus, queries, qrels = GenericDataLoader(
|
|
data_folder=os.path.join("datasets", ins_dir)
|
|
).load(split="test")
|
|
logging.info(f"Instance #{ins_dir}: #{len(corpus)} corpus, #{len(queries)} queries")
|
|
|
|
start_time = time()
|
|
if len(queries) == 1:
|
|
queries.update({"dummy": "dummy"})
|
|
results = retriever.retrieve(corpus, queries)
|
|
if "dummy" in queries:
|
|
queries.pop("dummy")
|
|
results.pop("dummy")
|
|
end_time = time()
|
|
logging.info("Time taken to retrieve: {:.2f} seconds".format(end_time - start_time))
|
|
|
|
# get topk retrieved docs
|
|
if args.dataset.startswith("swe-bench"):
|
|
indices = [i for i, ex in enumerate(swebench) if ex["instance_id"] in queries]
|
|
for index in indices:
|
|
instance_id = swebench[index]["instance_id"]
|
|
all_top_docs[index] = get_top_docs(results, corpus, instance_id)
|
|
elif args.dataset.startswith("repoeval"):
|
|
args.dataset_path = "output/repoeval/datasets/function_level_completion_2k_context_codex.test.clean.jsonl"
|
|
tasks = [json.loads(line.strip()) for line in open(args.dataset_path, 'r')]
|
|
prompts, references, docs, metadatas = [], [], [], []
|
|
for task in tasks:
|
|
if task["metadata"]["task_id"] not in queries: continue
|
|
prompts.append(task["prompt"]) # save full prompt
|
|
references.append(task["metadata"]["ground_truth"])
|
|
docs.append(get_top_docs(
|
|
results=results, corpus=corpus, task_id=task["metadata"]["task_id"],
|
|
))
|
|
metadatas.append(task["metadata"])
|
|
assert len(prompts) == len(references) == len(docs)
|
|
dataset = [
|
|
{"prompt": p, "reference": r, "docs": d, "metadata": m}
|
|
for p, r, d, m in zip(prompts, references, docs, metadatas)
|
|
]
|
|
with open(args.results_file, "a") as fout:
|
|
for curr in dataset:
|
|
fout.write(json.dumps(curr) + "\n")
|
|
else:
|
|
raise ValueError(f"`dataset` should starts with either 'swe-bench' or 'repoeval'.")
|
|
|
|
# evaluate retrieval results
|
|
if len(qrels) == 0:
|
|
logging.info("No qrels found for this dataset.")
|
|
return
|
|
logging.info("Retriever evaluation for k in: {}".format(retriever.k_values))
|
|
ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
|
|
mrr = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="mrr")
|
|
eval_results = {
|
|
"ndcg": ndcg, "mrr": mrr,
|
|
"recall": recall, "precision": precision,
|
|
"time": end_time - start_time
|
|
}
|
|
logging.info(f"Instance #{ins_dir}: {eval_results}")
|
|
all_eval_results.append(eval_results)
|
|
|
|
with open(args.output_file + "_all", "w") as f:
|
|
json.dump(all_eval_results, f)
|
|
|
|
if args.dataset.startswith("swe-bench"):
|
|
swebench = swebench.add_column("docs", all_top_docs)
|
|
swebench.to_json(args.results_file)
|
|
|
|
avg_eval_results = {}
|
|
for k, v_dict in all_eval_results[0].items():
|
|
if isinstance(v_dict, dict):
|
|
avg_v_dict = {}
|
|
for vk, vv in v_dict.items():
|
|
avg_vv = sum([e[k][vk] for e in all_eval_results]) / len(all_eval_results)
|
|
avg_v_dict[vk] = avg_vv
|
|
avg_eval_results.update(avg_v_dict)
|
|
elif isinstance(v_dict, float):
|
|
avg_v = sum([e[k] for e in all_eval_results]) / len(all_eval_results)
|
|
avg_eval_results[k] = avg_v
|
|
else:
|
|
raise ValueError
|
|
print("Average Eval Results: ", avg_eval_results)
|
|
with open(args.output_file, "w") as f:
|
|
json.dump(avg_eval_results, f)
|
|
else:
|
|
dataset = args.dataset
|
|
corpus, queries, qrels = GenericDataLoader(data_folder=os.path.join("datasets", args.dataset)).load(
|
|
split="test")
|
|
#### Retrieve dense results (format of results is identical to qrels)
|
|
start_time = time()
|
|
results = retriever.retrieve(corpus, queries)
|
|
end_time = time()
|
|
print("Time taken to retrieve: {:.2f} seconds".format(end_time - start_time))
|
|
|
|
if args.dataset in ["humaneval", "mbpp", "apps"]:
|
|
if args.dataset == "humaneval":
|
|
ds = load_dataset("openai_humaneval")
|
|
id_key = "task_id"
|
|
elif args.dataset == "mbpp":
|
|
ds = load_dataset("mbpp")
|
|
id_key = "task_id"
|
|
elif args.dataset == "apps":
|
|
ds = load_dataset("codeparrot/apps")
|
|
id_key = "problem_id"
|
|
all_top_docs = []
|
|
for task_id in ds["test"][id_key]:
|
|
all_top_docs.append(get_top_docs(results, corpus, f"{task_id}_doc"))
|
|
ds["test"] = ds["test"].add_column("docs", all_top_docs)
|
|
ds["test"].to_json(args.results_file) # this outputs to arrow format and read as .jsonl
|
|
elif args.dataset.startswith("odex"):
|
|
lang = args.dataset.split("_")[-1]
|
|
ds = load_dataset("neulab/odex", lang, trust_remote_code=True)
|
|
all_top_docs = []
|
|
for idx, task_id in enumerate(ds["test"]["task_id"]):
|
|
all_top_docs.append(get_top_docs(results, corpus, f"{idx}_{task_id}"))
|
|
ds["test"] = ds["test"].add_column("docs", all_top_docs)
|
|
ds["test"].to_json(args.results_file) # this outputs to arrow format and read as .jsonl
|
|
elif args.dataset.startswith("ds1000"):
|
|
_, key, mode = args.dataset.split("_")
|
|
key = key.capitalize()
|
|
mode = mode.capitalize()
|
|
from create.ds1000 import get_dataset
|
|
source_dir = pathlib.Path(__file__).parent / "ds"
|
|
data = get_dataset(source_dir, mode=mode, key=key)
|
|
all_docs = []
|
|
example_ids = []
|
|
for item in data:
|
|
example = item.data
|
|
example_id = f"{example['lib']}_{example['perturbation_origin_id']}"
|
|
all_docs.append(get_top_docs(results, corpus, example_id))
|
|
example_ids.append(example_id)
|
|
assert len(all_docs) == len(
|
|
example_ids), f"length of all_docs should be {len(example_ids)}, now is {len(all_docs)}"
|
|
with open(args.results_file, "w+") as fout:
|
|
for idx, all_doc in enumerate(all_docs):
|
|
fout.write(json.dumps({"example_id": example_id,
|
|
"docs": all_doc}) + "\n")
|
|
else:
|
|
with open(args.results_file, 'w+') as fw:
|
|
for curr in results:
|
|
fw.write(json.dumps({curr: results[curr]}) + "\n")
|
|
|
|
#### Evaluate your retrieval using NDCG@k, MAP@K ...
|
|
if len(qrels) == 0:
|
|
logging.info("No qrels found for this dataset.")
|
|
return
|
|
logging.info("Retriever evaluation for k in: {}".format(retriever.k_values))
|
|
ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
|
|
|
|
mrr = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="mrr")
|
|
recall_cap = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="r_cap")
|
|
hole = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="hole")
|
|
|
|
all_results = {"ndcg": ndcg, "mrr": mrr, "recall": recall, "precision": precision,
|
|
"time": end_time - start_time}
|
|
with open(args.output_file, "w") as f:
|
|
json.dump(all_results, f)
|
|
#### Print top-k documents retrieved ####
|
|
top_k = 3
|
|
|
|
query_id, ranking_scores = random.choice(list(results.items()))
|
|
scores_sorted = sorted(ranking_scores.items(), key=lambda item: item[1], reverse=True)
|
|
logging.info("Query : %s\n" % queries[query_id])
|
|
|
|
for rank in range(top_k):
|
|
doc_id = scores_sorted[rank][0]
|
|
# Format: Rank x: ID [Title] Body
|
|
logging.info(
|
|
"Rank %d: %s [%s] - %s\n" % (rank + 1, doc_id, corpus[doc_id].get("title"), corpus[doc_id].get("text")))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = HfArgumentParser((
|
|
CodeRAGEvalArgs,
|
|
CodeRAGEvalModelArgs
|
|
))
|
|
eval_args, model_args = parser.parse_args_into_dataclasses()
|
|
main(eval_args, model_args) |