2025-05-28 14:12:16 +08:00

313 lines
14 KiB
Python

import os
import json
import random
import logging
import pathlib
import argparse
import numpy as np
from time import time
from datasets import load_dataset
from beir import util, LoggingHandler
from beir.retrieval import models
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
from tqdm import tqdm
from transformers import HfArgumentParser
from arguments import CodeRAGEvalArgs, CodeRAGEvalModelArgs
from prompts import get_task_def_by_task_name
from FlagEmbedding import FlagLLMModel, FlagModel
def get_model(model_args: CodeRAGEvalModelArgs):
embedder_name_or_path = model_args.embedder_name_or_path
if model_args.embedder_model_class == "encoder-only-base":
embedder = FlagModel(
model_name_or_path=embedder_name_or_path,
normalize_embeddings=model_args.normalize_embeddings,
pooling_method=model_args.pooling_method,
use_fp16=model_args.use_fp16,
query_instruction_for_retrieval=model_args.query_instruction_for_retrieval,
query_instruction_format=model_args.query_instruction_format_for_retrieval,
devices=model_args.devices,
trust_remote_code=model_args.trust_remote_code,
cache_dir=model_args.cache_dir,
batch_size=model_args.embedder_batch_size,
query_max_length=model_args.embedder_query_max_length,
passage_max_length=model_args.embedder_passage_max_length,
)
elif model_args.embedder_model_class == "decoder-only-base":
embedder = FlagLLMModel(
model_name_or_path=embedder_name_or_path,
normalize_embeddings=model_args.normalize_embeddings,
pooling_method=model_args.pooling_method,
use_fp16=model_args.use_fp16,
query_instruction_for_retrieval=model_args.query_instruction_for_retrieval,
query_instruction_format=model_args.query_instruction_format_for_retrieval,
devices=model_args.devices,
examples_for_task=model_args.examples_for_task,
examples_instruction_format=model_args.examples_instruction_format,
trust_remote_code=model_args.trust_remote_code,
cache_dir=model_args.cache_dir,
batch_size=model_args.embedder_batch_size,
query_max_length=model_args.embedder_query_max_length,
passage_max_length=model_args.embedder_passage_max_length,
)
else:
raise ValueError(f"Invalid model class: {model_args.embedder_model_class}")
embedder.model.config._name_or_path = model_args.embedder_name_or_path
class CustomFlagModel:
def __init__(self, model):
self.model = model
def encode_queries(self, queries, show_progress_bar, convert_to_tensor, **kwargs):
if isinstance(queries, str):
queries = [queries]
if isinstance(queries[0], dict):
queries = [(e.get('title') + ' ' + e['text']).strip() for e in queries]
return self.model.encode_queries(queries, **kwargs)
def encode_corpus(self, corpus, show_progress_bar, convert_to_tensor, **kwargs):
if isinstance(corpus, str):
corpus = [corpus]
if isinstance(corpus[0], dict):
corpus = [(e.get('title') + ' ' + e['text']).strip() for e in corpus]
return self.model.encode_corpus(corpus, **kwargs)
def encode(self, corpus, show_progress_bar, convert_to_tensor, **kwargs):
if isinstance(corpus, str):
corpus = [corpus]
if isinstance(corpus[0], dict):
corpus = [(e.get('title') + ' ' + e['text']).strip() for e in corpus]
return self.model.encode(corpus, **kwargs)
return CustomFlagModel(embedder)
#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[LoggingHandler()])
def get_top_docs(results: dict, corpus: dict, task_id: str, topk: int = 10) -> list[str]:
if task_id not in results: return []
doc_scores = results[task_id]
doc_scores_sorted = sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)
doc_scores_sorted = doc_scores_sorted[:topk]
doc_code_snippets = [corpus[code_id] for code_id, score in doc_scores_sorted]
return doc_code_snippets
def main(
eval_args: CodeRAGEvalArgs,
model_args: CodeRAGEvalModelArgs
):
args = eval_args
embedder = get_model(model_args)
model = DRES(
embedder,
batch_size=args.batch_size,
corpus_chunk_size=512 * 9999
)
retriever = EvaluateRetrieval(model, score_function="dot")
if args.dataset.startswith("swe-bench") or args.dataset.startswith("repoeval"):
all_eval_results = []
if args.dataset.startswith("swe-bench"):
swebench = load_dataset("princeton-nlp/SWE-bench_Lite")["test"]
all_top_docs = [[] for _ in swebench]
instance_list = [i for i in os.listdir("datasets") if i.startswith(f"{args.dataset}_")]
instance_list_filtered = []
for ins_dir in tqdm(instance_list):
logging.info("Instance Repo: {}".format(ins_dir))
# load data and perform retrieval
corpus, queries, qrels = GenericDataLoader(
data_folder=os.path.join("datasets", ins_dir)
).load(split="test")
logging.info(f"Instance #{ins_dir}: #{len(corpus)} corpus, #{len(queries)} queries")
start_time = time()
if len(queries) == 1:
queries.update({"dummy": "dummy"})
results = retriever.retrieve(corpus, queries)
if "dummy" in queries:
queries.pop("dummy")
results.pop("dummy")
end_time = time()
logging.info("Time taken to retrieve: {:.2f} seconds".format(end_time - start_time))
# get topk retrieved docs
if args.dataset.startswith("swe-bench"):
indices = [i for i, ex in enumerate(swebench) if ex["instance_id"] in queries]
for index in indices:
instance_id = swebench[index]["instance_id"]
all_top_docs[index] = get_top_docs(results, corpus, instance_id)
elif args.dataset.startswith("repoeval"):
args.dataset_path = "output/repoeval/datasets/function_level_completion_2k_context_codex.test.clean.jsonl"
tasks = [json.loads(line.strip()) for line in open(args.dataset_path, 'r')]
prompts, references, docs, metadatas = [], [], [], []
for task in tasks:
if task["metadata"]["task_id"] not in queries: continue
prompts.append(task["prompt"]) # save full prompt
references.append(task["metadata"]["ground_truth"])
docs.append(get_top_docs(
results=results, corpus=corpus, task_id=task["metadata"]["task_id"],
))
metadatas.append(task["metadata"])
assert len(prompts) == len(references) == len(docs)
dataset = [
{"prompt": p, "reference": r, "docs": d, "metadata": m}
for p, r, d, m in zip(prompts, references, docs, metadatas)
]
with open(args.results_file, "a") as fout:
for curr in dataset:
fout.write(json.dumps(curr) + "\n")
else:
raise ValueError(f"`dataset` should starts with either 'swe-bench' or 'repoeval'.")
# evaluate retrieval results
if len(qrels) == 0:
logging.info("No qrels found for this dataset.")
return
logging.info("Retriever evaluation for k in: {}".format(retriever.k_values))
ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
mrr = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="mrr")
eval_results = {
"ndcg": ndcg, "mrr": mrr,
"recall": recall, "precision": precision,
"time": end_time - start_time
}
logging.info(f"Instance #{ins_dir}: {eval_results}")
all_eval_results.append(eval_results)
with open(args.output_file + "_all", "w") as f:
json.dump(all_eval_results, f)
if args.dataset.startswith("swe-bench"):
swebench = swebench.add_column("docs", all_top_docs)
swebench.to_json(args.results_file)
avg_eval_results = {}
for k, v_dict in all_eval_results[0].items():
if isinstance(v_dict, dict):
avg_v_dict = {}
for vk, vv in v_dict.items():
avg_vv = sum([e[k][vk] for e in all_eval_results]) / len(all_eval_results)
avg_v_dict[vk] = avg_vv
avg_eval_results.update(avg_v_dict)
elif isinstance(v_dict, float):
avg_v = sum([e[k] for e in all_eval_results]) / len(all_eval_results)
avg_eval_results[k] = avg_v
else:
raise ValueError
print("Average Eval Results: ", avg_eval_results)
with open(args.output_file, "w") as f:
json.dump(avg_eval_results, f)
else:
dataset = args.dataset
corpus, queries, qrels = GenericDataLoader(data_folder=os.path.join("datasets", args.dataset)).load(
split="test")
#### Retrieve dense results (format of results is identical to qrels)
start_time = time()
results = retriever.retrieve(corpus, queries)
end_time = time()
print("Time taken to retrieve: {:.2f} seconds".format(end_time - start_time))
if args.dataset in ["humaneval", "mbpp", "apps"]:
if args.dataset == "humaneval":
ds = load_dataset("openai_humaneval")
id_key = "task_id"
elif args.dataset == "mbpp":
ds = load_dataset("mbpp")
id_key = "task_id"
elif args.dataset == "apps":
ds = load_dataset("codeparrot/apps")
id_key = "problem_id"
all_top_docs = []
for task_id in ds["test"][id_key]:
all_top_docs.append(get_top_docs(results, corpus, f"{task_id}_doc"))
ds["test"] = ds["test"].add_column("docs", all_top_docs)
ds["test"].to_json(args.results_file) # this outputs to arrow format and read as .jsonl
elif args.dataset.startswith("odex"):
lang = args.dataset.split("_")[-1]
ds = load_dataset("neulab/odex", lang, trust_remote_code=True)
all_top_docs = []
for idx, task_id in enumerate(ds["test"]["task_id"]):
all_top_docs.append(get_top_docs(results, corpus, f"{idx}_{task_id}"))
ds["test"] = ds["test"].add_column("docs", all_top_docs)
ds["test"].to_json(args.results_file) # this outputs to arrow format and read as .jsonl
elif args.dataset.startswith("ds1000"):
_, key, mode = args.dataset.split("_")
key = key.capitalize()
mode = mode.capitalize()
from create.ds1000 import get_dataset
source_dir = pathlib.Path(__file__).parent / "ds"
data = get_dataset(source_dir, mode=mode, key=key)
all_docs = []
example_ids = []
for item in data:
example = item.data
example_id = f"{example['lib']}_{example['perturbation_origin_id']}"
all_docs.append(get_top_docs(results, corpus, example_id))
example_ids.append(example_id)
assert len(all_docs) == len(
example_ids), f"length of all_docs should be {len(example_ids)}, now is {len(all_docs)}"
with open(args.results_file, "w+") as fout:
for idx, all_doc in enumerate(all_docs):
fout.write(json.dumps({"example_id": example_id,
"docs": all_doc}) + "\n")
else:
with open(args.results_file, 'w+') as fw:
for curr in results:
fw.write(json.dumps({curr: results[curr]}) + "\n")
#### Evaluate your retrieval using NDCG@k, MAP@K ...
if len(qrels) == 0:
logging.info("No qrels found for this dataset.")
return
logging.info("Retriever evaluation for k in: {}".format(retriever.k_values))
ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
mrr = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="mrr")
recall_cap = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="r_cap")
hole = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="hole")
all_results = {"ndcg": ndcg, "mrr": mrr, "recall": recall, "precision": precision,
"time": end_time - start_time}
with open(args.output_file, "w") as f:
json.dump(all_results, f)
#### Print top-k documents retrieved ####
top_k = 3
query_id, ranking_scores = random.choice(list(results.items()))
scores_sorted = sorted(ranking_scores.items(), key=lambda item: item[1], reverse=True)
logging.info("Query : %s\n" % queries[query_id])
for rank in range(top_k):
doc_id = scores_sorted[rank][0]
# Format: Rank x: ID [Title] Body
logging.info(
"Rank %d: %s [%s] - %s\n" % (rank + 1, doc_id, corpus[doc_id].get("title"), corpus[doc_id].get("text")))
if __name__ == "__main__":
parser = HfArgumentParser((
CodeRAGEvalArgs,
CodeRAGEvalModelArgs
))
eval_args, model_args = parser.parse_args_into_dataclasses()
main(eval_args, model_args)