Merge pull request #1464 from 545999961/master

upload coder eval script
This commit is contained in:
chaofan 2025-05-28 14:18:23 +08:00 committed by GitHub
commit 97ca07325d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 2408 additions and 19 deletions

View File

@ -11,14 +11,17 @@
</a>
</p>
This repo contains the data, training, and evaluation pipeline for CodeR / [BGE-Code-v1](https://huggingface.co/BAAI/bge-code-v1)
**[BGE-Code-v1](https://huggingface.co/BAAI/bge-code-v1)** is an LLM-based code embedding model that supports code retrieval, text retrieval, and multilingual retrieval. It primarily demonstrates the following capabilities:
- Superior Code Retrieval Performance: The model demonstrates exceptional code retrieval capabilities, supporting natural language queries in both English and Chinese, as well as 20 programming languages.
- Robust Text Retrieval Capabilities: The model maintains strong text retrieval capabilities comparable to text embedding models of similar scale.
- Extensive Multilingual Support: BGE-Code-v1 offers comprehensive multilingual retrieval capabilities, excelling in languages such as English, Chinese, Japanese, French, and more.
## :bell: News:
- 🥳 5/15/2025: We have released the CodeR! :fire:
## Usage
@ -29,9 +32,6 @@ This repo contains the data, training, and evaluation pipeline for CodeR / [BGE-
git clone https://github.com/FlagOpen/FlagEmbedding.git
cd FlagEmbedding
pip install -e .
```
```python
from FlagEmbedding import FlagLLMModel
queries = [
"Delete the record with ID 4 from the 'Staff' table.",
@ -149,29 +149,29 @@ print(scores.tolist())
- CoIR
| | CodeXEmbed-2B | CodeXEmbed-7B | Voyage-Code-002 | Voyage-Code-003 | BGE-Code-v1 |
|---------------------------------------|---------------|---------------|-----------------|-----------------|-----------|
| **Apps** | 76.86 | 85.38 | 26.52 | 93.62 | 98.08 |
| **CosQA** | 40.47 | 42.47 | 29.79 | 34.45 | 46.72 |
| **Text2SQL** | 78.42 | 78.94 | 69.26 | 62.87 | 64.35 |
| **CSN** | 87.87 | 89.67 | 81.79 | 89.35 | 89.53 |
| **CSN-CCR** | 97.66 | 97.95 | 73.45 | 90.05 | 98.30 |
| **CodeTrans-Contest** | 90.30 | 94.45 | 72.77 | 94.96 | 94.38 |
| **CodeTrans-DL** | 38.57 | 40.46 | 27.48 | 38.57 | 46.13 |
| **StackOverFlow-QA** | 94.47 | 96.33 | 67.68 | 97.17 | 95.35 |
| **CodeFeedBack-ST** | 86.36 | 87.53 | 65.35 | 90.67 | 90.56 |
| **CodeFeedBack-MT** | 65.51 | 68.83 | 28.74 | 93.58 | 94.38 |
| **AVG** | **75.65** | **78.20** | **56.26** | **78.53** | **81.77** |
| | CodeXEmbed-2B | CodeXEmbed-7B | Voyage-Code-002 | Voyage-Code-003 | BGE-Code-v1 |
| --------------------- | ------------- | ------------- | --------------- | --------------- | ----------- |
| **Apps** | 76.86 | 85.38 | 26.52 | 93.62 | 98.08 |
| **CosQA** | 40.47 | 42.47 | 29.79 | 34.45 | 46.72 |
| **Text2SQL** | 78.42 | 78.94 | 69.26 | 62.87 | 64.35 |
| **CSN** | 87.87 | 89.67 | 81.79 | 89.35 | 89.53 |
| **CSN-CCR** | 97.66 | 97.95 | 73.45 | 90.05 | 98.30 |
| **CodeTrans-Contest** | 90.30 | 94.45 | 72.77 | 94.96 | 94.38 |
| **CodeTrans-DL** | 38.57 | 40.46 | 27.48 | 38.57 | 46.13 |
| **StackOverFlow-QA** | 94.47 | 96.33 | 67.68 | 97.17 | 95.35 |
| **CodeFeedBack-ST** | 86.36 | 87.53 | 65.35 | 90.67 | 90.56 |
| **CodeFeedBack-MT** | 65.51 | 68.83 | 28.74 | 93.58 | 94.38 |
| **AVG** | **75.65** | **78.20** | **56.26** | **78.53** | **81.77** |
- CodedRAG
| | HummanEval | MBPP | DS-1000 | ODEX | RepoEval | SWE-bench-Lite | AVG |
| --------------- | ---------- | ---- | ------- | ---- | -------- | -------------- | ---- |
| | HummanEval | MBPP | DS-1000 | ODEX | RepoEval | SWE-bench-Lite | AVG |
| --------------- | ---------- | ---- | ------- | ---- | -------- | -------------- | -------- |
| SFR | 100.0 | 99.0 | 19.3 | 37.1 | 83.8 | 62.7 | **67.0** |
| Jina-v2-code | 100.0 | 97.7 | 26.2 | 19.9 | 90.5 | 58.3 | **65.4** |
| CodeXEmbed-2B | 100.0 | 97.4 | 25.4 | 23.9 | 88.7 | 52.4 | **64.6** |
| Voyage-Code-002 | 100.0 | 99.0 | 33.1 | 26.6 | 94.3 | 29.1 | **63.7** |
| BGE-Code-v1 | 100.0 | 99.2 | 40.9 | 36.1 | 93.1 | 67.4 | **72.8** |
| BGE-Code-v1 | 100.0 | 99.2 | 40.9 | 36.1 | 93.1 | 67.4 | **72.8** |
### Instructions for Evaluation
@ -195,3 +195,39 @@ print(scores.tolist())
"SWE-bench-Lite": "Given a code snippet containing a bug and a natural language description of the bug or error, retrieve code snippets that demonstrate solutions or fixes for similar bugs or errors (the desired documents)."
}
```
### Evaluation script
#### CoIR
For CoIR, we use the [CoIR](https://github.com/CoIR-team/coir) evaluation script:
```shell
cd ./evaluation/coir_eval
### clone coir
mkdir test
cd ./test
git clone https://github.com/CoIR-team/coir.git
mv ./coir/coir ../
cd ..
rm -rf ./test
### evaluate
bash eval.sh
```
### CodeRAG
For CodeRAG, we use the [CodeRAG](https://github.com/code-rag-bench/code-rag-bench) evaluation script:
```shell
cd ./evaluation/coderag_eval
### clone coderag
git clone https://github.com/code-rag-bench/code-rag-bench.git
## You need prepare environment according to README.md
rm -rf ./code-rag-bench/retrieval/create
cp -r ./test/* ./code-rag-bench/retrieval/
### prepare data
bash prepare_data.sh
### evaluate
bash eval.sh
```

View File

@ -0,0 +1,22 @@
cd ./code-rag-bench/retrieval/
output_dir='result'
for dataset_name in "humaneval" "mbpp" "repoeval" "ds1000_all_completion" "odex_en" "swe-bench-lite"
do
echo "dataset_name: ${dataset_name}"
python main.py \
--embedder_name_or_path BAAI/bge-code-v1 \
--embedder_model_class decoder-only-base \
--query_instruction_format_for_retrieval '<instruct>{}\n<query>{}' \
--embedder_query_max_length 2048 \
--embedder_passage_max_length 2048 \
--trust_remote_code True \
--pooling_method last_token \
--embedder_batch_size 64 \
--devices cuda:0 cuda:1 cuda:2 cuda:3 cuda:4 cuda:5 cuda:6 cuda:7 \
--cache_dir ./cache \
--dataset $dataset_name \
--output_file ../../${output_dir}/${dataset_name}_output.json \
--results_file ../../${output_dir}/${dataset_name}_results.json
done

View File

@ -0,0 +1,7 @@
cd ./code-rag-bench/retrieval/
for dataset_name in "humaneval" "mbpp" "live_code_bench" "ds1000" "odex" "repoeval_repo" "swebench_repo"
do
echo "dataset_name: ${dataset_name}"
PYTHONPATH=./ python create/${dataset_name}.py
done

View File

@ -0,0 +1,34 @@
from typing import List
from dataclasses import dataclass, field
from FlagEmbedding.abc.evaluation import (
AbsEvalModelArgs as CodeRAGEvalModelArgs,
)
@dataclass
class CodeRAGEvalArgs:
dataset: str = field(
default='humaneval',
metadata={
"help": "Task to evaluate. Default: humaneval. Available tasks: "
"['humaneval', 'mbpp', 'live_code_bench', 'ds1000', 'odex', 'repoeval_repo', 'swebench_repo', 'code_search_net']",
}
)
max_length: int = field(
default=2048, metadata={"help": "Max length to use for evaluation."}
)
batch_size: int = field(
default=64, metadata={"help": "Batch size for evaluation."}
)
output_file: str = field(
default="outputs.json",
metadata={
"help": "Specify the filepath if you want to save the retrieval (evaluation) results."
},
)
results_file: str = field(
default="results.json",
metadata={
"help": "Specify the filepath if you want to save the retrieval results."
},
)

View File

@ -0,0 +1,52 @@
import argparse
import datasets
import os
from tqdm import tqdm
import random
from create.utils import save_tsv_dict, save_file_jsonl, load_jsonlines
def document2code(data, split="train"):
data = data[split]
code_search_net_data_queries = []
code_search_net_data_docs = []
code_search_net_data_qrels = []
for item in tqdm(data):
doc = item["func_documentation_string"]
code = item["func_code_string"]
doc_id = "{repository_name}_{func_path_in_repository}_{func_name}_doc".format_map(item)
code_id = "{repository_name}_{func_path_in_repository}_{func_name}_code".format_map(item)
code_search_net_data_queries.append({"_id": doc_id, "text": doc, "metadata": {}})
code_search_net_data_docs.append({"_id": code_id, "title": item["func_name"], "text": code, "metadata": {}})
code_search_net_data_qrels.append({"query-id": doc_id, "corpus-id": code_id, "score": 1})
return code_search_net_data_queries, code_search_net_data_docs, code_search_net_data_qrels
def main():
#### /print debug information to stdout
parser = argparse.ArgumentParser()
parser.add_argument("--language", type=str, default="python", help="codesearch net language")
parser.add_argument("--output_dir", type=str, default="datasets")
args = parser.parse_args()
dataset = datasets.load_dataset("code_search_net", args.language)
path = os.path.join(args.output_dir, "code_search_net_{}".format(args.language))
os.makedirs(path)
os.makedirs(os.path.join(path, "qrels"))
docs = []
queries = []
for split in ["train", "validation", "test"]:
queries_split, docs_split, qrels_split = document2code(dataset, split)
docs += docs_split
queries += queries_split
save_tsv_dict(qrels_split, os.path.join(path, "qrels", "{}.tsv".format(split)), ["query-id", "corpus-id", "score"])
save_file_jsonl(queries, os.path.join(path, "queries.jsonl"))
save_file_jsonl(docs, os.path.join(path, "corpus.jsonl"))
if __name__ == "__main__":
main()

View File

@ -0,0 +1,123 @@
import io
import os
import fcntl
import pathlib
import zipfile
import argparse
import requests
import warnings
import itertools
from tqdm import tqdm
from datasets import load_dataset
from create.utils import save_tsv_dict, save_file_jsonl
# Load dataset
def download_source(source_dir):
src = source_dir / "ds1000.py"
url = "https://github.com/HKUNLP/DS-1000/blob/49c1c543ada8b58138181333cdc62e613204efcf/ds1000.py?raw=true"
lock = src.with_suffix(".lock")
with open(lock, "w") as f_lock:
fcntl.flock(f_lock, fcntl.LOCK_EX)
if not src.exists():
warnings.warn(f"DS-1000 source is being saved to {src}.")
print("Downloading source code...")
r = requests.get(url, stream=True)
with open(src, "wb") as f_src:
f_src.write(r.content)
open(src.parent / "__init__.py", "w").close()
print("Done.")
fcntl.flock(f_lock, fcntl.LOCK_UN)
def download_dataset(source_dir):
path = source_dir / "ds1000_data"
url = "https://github.com/HKUNLP/DS-1000/blob/49c1c543ada8b58138181333cdc62e613204efcf/ds1000_data.zip?raw=true"
lock = path.with_suffix(".lock")
with open(lock, "w") as f_lock:
fcntl.flock(f_lock, fcntl.LOCK_EX)
if not path.exists():
warnings.warn(f"DS-1000 data is being saved to {path}.")
print("Downloading dataset...")
r = requests.get(url, stream=True)
z = zipfile.ZipFile(io.BytesIO(r.content))
z.extractall(source_dir)
print("Done.")
fcntl.flock(f_lock, fcntl.LOCK_UN)
def get_dataset(source_dir, mode: str = "Completion", key: str = "All"):
"""Returns dataset for the task or an iterable of any object, that get_prompt can handle"""
from ds.ds1000 import DS1000Dataset
data = DS1000Dataset(source_dir / "ds1000_data", mode=mode).data
if key == "All":
if mode == "Insertion":
warnings.warn(
"Insertion not supported for Matplotlib. Only running others."
)
data = {k: v for k, v in data.items() if k != "Matplotlib"}
dataset = list(itertools.chain(*data.values()))
else:
dataset = data[key]
return dataset
# Collect queries, docs, and relations
def document2code(data: list):
queries, docs, qrels = [], [], []
# collect doc corpus
code_docs = load_dataset("neulab/docprompting-conala", "docs")["train"]
for i in range(len(code_docs)):
docs.append({
"_id": str(i),
"title": code_docs[i]["doc_id"],
"text": code_docs[i]["doc_content"],
"metadata": {}
})
# load canonical docs
ds1000 = load_dataset("json", data_files={"test": args.canonical_file})["test"]
for idx,item in enumerate(tqdm(data)):
example = item.data
query = example["prompt"]
query_id = f"{example['lib']}_{example['perturbation_origin_id']}"
queries.append({"_id": query_id, "text": query, "metadata": {}})
doc_ids = [doc["title"] for doc in ds1000[idx]["docs"]]
for doc_id in doc_ids:
corpus_id = code_docs["doc_id"].index(doc_id)
corpus_id = str(corpus_id)
qrels.append({"query-id": query_id, "corpus-id": corpus_id, "score": 1})
return queries, docs, qrels
def main():
args.source_dir = pathlib.Path(__file__).parent.parent / args.source_dir
os.makedirs(args.source_dir, exist_ok=True)
download_source(args.source_dir)
download_dataset(args.source_dir)
dataset = get_dataset(args.source_dir, mode=args.mode, key=args.key)
path = os.path.join(args.output_dir, f"ds1000_{args.key.lower()}_{args.mode.lower()}")
os.makedirs(path, exist_ok=True)
os.makedirs(os.path.join(path, "qrels"), exist_ok=True)
queries, docs, qrels = document2code(dataset)
save_tsv_dict(qrels, os.path.join(path, "qrels", "test.tsv"), ["query-id", "corpus-id", "score"])
save_file_jsonl(queries, os.path.join(path, "queries.jsonl"))
save_file_jsonl(docs, os.path.join(path, "corpus.jsonl"))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--source_dir", type=str, default="ds")
parser.add_argument("--output_dir", type=str, default="datasets")
parser.add_argument("--mode", type=str, default="Completion", choices=["Completion", "Insertion"])
parser.add_argument("--key", type=str, default="All",
choices=["All", "Numpy", "Pandas", "Scipy", "Matplotlib", "Sklearn", "Tensorflow", "Pytorch"])
parser.add_argument("--canonical_file", type=str, default="datasets/canonical/ds1000_docs.json")
args = parser.parse_args()
main()

View File

@ -0,0 +1,70 @@
"""Aggregate all code-generation datasets."""
import os
import json
import datasets
import argparse
from create.utils import save_tsv_dict
from create.humaneval import document2code as d2c_humaneval
from create.mbpp import document2code as d2c_mbpp
D2C_FUNC_DICT = {
"humaneval": d2c_humaneval,
"mbpp": d2c_mbpp,
}
SPLIT_DICT = {
"humaneval": ["test"],
"mbpp": ["train", "test", "validation", "prompt"],
}
HF_NAME_DICT = {
"humaneval": "openai_humaneval",
"mbpp": "mbpp",
}
def save_file_jsonl(data, path):
with open(path,'w') as fw:
for item in data:
fw.write(json.dumps(item) + '\n')
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_names", type=str, nargs='+', default=["humaneval", "mbpp"])
parser.add_argument("--output_dir", type=str, default="datasets")
parser.add_argument("--output_name", type=str, default="general-programming")
args = parser.parse_args()
path = os.path.join(args.output_dir, args.output_name)
os.makedirs(path)
os.makedirs(os.path.join(path, "qrels"), exist_ok=True)
split_dict = {}
for dataset_name in args.dataset_names:
for split in SPLIT_DICT[dataset_name]:
if split not in split_dict:
split_dict[split] = []
split_dict[split].append(dataset_name)
dataset_dict = {
dataset_name: datasets.load_dataset(HF_NAME_DICT[dataset_name])
for dataset_name in args.dataset_names
}
docs, queries = [], []
for split, ds_names in split_dict.items():
for ds in ds_names:
dataset = dataset_dict[ds]
queries_split, docs_split, qrels_split = D2C_FUNC_DICT[ds](dataset, split)
docs += docs_split
queries += queries_split
qrels_path = os.path.join(path, "qrels", f"{split}.tsv")
save_tsv_dict(qrels_split, qrels_path, ["query-id", "corpus-id", "score"])
save_file_jsonl(queries, os.path.join(path, "queries.jsonl"))
save_file_jsonl(docs, os.path.join(path, "corpus.jsonl"))
if __name__ == "__main__":
main()

View File

@ -0,0 +1,58 @@
import os
import argparse
import datasets
from tqdm import tqdm
from create.utils import save_tsv_dict, save_file_jsonl
def document2code(data, split="test"):
data = data[split]
queries, docs, qrels = [], [], []
for item in tqdm(data):
doc = item["prompt"]
code = item["prompt"] + '\n' + item["canonical_solution"]
doc_id = "{task_id}_doc".format_map(item)
code_id = "{task_id}_code".format_map(item)
queries.append({"_id": doc_id, "text": doc, "metadata": {}})
docs.append({"_id": code_id, "title": item["entry_point"], "text": code, "metadata": {}})
qrels.append({"query-id": doc_id, "corpus-id": code_id, "score": 1})
return queries, docs, qrels
def main():
dataset = datasets.load_dataset(args.dataset_name)
path = os.path.join(args.output_dir, args.output_name)
os.makedirs(path, exist_ok=True)
os.makedirs(os.path.join(path, "qrels"), exist_ok=True)
queries, docs, qrels = document2code(dataset, split="test")
save_file_jsonl(queries, os.path.join(path, "queries.jsonl"))
save_file_jsonl(docs, os.path.join(path, "corpus.jsonl"))
qrels_path = os.path.join(path, "qrels", "test.tsv")
save_tsv_dict(qrels, qrels_path, ["query-id", "corpus-id", "score"])
# create canonical file if not existent yet
if not os.path.exists(args.canonical_file):
canonical_solutions = []
for doc in docs:
canonical_solutions.append([{
"text": doc["text"], "title": doc["title"]
}])
canonical_dataset = dataset["test"].add_column("docs", canonical_solutions)
canonical_dataset.to_json(args.canonical_file)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_name", type=str, default="openai_humaneval")
parser.add_argument("--output_name", type=str, default="humaneval")
parser.add_argument("--canonical_file", type=str,
default="datasets/canonical/humaneval_solutions.json")
parser.add_argument("--output_dir", type=str, default="datasets")
args = parser.parse_args()
main()

View File

@ -0,0 +1,53 @@
import os
import argparse
import datasets
from tqdm import tqdm
from datasets import load_dataset
from create.utils import save_tsv_dict, save_file_jsonl
def get_queries(data, split="test") -> list[dict]:
queries = [{
"_id": item["question_id"] + '__' + item["contest_id"],
"text": item["question_content"],
"metadata": {}
} for item in data[split]]
return queries
def get_corpus(hf_name: str, cache_dir: str) -> list[dict]:
dataset = load_dataset(hf_name, cache_dir=cache_dir)["train"]
corpus = [
{"_id": i, "text": item["text"], "title": item["title"]}
for i,item in enumerate(dataset)
]
return corpus
def main():
dataset = datasets.load_dataset(args.dataset_name, cache_dir=args.cache_dir)
path = os.path.join(args.output_dir, args.output_name)
os.makedirs(path, exist_ok=True)
os.makedirs(os.path.join(path, "qrels"), exist_ok=True)
queries = get_queries(dataset, split="test")
save_file_jsonl(queries, os.path.join(path, "queries.jsonl"))
docs = get_corpus(args.corpus_name, args.cache_dir)
save_file_jsonl(docs, os.path.join(path, "corpus.jsonl"))
qrels = [] # no ground-truth solutions
qrels_path = os.path.join(path, "qrels", "test.tsv")
save_tsv_dict(qrels, qrels_path, ["query-id", "corpus-id", "score"])
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_name", type=str, default="livecodebench/code_generation")
parser.add_argument("--corpus_name", type=str, default="code-rag-bench/programming-solutions")
parser.add_argument("--cache_dir", type=str, default="/scratch/zhiruow/data")
parser.add_argument("--output_name", type=str, default="livecodebench")
parser.add_argument("--output_dir", type=str, default="datasets")
args = parser.parse_args()
main()

View File

@ -0,0 +1,78 @@
import os
import argparse
import datasets
from tqdm import tqdm
from create.utils import save_tsv_dict, save_file_jsonl
def get_function_name(code: str) -> str:
"""Parse the function name for a code snippet string."""
lines = code.split('\n')
for line in lines:
if line.lstrip().startswith("def "):
break
func_name = line.lstrip()[4: ]
func_name = func_name.split('(')[0]
return func_name
def document2code(data, split="test"):
data = data[split]
queries, docs, qrels = [], [], []
for item in tqdm(data):
doc = item["text"]
code = "# " + item["text"] + '\n' + item["code"]
doc_id = "{task_id}_doc".format_map(item)
code_id = "{task_id}_code".format_map(item)
queries.append({"_id": doc_id, "text": doc, "metadata": {}})
docs.append({"_id": code_id, "title": get_function_name(item["code"]), "text": code, "metadata": {}})
qrels.append({"query-id": doc_id, "corpus-id": code_id, "score": 1})
return queries, docs, qrels
def main():
dataset = datasets.load_dataset(args.dataset_name)
path = os.path.join(args.output_dir, args.output_name)
os.makedirs(path, exist_ok=True)
os.makedirs(os.path.join(path, "qrels"), exist_ok=True)
docs, queries = [], []
for split in args.splits:
queries_split, docs_split, qrels_split = document2code(dataset, split)
docs += docs_split
queries += queries_split
qrels_path = os.path.join(path, "qrels", f"{split}.tsv")
save_tsv_dict(qrels_split, qrels_path, ["query-id", "corpus-id", "score"])
# create canonical file for test split if not existent yet
if split == "test" and (not os.path.exists(args.canonical_file)):
canonical_solutions = []
for doc in docs_split:
canonical_solutions.append([{
"text": doc["text"], "title": doc["title"]
}])
canonical_dataset = dataset["test"].add_column("docs", canonical_solutions)
canonical_dataset.to_json(args.canonical_file)
save_file_jsonl(queries, os.path.join(path, "queries.jsonl"))
save_file_jsonl(docs, os.path.join(path, "corpus.jsonl"))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_name", type=str, default="google-research-datasets/mbpp")
# parser.add_argument("--dataset_name", type=str, default="code-rag-bench/mbpp")
parser.add_argument("--splits", type=str, default=["train", "validation", "test"],
choices=["train", "validation", "test", "prompt"])
parser.add_argument("--output_name", type=str, default="mbpp")
parser.add_argument("--output_dir", type=str, default="datasets")
parser.add_argument("--canonical_file", type=str,
default="datasets/canonical/mbpp_solutions.json")
args = parser.parse_args()
main()

View File

@ -0,0 +1,76 @@
import os
import re
import random
import argparse
import datasets
from tqdm import tqdm
from collections import Counter
from datasets import load_dataset
from create.utils import save_tsv_dict, save_file_jsonl
def document2code(data, split="test"):
data = data[split]
queries, docs, qrels = [], [], []
# build doc corpus
code_docs = load_dataset("neulab/docprompting-conala", "docs")["train"]
for i in range(len(code_docs)):
docs.append({
"_id": str(i),
"title": code_docs[i]["doc_id"],
"text": code_docs[i]["doc_content"],
"metadata": {}
})
# load canonical docs
odex = load_dataset("json", data_files={"test": args.canonical_file})["test"]
# collect queries and query-doc matching
for idx,item in enumerate(tqdm(data)):
query = item["intent"]
query_id = f"{idx}_{item['task_id']}"
queries.append({"_id": query_id, "text": query, "metadata": {}})
doc_ids = [doc["title"] for doc in odex[idx]["docs"]]
for doc_id in doc_ids:
corpus_id = code_docs["doc_id"].index(doc_id)
corpus_id = str(corpus_id)
qrels.append({"query-id": query_id, "corpus-id": corpus_id, "score": 1})
return queries, docs, qrels
def main():
if '_' in args.dataset_name:
dataset_name = args.dataset_name.split('_')[0]
language = args.dataset_name.split('_')[1]
else:
dataset_name = args.dataset_name
language = 'en'
dataset = datasets.load_dataset(dataset_name, language) # english version by default
path = os.path.join(args.output_dir, args.output_name.replace('en', language))
os.makedirs(path, exist_ok=True)
os.makedirs(os.path.join(path, "qrels"), exist_ok=True)
docs, queries = [], []
for split in ["test"]:
queries_split, docs_split, qrels_split = document2code(dataset, split)
docs += docs_split
queries += queries_split
save_tsv_dict(qrels_split, os.path.join(path, "qrels", "{}.tsv".format(split)), ["query-id", "corpus-id", "score"])
save_file_jsonl(queries, os.path.join(path, "queries.jsonl"))
save_file_jsonl(docs, os.path.join(path, "corpus.jsonl"))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_name", type=str, default="neulab/odex")
parser.add_argument("--output_name", type=str, default="odex_en")
parser.add_argument("--canonical_file", type=str, default="datasets/canonical/odex_docs.json")
parser.add_argument("--output_dir", type=str, default="datasets")
args = parser.parse_args()
main()

View File

@ -0,0 +1,306 @@
import io
import os
import glob
import json
import argparse
import requests
import zipfile
from collections import defaultdict
from create.utils import save_tsv_dict, save_file_jsonl
REPOs_line_and_api = [
'huggingface_diffusers',
'nerfstudio-project_nerfstudio',
'awslabs_fortuna',
'huggingface_evaluate',
'google_vizier',
'alibaba_FederatedScope',
'pytorch_rl',
'opendilab_ACE',
]
REPOs_function = [
"amazon-science_patchcore-inspection",
"deepmind_tracr",
"facebookresearch_omnivore",
"google_lightweight_mmm",
"lucidrains_imagen-pytorch",
"maxhumber_redframes",
]
REPO_DIRs = {
"api": "repositories/line_and_api_level",
"line": "repositories/line_and_api_level",
"function": "repositories/function_level",
}
def iterate_repository(base_dir: str, repo: str) -> dict:
pattern = os.path.join(f'{base_dir}/{repo}', "**", "*.py")
files = glob.glob(pattern, recursive=True)
skipped_files = []
loaded_code_files = dict()
base_dir_list = os.path.normpath(base_dir).split(os.sep)
for fname in files:
try:
code = open(fname, 'r', encoding='utf8').read()
fpath_tuple = tuple(os.path.normpath(fname).split(os.sep)[len(base_dir_list):])
loaded_code_files[fpath_tuple]= code
except Exception as e:
skipped_files.append((fname, e))
continue
if len(skipped_files) > 0:
print(f"Skipped {len(skipped_files)} out of {len(files)} files due to I/O errors")
for fname, e in skipped_files:
print(f"{fname}: {e}")
return loaded_code_files
def window_overlap(span: tuple, target_span: tuple) -> bool:
if span[0] >= target_span[1] or span[1] <= target_span[0]:
return False
return True
class RepoWindowMaker:
def __init__(self, base_dir, repo, tasks, window_size, slice_size):
self.base_dir = base_dir
self.repo = repo
self.window_size = window_size
self.slice_size = slice_size
self.slice_step = 1 if window_size // slice_size == 0 else window_size // slice_size
self.tasks = tasks
self.source_code_files = iterate_repository(base_dir, repo)
def _buid_windows_for_a_file(self, fpath_tuple, code):
code_windows = []
code_lines = code.splitlines()
delta_size = self.window_size // 2
for line_no in range(0, len(code_lines), self.slice_step): # line_no starts from 0
start_line_no = max(0, line_no - delta_size)
end_line_no = min(len(code_lines), line_no + self.window_size - delta_size)
window_lines = [i for i in code_lines[start_line_no:end_line_no]]
if not window_lines: # all empty lines
continue
window_text = '\n'.join(window_lines)
code_windows.append({
'context': window_text,
'metadata': {
'fpath_tuple': fpath_tuple,
'line_no': line_no,
'start_line_no': start_line_no,
'end_line_no': end_line_no,
'window_size': self.window_size,
'repo': self.repo,
'slice_size': self.slice_size,
}
})
return code_windows
def _merge_windows_with_same_context(self, code_windows):
merged_code_windows = defaultdict(list)
for code_window in code_windows:
context = code_window['context']
metadata = code_window['metadata']
merged_code_windows[context].append(metadata)
json_lines = []
for context, metadata_list in merged_code_windows.items():
json_lines.append({
'context': context,
'metadata': metadata_list
})
return json_lines
def build_windows(self):
all_code_windows = []
for fpath_tuple, code in self.source_code_files.items():
all_code_windows += self._buid_windows_for_a_file(fpath_tuple, code)
merged_code_windows = self._merge_windows_with_same_context(all_code_windows)
print(f'build {len(merged_code_windows)} windows for {self.repo} with window size {self.window_size} and slice {self.slice_size}')
ground_truth_indices = {}
for task in self.tasks:
fpath_tuple = tuple(task['metadata']['fpath_tuple'])
line_no = task['metadata']['line_no']
start_line_no = task['metadata']['context_start_lineno']
for i, window in enumerate(merged_code_windows):
if window["metadata"][0]["fpath_tuple"] != fpath_tuple:
continue
if any([
window_overlap(
(sub_window["start_line_no"], sub_window["end_line_no"]),
(start_line_no, line_no + 1)
)
for sub_window in window["metadata"]
]):
if i not in ground_truth_indices:
ground_truth_indices[i] = []
ground_truth_indices[i].append(task["metadata"]["task_id"])
return merged_code_windows, ground_truth_indices
def download_data(directory: str = "repoeval"):
os.makedirs(directory, exist_ok=True)
datasets_dir = os.path.join(directory, "datasets")
repos_lineapi_dir = os.path.join(directory, "repositories", "line_and_api_level")
repos_function_dir = os.path.join(directory, "repositories", "function_level")
print(f"Start downloading the necessary `datasets` and `repositories` files.")
if not os.path.exists(datasets_dir):
print(f"Start downloading the `datasets`.")
datasets_url = "https://github.com/microsoft/CodeT/raw/main/RepoCoder/datasets/datasets.zip"
r = requests.get(datasets_url, stream=True)
z = zipfile.ZipFile(io.BytesIO(r.content))
z.extractall(datasets_dir)
print("Finished downloading the `datasets` files.")
if not os.path.exists(repos_lineapi_dir):
print(f"Start downloading the `repositories` (line_and_api).")
repos_lineapi_url = "https://github.com/microsoft/CodeT/raw/main/RepoCoder/repositories/line_and_api_level.zip"
r = requests.get(repos_lineapi_url, stream=True)
z = zipfile.ZipFile(io.BytesIO(r.content))
z.extractall(repos_lineapi_dir)
if not os.path.exists(repos_function_dir):
print(f"Start downloading the `repositories` (function).")
# repos_function_url = "https://github.com/microsoft/CodeT/raw/main/RepoCoder/repositories/function_level.zip"
repos_function_url = "https://github.com/Veronicium/repoeval_debug/raw/main/function_level.zip"
r = requests.get(repos_function_url, stream=True)
z = zipfile.ZipFile(io.BytesIO(r.content))
z.extractall(repos_function_dir)
print("Finished downloading the `repositories` files.")
def repo2code(
repo: str, data_cache_dir: str,
split: str, context_length: str,
window_size: int, slice_size: int
):
# load test examples
file_name = f"{split}_level_completion_{context_length}_context_codex.test.jsonl"
if split == 'function':
file_name = file_name.replace('.test.jsonl', '.test.clean.jsonl')
task_path = os.path.join(data_cache_dir, "datasets", file_name)
tasks = [json.loads(l.rstrip()) for l in open(task_path, 'r')]
tasks = [task for task in tasks if repo == task['metadata']['task_id'].split('/')[0]]
# collect queries
queries = []
for task in tasks:
query_id = task["metadata"]["task_id"]
# text = '\n'.join(task["prompt"].split('\n')[-2:])
text = task["prompt"]
metadata = task["metadata"]
queries.append({"_id": query_id, "text": text, "metadata": metadata})
base_dir = os.path.join(data_cache_dir, REPO_DIRs[split])
repo_window_maker = RepoWindowMaker(base_dir, repo, tasks, window_size, slice_size)
windows, ground_truth_indices = repo_window_maker.build_windows()
corpus, qrels = [], []
query_id2gt = {task['metadata']['task_id']:[] for task in tasks}
for i, window in enumerate(windows):
path = '-'.join(window["metadata"][0]["fpath_tuple"])
line = f"{window['metadata'][0]['start_line_no']}-{window['metadata'][-1]['end_line_no']}"
corpus_id = f"{repo}_{path}_{line}"
corpus.append({
"_id": corpus_id, "title": path,
"text": window["context"], "metadata": window["metadata"]
})
if i in ground_truth_indices:
for query_id in ground_truth_indices[i]:
qrels.append({"query-id": query_id, "corpus-id": corpus_id, "score": 1})
query_id2gt[query_id].append({"title": corpus_id.replace('_', '/'), "text": window["context"]})
return queries, corpus, qrels, query_id2gt
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--output_dir", type=str, default="datasets")
parser.add_argument("--results_dir", type=str, default="results")
parser.add_argument("--split", type=str, required=True, choices=["api", "line", "function"])
parser.add_argument("--context_length", type=str, default="1k", choices=["1k", "2k", "4k"])
parser.add_argument("--data_cache_dir", type=str, default="output/repoeval")
parser.add_argument("--window_size", type=int, default=20)
parser.add_argument("--slice_size", type=int, default=2)
args = parser.parse_args()
download_data(args.data_cache_dir)
path = os.path.join(args.output_dir, "repoeval", args.split)
os.makedirs(path, exist_ok=True)
REPOs = REPOs_function if args.split == "function" else REPOs_line_and_api
file_name = f"{args.split}_level_completion_{args.context_length}_context_codex.test.jsonl"
data_path = os.path.join(args.data_cache_dir, "datasets", file_name)
data = [json.loads(l.rstrip()) for l in open(data_path, 'r')]
# preprocess function completion data (the data in the RepoCoder repo isn't correctly formatted)
if args.split == 'function':
repo2idx = {}
for task in data:
repo = task['metadata']['task_id'].replace('--', '_').split('/')[0]
if repo not in repo2idx:
repo2idx[repo] = 0
task['metadata']['task_id'] = task['metadata']['task_id'].replace('--', '_').replace('idx', str(repo2idx[repo]))
task['metadata']['line_no'] = task['metadata']['lineno']
repo2idx[repo] += 1
new_data_path = data_path.replace('.test.jsonl', '.test.clean.jsonl')
with open(new_data_path, 'w') as f:
for task in data:
repo = task['metadata']['task_id'].split('/')[0]
if repo not in REPOs:
continue
f.write(json.dumps(task) + '\n')
data = [json.loads(l.rstrip()) for l in open(new_data_path, 'r')]
# build query, docs, and qrels for each repository
queries, corpus, qrels = [], [], []
query_id2gt = {}
for repo in REPOs:
repo_queries, repo_corpus, repo_qrels, repo_query_id2gt = repo2code(
repo, args.data_cache_dir,
args.split, args.context_length,
args.window_size, args.slice_size
)
queries += repo_queries
corpus += repo_corpus
qrels += repo_qrels
query_id2gt.update(repo_query_id2gt)
save_file_jsonl(queries, os.path.join(path, "queries.jsonl"))
save_file_jsonl(corpus, os.path.join(path, "corpus.jsonl"))
save_tsv_dict(qrels, os.path.join(path, "qrels", "test.tsv"), ["query-id", "corpus-id", "score"])
gt_data = []
for example in data:
query_id = example['metadata']['task_id']
gt = query_id2gt[query_id]
new_example = {
"prompt": example["prompt"],
"reference": example["metadata"]["ground_truth"],
"docs": gt[:10],
"metadata": {k:v for k,v in example["metadata"].items() if k != "ground_truth"},
}
gt_data.append(new_example)
results_file = os.path.join(args.results_dir, f"repoeval-{args.split}-{args.context_length}-gt.jsonl")
with open(results_file, "w") as fw:
for ex in gt_data:
fw.write(json.dumps(ex) + "\n")
results_file = os.path.join(args.results_dir, f"repoeval-{args.split}-{args.context_length}-infile.jsonl")
with open(results_file, "w") as fw:
for ex in gt_data:
ex = {k:v for k,v in ex.items() if k != "docs"}
ex["docs"] = []
fw.write(json.dumps(ex) + "\n")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,306 @@
import io
import os
import glob
import json
import argparse
import requests
import zipfile
from collections import defaultdict
from create.utils import save_tsv_dict, save_file_jsonl
REPOs_line_and_api = [
'huggingface_diffusers',
'nerfstudio-project_nerfstudio',
'awslabs_fortuna',
'huggingface_evaluate',
'google_vizier',
'alibaba_FederatedScope',
'pytorch_rl',
'opendilab_ACE',
]
REPOs_function = [
"amazon-science_patchcore-inspection",
"deepmind_tracr",
"facebookresearch_omnivore",
"google_lightweight_mmm",
"lucidrains_imagen-pytorch",
"maxhumber_redframes",
]
REPO_DIRs = {
"api": "repositories/line_and_api_level",
"line": "repositories/line_and_api_level",
"function": "repositories/function_level",
}
def iterate_repository(base_dir: str, repo: str) -> dict:
pattern = os.path.join(f'{base_dir}/{repo}', "**", "*.py")
files = glob.glob(pattern, recursive=True)
skipped_files = []
loaded_code_files = dict()
base_dir_list = os.path.normpath(base_dir).split(os.sep)
for fname in files:
try:
code = open(fname, 'r', encoding='utf8').read()
fpath_tuple = tuple(os.path.normpath(fname).split(os.sep)[len(base_dir_list):])
loaded_code_files[fpath_tuple]= code
except Exception as e:
skipped_files.append((fname, e))
continue
if len(skipped_files) > 0:
print(f"Skipped {len(skipped_files)} out of {len(files)} files due to I/O errors")
for fname, e in skipped_files:
print(f"{fname}: {e}")
return loaded_code_files
def window_overlap(span: tuple, target_span: tuple) -> bool:
if span[0] >= target_span[1] or span[1] <= target_span[0]:
return False
return True
class RepoWindowMaker:
def __init__(self, base_dir, repo, tasks, window_size, slice_size):
self.base_dir = base_dir
self.repo = repo
self.window_size = window_size
self.slice_size = slice_size
self.slice_step = 1 if window_size // slice_size == 0 else window_size // slice_size
self.tasks = tasks
self.source_code_files = iterate_repository(base_dir, repo)
def _buid_windows_for_a_file(self, fpath_tuple, code):
code_windows = []
code_lines = code.splitlines()
delta_size = self.window_size // 2
for line_no in range(0, len(code_lines), self.slice_step): # line_no starts from 0
start_line_no = max(0, line_no - delta_size)
end_line_no = min(len(code_lines), line_no + self.window_size - delta_size)
window_lines = [i for i in code_lines[start_line_no:end_line_no]]
if not window_lines: # all empty lines
continue
window_text = '\n'.join(window_lines)
code_windows.append({
'context': window_text,
'metadata': {
'fpath_tuple': fpath_tuple,
'line_no': line_no,
'start_line_no': start_line_no,
'end_line_no': end_line_no,
'window_size': self.window_size,
'repo': self.repo,
'slice_size': self.slice_size,
}
})
return code_windows
def _merge_windows_with_same_context(self, code_windows):
merged_code_windows = defaultdict(list)
for code_window in code_windows:
context = code_window['context']
metadata = code_window['metadata']
merged_code_windows[context].append(metadata)
json_lines = []
for context, metadata_list in merged_code_windows.items():
json_lines.append({
'context': context,
'metadata': metadata_list
})
return json_lines
def build_windows(self):
all_code_windows = []
for fpath_tuple, code in self.source_code_files.items():
all_code_windows += self._buid_windows_for_a_file(fpath_tuple, code)
merged_code_windows = self._merge_windows_with_same_context(all_code_windows)
print(f'build {len(merged_code_windows)} windows for {self.repo} with window size {self.window_size} and slice {self.slice_size}')
ground_truth_indices = {}
for task in self.tasks:
fpath_tuple = tuple(task['metadata']['fpath_tuple'])
line_no = task['metadata']['line_no']
start_line_no = task['metadata']['context_start_lineno']
for i, window in enumerate(merged_code_windows):
# print(window["metadata"][0]["fpath_tuple"], fpath_tuple)
if window["metadata"][0]["fpath_tuple"] != fpath_tuple and ' '.join(list(window["metadata"][0]["fpath_tuple"])) != ' '.join(list(fpath_tuple)):
continue
# print(1)
if any([
window_overlap(
(sub_window["start_line_no"], sub_window["end_line_no"]),
(start_line_no, line_no + 1)
)
for sub_window in window["metadata"]
]):
# print('test')
if i not in ground_truth_indices:
ground_truth_indices[i] = []
ground_truth_indices[i].append(task["metadata"]["task_id"])
# sys.exit()
return merged_code_windows, ground_truth_indices
def download_data(directory: str = "repoeval"):
os.makedirs(directory, exist_ok=True)
datasets_dir = os.path.join(directory, "datasets")
repos_function_dir = os.path.join(directory, "repositories", "function_level")
print(f"Start downloading the necessary `datasets` and `repositories` files.")
if not os.path.exists(datasets_dir):
print(f"Start downloading the `datasets`.")
datasets_url = "https://github.com/microsoft/CodeT/raw/main/RepoCoder/datasets/datasets.zip"
r = requests.get(datasets_url, stream=True)
z = zipfile.ZipFile(io.BytesIO(r.content))
z.extractall(datasets_dir)
print("Finished downloading the `datasets` files.")
import shutil
shutil.rmtree(repos_function_dir)
if not os.path.exists(repos_function_dir):
print(f"Start downloading the `repositories` (function).")
repos_function_url = "https://github.com/microsoft/CodeT/raw/main/RepoCoder/repositories/function_level.zip"
# repos_function_url = "https://github.com/Veronicium/repoeval_debug/raw/main/function_level.zip"
r = requests.get(repos_function_url, stream=True)
z = zipfile.ZipFile(io.BytesIO(r.content))
z.extractall(repos_function_dir)
print("Finished downloading the `repositories` files.")
def repo2code(
repo: str, tasks: list[dict], data_cache_dir: str,
split: str, context_length: str,
window_size: int, slice_size: int
):
# collect queries
queries = []
for task in tasks:
query_id = task["metadata"]["task_id"]
# text = '\n'.join(task["prompt"].split('\n')[-2:])
text = task["prompt"]
metadata = task["metadata"]
queries.append({"_id": query_id, "text": text, "metadata": metadata})
base_dir = os.path.join(data_cache_dir, REPO_DIRs[split])
repo_window_maker = RepoWindowMaker(base_dir, repo, tasks, window_size, slice_size)
windows, ground_truth_indices = repo_window_maker.build_windows()
corpus, qrels = [], []
query_id2gt = {task['metadata']['task_id']:[] for task in tasks}
for i, window in enumerate(windows):
path = '-'.join(window["metadata"][0]["fpath_tuple"])
line = f"{window['metadata'][0]['start_line_no']}-{window['metadata'][-1]['end_line_no']}"
corpus_id = f"{repo}_{path}_{line}"
corpus.append({
"_id": corpus_id, "title": path,
"text": window["context"], "metadata": window["metadata"]
})
# print(windows, ground_truth_indices)
if i in ground_truth_indices:
for query_id in ground_truth_indices[i]:
qrels.append({"query-id": query_id, "corpus-id": corpus_id, "score": 1})
query_id2gt[query_id].append({"title": corpus_id.replace('_', '/'), "text": window["context"]})
return queries, corpus, qrels, query_id2gt
def main():
download_data(args.data_cache_dir)
REPOs = REPOs_function if args.split == "function" else REPOs_line_and_api
file_name = f"{args.split}_level_completion_{args.context_length}_context_codex.test.jsonl"
data_path = os.path.join(args.data_cache_dir, "datasets", file_name)
data = [json.loads(l.rstrip()) for l in open(data_path, 'r')]
# preprocess function completion data (the data in the RepoCoder repo isn't correctly formatted)
if args.split == 'function':
repo2idx = {}
for task in data:
repo = task['metadata']['task_id'].replace('--', '_').split('/')[0]
if repo not in repo2idx:
repo2idx[repo] = 0
task['metadata']['task_id'] = task['metadata']['task_id'].replace('--', '_').replace('idx', str(repo2idx[repo]))
task['metadata']['line_no'] = task['metadata']['lineno']
repo2idx[repo] += 1
new_data_path = data_path.replace('.test.jsonl', '.test.clean.jsonl')
with open(new_data_path, 'w') as f:
for task in data:
repo = task['metadata']['task_id'].split('/')[0]
if repo not in REPOs:
continue
f.write(json.dumps(task) + '\n')
data = [json.loads(l.rstrip()) for l in open(new_data_path, 'r')]
# group data instances by repository
data_dict = {}
for ex in data:
repo_name = ex["metadata"]["task_id"]
repo_name = repo_name.split('/')[0]
if repo_name not in data_dict:
data_dict[repo_name] = []
data_dict[repo_name].append(ex)
# build query, docs, and qrels for each repository
for repo in REPOs:
queries, corpus, qrels, query_id2gt = repo2code(
repo, data_dict[repo], args.data_cache_dir,
args.split, args.context_length,
args.window_size, args.slice_size
)
print(len(queries))
if len(qrels) == 0:
print(repo)
# sys.exit()
continue
# sys.exit()
path = os.path.join(args.output_dir, f"repoeval__{repo}")
os.makedirs(path, exist_ok=True)
save_file_jsonl(queries, os.path.join(path, "queries.jsonl"))
save_file_jsonl(corpus, os.path.join(path, "corpus.jsonl"))
save_tsv_dict(qrels, os.path.join(path, "qrels", "test.tsv"), ["query-id", "corpus-id", "score"])
gt_data = []
for example in data_dict[repo]:
query_id = example['metadata']['task_id']
gt = query_id2gt[query_id]
new_example = {
"prompt": example["prompt"],
"reference": example["metadata"]["ground_truth"],
"docs": gt[:10],
"metadata": {k:v for k,v in example["metadata"].items() if k != "ground_truth"},
}
gt_data.append(new_example)
os.makedirs(args.results_dir, exist_ok=True)
results_file = os.path.join(args.results_dir, f"repoeval-{args.split}-{repo}-{args.context_length}-gt.jsonl")
with open(results_file, "w") as fw:
for ex in gt_data:
fw.write(json.dumps(ex) + "\n")
results_file = os.path.join(args.results_dir, f"repoeval-{args.split}-{repo}-{args.context_length}-infile.jsonl")
with open(results_file, "w") as fw:
for ex in gt_data:
ex = {k:v for k,v in ex.items() if k != "docs"}
ex["docs"] = []
fw.write(json.dumps(ex) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--output_dir", type=str, default="datasets")
parser.add_argument("--results_dir", type=str, default="results")
parser.add_argument("--split", type=str, default="function", choices=["function"])
parser.add_argument("--context_length", type=str, default="2k", choices=["1k", "2k", "4k"])
parser.add_argument("--data_cache_dir", type=str, default="output/repoeval")
parser.add_argument("--window_size", type=int, default=50)
parser.add_argument("--slice_size", type=int, default=5)
args = parser.parse_args()
main()

View File

@ -0,0 +1,247 @@
import os
import re
import chardet
import unidiff
import argparse
import datasets
import traceback
import subprocess
from git import Repo
from tqdm import tqdm
from pathlib import Path
from tempfile import TemporaryDirectory
from create.utils import save_tsv_dict, save_file_jsonl
# %% Get oracle file contents
# get oracle file contents from the repo
class ContextManager:
def __init__(self, repo_path, base_commit, verbose=False):
self.repo_path = Path(repo_path).resolve().as_posix()
self.old_dir = os.getcwd()
self.base_commit = base_commit
self.verbose = verbose
def __enter__(self):
os.chdir(self.repo_path)
cmd = f"git reset --hard {self.base_commit} && git clean -fdxq"
if self.verbose:
subprocess.run(cmd, shell=True, check=True)
else:
subprocess.run(
cmd,
shell=True,
check=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
return self
def get_environment(self):
raise NotImplementedError() # TODO: activate conda environment and return the environment file
def get_readme_files(self):
files = os.listdir(self.repo_path)
files = list(filter(lambda x: os.path.isfile(x), files))
files = list(filter(lambda x: x.lower().startswith("readme"), files))
return files
def __exit__(self, exc_type, exc_val, exc_tb):
os.chdir(self.old_dir)
class AutoContextManager(ContextManager):
"""Automatically clones the repo if it doesn't exist"""
def __init__(self, instance, root_dir=None, verbose=False, token=None):
if token is None:
token = os.environ.get("GITHUB_TOKEN", "git")
self.tempdir = None
if root_dir is None:
self.tempdir = TemporaryDirectory()
root_dir = self.tempdir.name
self.root_dir = root_dir
repo_dir = os.path.join(self.root_dir, instance["repo"].replace("/", "__"))
if not os.path.exists(repo_dir):
repo_url = (
f"https://{token}@github.com/swe-bench/"
+ instance["repo"].replace("/", "__")
+ ".git"
)
if verbose:
print(f"Cloning {instance['repo']} to {root_dir}")
Repo.clone_from(repo_url, repo_dir)
super().__init__(repo_dir, instance["base_commit"], verbose=verbose)
self.instance = instance
def __exit__(self, exc_type, exc_val, exc_tb):
if self.tempdir is not None:
self.tempdir.cleanup()
return super().__exit__(exc_type, exc_val, exc_tb)
def ingest_files(filenames):
files_dict = dict()
for filename in filenames:
with open(filename) as f:
content = f.read()
files_dict[filename] = content
return files_dict
def get_oracle_filenames(instance):
"""
Returns the filenames that are changed in the patch
"""
source_files = {
patch_file.source_file.split("a/", 1)[-1]
for patch_file in unidiff.PatchSet(instance["patch"])
}
gold_docs = set()
for source_file in source_files:
gold_docs.add(source_file)
return gold_docs
# get all file contents from the repo
def is_test(name, test_phrases=None):
if test_phrases is None:
test_phrases = ["test", "tests", "testing"]
words = set(re.split(r" |_|\/|\.", name.lower()))
return any(word in words for word in test_phrases)
def list_files(root_dir, include_tests=False):
files = []
for filename in Path(root_dir).rglob("*.py"):
if not include_tests and is_test(filename.as_posix()):
continue
files.append(filename.relative_to(root_dir).as_posix())
return files
def detect_encoding(filename):
"""
Detect the encoding of a file
"""
with open(filename, "rb") as file:
rawdata = file.read()
return chardet.detect(rawdata)["encoding"]
def ingest_directory_contents(root_dir, include_tests=False):
files_content = {}
for relative_path in list_files(root_dir, include_tests=include_tests):
filename = os.path.join(root_dir, relative_path)
encoding = detect_encoding(filename)
if encoding is None:
content = "[BINARY DATA FILE]"
else:
try:
with open(filename, encoding=encoding) as file:
content = file.read()
except (UnicodeDecodeError, LookupError):
content = "[BINARY DATA FILE]"
files_content[relative_path] = content
return files_content
def get_file_contents(input_instances, verbose: bool = False, tmp_dir: str = "/scratch"):
orig_dir = os.getcwd()
with TemporaryDirectory(dir=tmp_dir if os.path.exists(tmp_dir) else "/tmp") as root_dir:
for instance_id, instance in tqdm(
input_instances.items(),
total=len(input_instances),
desc="Getting file contents",
):
try:
with AutoContextManager(instance, root_dir, verbose=verbose) as cm:
readmes = cm.get_readme_files()
instance["readmes"] = ingest_files(readmes)
instance["oracle_file_contents"] = ingest_files(get_oracle_filenames(instance))
instance["file_contents"] = ingest_directory_contents(cm.repo_path)
assert all([
okey in instance["file_contents"]
for okey in instance["oracle_file_contents"].keys()
])
except Exception as e:
print(f"Failed on instance {instance_id}", e)
traceback.print_exc()
finally:
# if AutoContextManager fails to exit properly future exits will return the wrong directory
os.chdir(orig_dir)
os.chdir(orig_dir)
# %% Get queries, docs, and qrels
def document2code(data, split: str = "test"):
subset = data[split]
if args.num_examples is not None:
import random
indices = random.sample([i for i in range(len(subset))], args.num_examples)
subset = subset.select(indices)
print(subset)
# get queries for each example
queries = [
{
"_id": item["instance_id"],
"text": item["problem_statement"],
"metadata": {}
}
for item in subset
]
subset_dict = {x["instance_id"]: x for x in subset}
get_file_contents(subset_dict, tmp_dir=args.tmp_dir)
# collect all docs, i.e., code chunks from the repo
docs = []
for instance_id, instance in subset_dict.items():
print(f"Instance #{instance_id}: {len(instance['oracle_file_contents'])} oracle / {len(instance['file_contents'])} files")
for filename, content in instance["file_contents"].items():
docs.append({
"_id": f"{instance_id}_{filename}",
"title": filename,
"text": content,
"metadata": {},
})
# find ground-truth docs for each example
qrels = []
for instance_id, instance in subset_dict.items():
for filename, content in instance["oracle_file_contents"].items():
qrels.append({
"query-id": instance_id,
"corpus-id": f"{instance_id}_{filename}",
"score": 1
})
return queries, docs, qrels
def main():
dataset = datasets.load_dataset(args.dataset_name, cache_dir=args.cache_dir)
name = "swe-bench"
if "lite" in args.dataset_name.lower():
name += "-lite"
path = os.path.join(args.output_dir, name)
os.makedirs(path, exist_ok=True)
os.makedirs(os.path.join(path, "qrels"), exist_ok=True)
queries, docs, qrels = document2code(dataset, split="test")
save_file_jsonl(queries, os.path.join(path, "queries.jsonl"))
save_file_jsonl(docs, os.path.join(path, "corpus.jsonl"))
qrels_path = os.path.join(path, "qrels", "test.tsv")
save_tsv_dict(qrels, qrels_path, ["query-id", "corpus-id", "score"])
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_name", type=str, default="princeton-nlp/SWE-bench_Lite",
choices=["princeton-nlp/SWE-bench", "princeton-nlp/SWE-bench_Lite"])
parser.add_argument("--cache_dir", type=str, default="/scratch/zhiruow/data")
parser.add_argument("--tmp_dir", type=str, default="/scratch/zhiruow/tmp")
parser.add_argument("--output_dir", type=str, default="datasets")
parser.add_argument("--num_examples", type=int, default=None)
args = parser.parse_args()
main()

View File

@ -0,0 +1,263 @@
import os
import re
import chardet
import unidiff
import argparse
import datasets
import traceback
import subprocess
from git import Repo
from tqdm import tqdm
from pathlib import Path
from tempfile import TemporaryDirectory
from create.utils import save_tsv_dict, save_file_jsonl
# %% Get oracle file contents
# get oracle file contents from the repo
class ContextManager:
def __init__(self, repo_path, base_commit, verbose=False):
self.repo_path = Path(repo_path).resolve().as_posix()
self.old_dir = os.getcwd()
self.base_commit = base_commit
self.verbose = verbose
def __enter__(self):
os.chdir(self.repo_path)
cmd = f"git reset --hard {self.base_commit} && git clean -fdxq"
if self.verbose:
subprocess.run(cmd, shell=True, check=True)
else:
subprocess.run(
cmd,
shell=True,
check=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
return self
def get_environment(self):
raise NotImplementedError() # TODO: activate conda environment and return the environment file
def get_readme_files(self):
files = os.listdir(self.repo_path)
files = list(filter(lambda x: os.path.isfile(x), files))
files = list(filter(lambda x: x.lower().startswith("readme"), files))
return files
def __exit__(self, exc_type, exc_val, exc_tb):
os.chdir(self.old_dir)
class AutoContextManager(ContextManager):
"""Automatically clones the repo if it doesn't exist"""
def __init__(self, instance, root_dir=None, verbose=False, token=None):
if token is None:
token = os.environ.get("GITHUB_TOKEN", "git")
self.tempdir = None
if root_dir is None:
self.tempdir = TemporaryDirectory()
root_dir = self.tempdir.name
self.root_dir = root_dir
repo_dir = os.path.join(self.root_dir, instance["repo"].replace("/", "__"))
if not os.path.exists(repo_dir):
repo_url = (
f"https://{token}@github.com/swe-bench/"
+ instance["repo"].replace("/", "__")
+ ".git"
)
if verbose:
print(f"Cloning {instance['repo']} to {root_dir}")
Repo.clone_from(repo_url, repo_dir)
super().__init__(repo_dir, instance["base_commit"], verbose=verbose)
self.instance = instance
def __exit__(self, exc_type, exc_val, exc_tb):
if self.tempdir is not None:
self.tempdir.cleanup()
return super().__exit__(exc_type, exc_val, exc_tb)
def ingest_files(filenames):
files_dict = dict()
for filename in filenames:
with open(filename) as f:
content = f.read()
files_dict[filename] = content
return files_dict
def get_oracle_filenames(instance):
"""
Returns the filenames that are changed in the patch
"""
source_files = {
patch_file.source_file.split("a/", 1)[-1]
for patch_file in unidiff.PatchSet(instance["patch"])
}
gold_docs = set()
for source_file in source_files:
gold_docs.add(source_file)
return gold_docs
# get all file contents from the repo
def is_test(name, test_phrases=None):
if test_phrases is None:
test_phrases = ["test", "tests", "testing"]
words = set(re.split(r" |_|\/|\.", name.lower()))
return any(word in words for word in test_phrases)
def list_files(root_dir, include_tests=False):
files = []
for filename in Path(root_dir).rglob("*.py"):
if not include_tests and is_test(filename.as_posix()):
continue
files.append(filename.relative_to(root_dir).as_posix())
return files
def detect_encoding(filename):
"""
Detect the encoding of a file
"""
with open(filename, "rb") as file:
rawdata = file.read()
return chardet.detect(rawdata)["encoding"]
def ingest_directory_contents(root_dir, include_tests=False):
files_content = {}
for relative_path in list_files(root_dir, include_tests=include_tests):
filename = os.path.join(root_dir, relative_path)
encoding = detect_encoding(filename)
if encoding is None:
content = "[BINARY DATA FILE]"
else:
try:
with open(filename, encoding=encoding) as file:
content = file.read()
except (UnicodeDecodeError, LookupError):
content = "[BINARY DATA FILE]"
files_content[relative_path] = content
return files_content
def get_file_contents(input_instances, verbose: bool = False, tmp_dir: str = "/scratch"):
orig_dir = os.getcwd()
with TemporaryDirectory(dir=tmp_dir if os.path.exists(tmp_dir) else "/tmp") as root_dir:
for instance_id, instance in tqdm(
input_instances.items(),
total=len(input_instances),
desc="Getting file contents",
):
try:
with AutoContextManager(instance, root_dir, verbose=verbose) as cm:
readmes = cm.get_readme_files()
instance["readmes"] = ingest_files(readmes)
instance["oracle_file_contents"] = ingest_files(get_oracle_filenames(instance))
instance["file_contents"] = ingest_directory_contents(cm.repo_path)
assert all([
okey in instance["file_contents"]
for okey in instance["oracle_file_contents"].keys()
])
except Exception as e:
print(f"Failed on instance {instance_id}", e)
traceback.print_exc()
finally:
# if AutoContextManager fails to exit properly future exits will return the wrong directory
os.chdir(orig_dir)
os.chdir(orig_dir)
import multiprocessing as mp
from functools import partial
def process_single_item(item, args):
"""处理单个数据项的函数"""
name = "swe-bench"
if "lite" in args.dataset_name.lower():
name += "-lite"
queries = [{
"_id": item["instance_id"],
"text": item["problem_statement"],
"metadata": {}
}]
item_dict = {item["instance_id"]: item}
output_path = os.path.join(args.output_dir, f"{name}_{item['instance_id']}", "qrels", "test.tsv")
if os.path.exists(output_path):
return
try:
get_file_contents(item_dict, tmp_dir=args.tmp_dir)
docs = []
for instance_id, instance in item_dict.items():
print(f"Instance #{instance_id}: {len(instance['oracle_file_contents'])} oracle / {len(instance['file_contents'])} files")
for filename, content in instance["file_contents"].items():
docs.append({
"_id": f"{instance_id}_{filename}",
"title": filename,
"text": content,
"metadata": {},
})
qrels = []
for instance_id, instance in item_dict.items():
for filename, content in instance["oracle_file_contents"].items():
qrels.append({
"query-id": instance_id,
"corpus-id": f"{instance_id}_{filename}",
"score": 1
})
path = os.path.join(args.output_dir, f"{name}_{instance_id}")
os.makedirs(path, exist_ok=True)
os.makedirs(os.path.join(path, "qrels"), exist_ok=True)
save_file_jsonl(queries, os.path.join(path, "queries.jsonl"))
save_file_jsonl(docs, os.path.join(path, "corpus.jsonl"))
qrels_path = os.path.join(path, "qrels", "test.tsv")
save_tsv_dict(qrels, qrels_path, ["query-id", "corpus-id", "score"])
except Exception as e:
print(f"Error processing item {item['instance_id']}: {str(e)}")
def main():
dataset = datasets.load_dataset(args.dataset_name, cache_dir=args.cache_dir)["test"]
if args.num_examples is not None:
import random
indices = random.sample([i for i in range(len(dataset))], args.num_examples)
dataset = dataset.select(indices)
print(dataset)
# 创建进程池
num_processes = mp.cpu_count() - 1 # 留一个CPU核心
pool = mp.Pool(processes=num_processes)
# 使用partial固定args参数
process_func = partial(process_single_item, args=args)
# 使用进程池并行处理
list(tqdm(
pool.imap_unordered(process_func, dataset),
total=len(dataset),
desc="Processing items"
))
# 关闭进程池
pool.close()
pool.join()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_name", type=str, default="princeton-nlp/SWE-bench_Lite",
choices=["princeton-nlp/SWE-bench", "princeton-nlp/SWE-bench_Lite"])
parser.add_argument("--cache_dir", type=str, default="/scratch/zhiruow/data")
parser.add_argument("--tmp_dir", type=str, default="/scratch/zhiruow/tmp")
parser.add_argument("--output_dir", type=str, default="datasets")
parser.add_argument("--num_examples", type=int, default=None)
args = parser.parse_args()
main()

View File

@ -0,0 +1,37 @@
import jsonlines
import csv
import os
def load_jsonlines(file):
with jsonlines.open(file, 'r') as jsonl_f:
lst = [obj for obj in jsonl_f]
return lst
def save_file_jsonl(data, fp):
with jsonlines.open(fp, mode='w') as writer:
writer.write_all(data)
def save_tsv_dict(data, fp, fields):
# build dir
dir_path = os.path.dirname(fp)
os.makedirs(dir_path, exist_ok=True)
# writing to csv file
with open(fp, 'w') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fields, delimiter='\t',)
writer.writeheader()
writer.writerows(data)
def cost_esitmate(path):
corpus = load_jsonlines(os.path.join(path, "corpus.jsonl"))
queries = load_jsonlines(os.path.join(path, "queries.jsonl"))
num_corpus_words = 0
num_queries_words = 0
for item in tqdm(corpus):
num_corpus_words += len(item["text"].split(" "))
for item in tqdm(queries):
num_queries_words += len(item["text"].split(" "))
print(len(corpus))
print(len(queries))
print(num_corpus_words)
print(num_queries_words)

View File

@ -0,0 +1,313 @@
import os
import json
import random
import logging
import pathlib
import argparse
import numpy as np
from time import time
from datasets import load_dataset
from beir import util, LoggingHandler
from beir.retrieval import models
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
from tqdm import tqdm
from transformers import HfArgumentParser
from arguments import CodeRAGEvalArgs, CodeRAGEvalModelArgs
from prompts import get_task_def_by_task_name
from FlagEmbedding import FlagLLMModel, FlagModel
def get_model(model_args: CodeRAGEvalModelArgs):
embedder_name_or_path = model_args.embedder_name_or_path
if model_args.embedder_model_class == "encoder-only-base":
embedder = FlagModel(
model_name_or_path=embedder_name_or_path,
normalize_embeddings=model_args.normalize_embeddings,
pooling_method=model_args.pooling_method,
use_fp16=model_args.use_fp16,
query_instruction_for_retrieval=model_args.query_instruction_for_retrieval,
query_instruction_format=model_args.query_instruction_format_for_retrieval,
devices=model_args.devices,
trust_remote_code=model_args.trust_remote_code,
cache_dir=model_args.cache_dir,
batch_size=model_args.embedder_batch_size,
query_max_length=model_args.embedder_query_max_length,
passage_max_length=model_args.embedder_passage_max_length,
)
elif model_args.embedder_model_class == "decoder-only-base":
embedder = FlagLLMModel(
model_name_or_path=embedder_name_or_path,
normalize_embeddings=model_args.normalize_embeddings,
pooling_method=model_args.pooling_method,
use_fp16=model_args.use_fp16,
query_instruction_for_retrieval=model_args.query_instruction_for_retrieval,
query_instruction_format=model_args.query_instruction_format_for_retrieval,
devices=model_args.devices,
examples_for_task=model_args.examples_for_task,
examples_instruction_format=model_args.examples_instruction_format,
trust_remote_code=model_args.trust_remote_code,
cache_dir=model_args.cache_dir,
batch_size=model_args.embedder_batch_size,
query_max_length=model_args.embedder_query_max_length,
passage_max_length=model_args.embedder_passage_max_length,
)
else:
raise ValueError(f"Invalid model class: {model_args.embedder_model_class}")
embedder.model.config._name_or_path = model_args.embedder_name_or_path
class CustomFlagModel:
def __init__(self, model):
self.model = model
def encode_queries(self, queries, show_progress_bar, convert_to_tensor, **kwargs):
if isinstance(queries, str):
queries = [queries]
if isinstance(queries[0], dict):
queries = [(e.get('title') + ' ' + e['text']).strip() for e in queries]
return self.model.encode_queries(queries, **kwargs)
def encode_corpus(self, corpus, show_progress_bar, convert_to_tensor, **kwargs):
if isinstance(corpus, str):
corpus = [corpus]
if isinstance(corpus[0], dict):
corpus = [(e.get('title') + ' ' + e['text']).strip() for e in corpus]
return self.model.encode_corpus(corpus, **kwargs)
def encode(self, corpus, show_progress_bar, convert_to_tensor, **kwargs):
if isinstance(corpus, str):
corpus = [corpus]
if isinstance(corpus[0], dict):
corpus = [(e.get('title') + ' ' + e['text']).strip() for e in corpus]
return self.model.encode(corpus, **kwargs)
return CustomFlagModel(embedder)
#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[LoggingHandler()])
def get_top_docs(results: dict, corpus: dict, task_id: str, topk: int = 10) -> list[str]:
if task_id not in results: return []
doc_scores = results[task_id]
doc_scores_sorted = sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)
doc_scores_sorted = doc_scores_sorted[:topk]
doc_code_snippets = [corpus[code_id] for code_id, score in doc_scores_sorted]
return doc_code_snippets
def main(
eval_args: CodeRAGEvalArgs,
model_args: CodeRAGEvalModelArgs
):
args = eval_args
embedder = get_model(model_args)
model = DRES(
embedder,
batch_size=args.batch_size,
corpus_chunk_size=512 * 9999
)
retriever = EvaluateRetrieval(model, score_function="dot")
if args.dataset.startswith("swe-bench") or args.dataset.startswith("repoeval"):
all_eval_results = []
if args.dataset.startswith("swe-bench"):
swebench = load_dataset("princeton-nlp/SWE-bench_Lite")["test"]
all_top_docs = [[] for _ in swebench]
instance_list = [i for i in os.listdir("datasets") if i.startswith(f"{args.dataset}_")]
instance_list_filtered = []
for ins_dir in tqdm(instance_list):
logging.info("Instance Repo: {}".format(ins_dir))
# load data and perform retrieval
corpus, queries, qrels = GenericDataLoader(
data_folder=os.path.join("datasets", ins_dir)
).load(split="test")
logging.info(f"Instance #{ins_dir}: #{len(corpus)} corpus, #{len(queries)} queries")
start_time = time()
if len(queries) == 1:
queries.update({"dummy": "dummy"})
results = retriever.retrieve(corpus, queries)
if "dummy" in queries:
queries.pop("dummy")
results.pop("dummy")
end_time = time()
logging.info("Time taken to retrieve: {:.2f} seconds".format(end_time - start_time))
# get topk retrieved docs
if args.dataset.startswith("swe-bench"):
indices = [i for i, ex in enumerate(swebench) if ex["instance_id"] in queries]
for index in indices:
instance_id = swebench[index]["instance_id"]
all_top_docs[index] = get_top_docs(results, corpus, instance_id)
elif args.dataset.startswith("repoeval"):
args.dataset_path = "output/repoeval/datasets/function_level_completion_2k_context_codex.test.clean.jsonl"
tasks = [json.loads(line.strip()) for line in open(args.dataset_path, 'r')]
prompts, references, docs, metadatas = [], [], [], []
for task in tasks:
if task["metadata"]["task_id"] not in queries: continue
prompts.append(task["prompt"]) # save full prompt
references.append(task["metadata"]["ground_truth"])
docs.append(get_top_docs(
results=results, corpus=corpus, task_id=task["metadata"]["task_id"],
))
metadatas.append(task["metadata"])
assert len(prompts) == len(references) == len(docs)
dataset = [
{"prompt": p, "reference": r, "docs": d, "metadata": m}
for p, r, d, m in zip(prompts, references, docs, metadatas)
]
with open(args.results_file, "a") as fout:
for curr in dataset:
fout.write(json.dumps(curr) + "\n")
else:
raise ValueError(f"`dataset` should starts with either 'swe-bench' or 'repoeval'.")
# evaluate retrieval results
if len(qrels) == 0:
logging.info("No qrels found for this dataset.")
return
logging.info("Retriever evaluation for k in: {}".format(retriever.k_values))
ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
mrr = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="mrr")
eval_results = {
"ndcg": ndcg, "mrr": mrr,
"recall": recall, "precision": precision,
"time": end_time - start_time
}
logging.info(f"Instance #{ins_dir}: {eval_results}")
all_eval_results.append(eval_results)
with open(args.output_file + "_all", "w") as f:
json.dump(all_eval_results, f)
if args.dataset.startswith("swe-bench"):
swebench = swebench.add_column("docs", all_top_docs)
swebench.to_json(args.results_file)
avg_eval_results = {}
for k, v_dict in all_eval_results[0].items():
if isinstance(v_dict, dict):
avg_v_dict = {}
for vk, vv in v_dict.items():
avg_vv = sum([e[k][vk] for e in all_eval_results]) / len(all_eval_results)
avg_v_dict[vk] = avg_vv
avg_eval_results.update(avg_v_dict)
elif isinstance(v_dict, float):
avg_v = sum([e[k] for e in all_eval_results]) / len(all_eval_results)
avg_eval_results[k] = avg_v
else:
raise ValueError
print("Average Eval Results: ", avg_eval_results)
with open(args.output_file, "w") as f:
json.dump(avg_eval_results, f)
else:
dataset = args.dataset
corpus, queries, qrels = GenericDataLoader(data_folder=os.path.join("datasets", args.dataset)).load(
split="test")
#### Retrieve dense results (format of results is identical to qrels)
start_time = time()
results = retriever.retrieve(corpus, queries)
end_time = time()
print("Time taken to retrieve: {:.2f} seconds".format(end_time - start_time))
if args.dataset in ["humaneval", "mbpp", "apps"]:
if args.dataset == "humaneval":
ds = load_dataset("openai_humaneval")
id_key = "task_id"
elif args.dataset == "mbpp":
ds = load_dataset("mbpp")
id_key = "task_id"
elif args.dataset == "apps":
ds = load_dataset("codeparrot/apps")
id_key = "problem_id"
all_top_docs = []
for task_id in ds["test"][id_key]:
all_top_docs.append(get_top_docs(results, corpus, f"{task_id}_doc"))
ds["test"] = ds["test"].add_column("docs", all_top_docs)
ds["test"].to_json(args.results_file) # this outputs to arrow format and read as .jsonl
elif args.dataset.startswith("odex"):
lang = args.dataset.split("_")[-1]
ds = load_dataset("neulab/odex", lang, trust_remote_code=True)
all_top_docs = []
for idx, task_id in enumerate(ds["test"]["task_id"]):
all_top_docs.append(get_top_docs(results, corpus, f"{idx}_{task_id}"))
ds["test"] = ds["test"].add_column("docs", all_top_docs)
ds["test"].to_json(args.results_file) # this outputs to arrow format and read as .jsonl
elif args.dataset.startswith("ds1000"):
_, key, mode = args.dataset.split("_")
key = key.capitalize()
mode = mode.capitalize()
from create.ds1000 import get_dataset
source_dir = pathlib.Path(__file__).parent / "ds"
data = get_dataset(source_dir, mode=mode, key=key)
all_docs = []
example_ids = []
for item in data:
example = item.data
example_id = f"{example['lib']}_{example['perturbation_origin_id']}"
all_docs.append(get_top_docs(results, corpus, example_id))
example_ids.append(example_id)
assert len(all_docs) == len(
example_ids), f"length of all_docs should be {len(example_ids)}, now is {len(all_docs)}"
with open(args.results_file, "w+") as fout:
for idx, all_doc in enumerate(all_docs):
fout.write(json.dumps({"example_id": example_id,
"docs": all_doc}) + "\n")
else:
with open(args.results_file, 'w+') as fw:
for curr in results:
fw.write(json.dumps({curr: results[curr]}) + "\n")
#### Evaluate your retrieval using NDCG@k, MAP@K ...
if len(qrels) == 0:
logging.info("No qrels found for this dataset.")
return
logging.info("Retriever evaluation for k in: {}".format(retriever.k_values))
ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
mrr = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="mrr")
recall_cap = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="r_cap")
hole = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="hole")
all_results = {"ndcg": ndcg, "mrr": mrr, "recall": recall, "precision": precision,
"time": end_time - start_time}
with open(args.output_file, "w") as f:
json.dump(all_results, f)
#### Print top-k documents retrieved ####
top_k = 3
query_id, ranking_scores = random.choice(list(results.items()))
scores_sorted = sorted(ranking_scores.items(), key=lambda item: item[1], reverse=True)
logging.info("Query : %s\n" % queries[query_id])
for rank in range(top_k):
doc_id = scores_sorted[rank][0]
# Format: Rank x: ID [Title] Body
logging.info(
"Rank %d: %s [%s] - %s\n" % (rank + 1, doc_id, corpus[doc_id].get("title"), corpus[doc_id].get("text")))
if __name__ == "__main__":
parser = HfArgumentParser((
CodeRAGEvalArgs,
CodeRAGEvalModelArgs
))
eval_args, model_args = parser.parse_args_into_dataclasses()
main(eval_args, model_args)

View File

@ -0,0 +1,18 @@
from typing import Dict
def get_task_def_by_task_name(task_name: str) -> str:
task_name_to_instruct: Dict[str, str] = {
'humaneval': 'Given a question that consists of a mix of text and code snippets, retrieve relevant answers that also consist of a mix of text and code snippets, and can help answer the question.',
'mbpp': 'Given a textual explanation of code functionality, retrieve the corresponding code implementation.',
'ds1000_all_completion': 'Given a question that consists of a mix of text and code snippets, retrieve relevant answers that also consist of a mix of text and code snippets, and can help answer the question.',
'odex_en': 'Given a question, retrieve relevant answers that also consist of a mix of text and code snippets, and can help answer the question.',
'odex_es': 'Given a question, retrieve relevant answers that also consist of a mix of text and code snippets, and can help answer the question.',
'odex_ja': 'Given a question, retrieve relevant answers that also consist of a mix of text and code snippets, and can help answer the question.',
'odex_ru': 'Given a question, retrieve relevant answers that also consist of a mix of text and code snippets, and can help answer the question.',
'repoeval': 'Given a code snippet and a new function name, retrieve the implementation of the function.',
# 'repoeval': 'Given a piece of code segment, retrieve the code segment that is the latter part of the code.',
'swe-bench-lite': 'Given a code snippet containing a bug and a natural language description of the bug or error, retrieve code snippets that demonstrate solutions or fixes for similar bugs or errors (the desired documents).'
}
return task_name_to_instruct[task_name]

View File

@ -0,0 +1,69 @@
from typing import List
from dataclasses import dataclass, field
from FlagEmbedding.abc.evaluation import (
AbsEvalModelArgs as COIREvalModelArgs,
)
def coir_tasks():
return [
"apps",
"codefeedback-mt",
"codefeedback-st",
"CodeSearchNet-ccr-go",
"CodeSearchNet-ccr-java",
"CodeSearchNet-ccr-javascript",
"CodeSearchNet-ccr-php",
"CodeSearchNet-ccr-python",
"CodeSearchNet-ccr-ruby",
"CodeSearchNet-go",
"CodeSearchNet-java",
"CodeSearchNet-javascript",
"CodeSearchNet-php",
"CodeSearchNet-python",
"CodeSearchNet-ruby",
"codetrans-contest",
"codetrans-dl",
"cosqa",
"stackoverflow-qa",
"synthetic-text2sql"
]
@dataclass
class COIREvalArgs:
output_dir: str = field(
default="./results", metadata={"help": "Path to save results."}
)
tasks: List[str] = field(
default_factory=coir_tasks,
metadata={
"help": "Tasks to evaluate. Default: None. Available tasks: ['apps', 'codefeedback-mt', 'codefeedback-st', 'CodeSearchNet-ccr-go', 'CodeSearchNet-ccr-java', 'CodeSearchNet-ccr-javascript', 'CodeSearchNet-ccr-php', 'CodeSearchNet-ccr-python', 'CodeSearchNet-ccr-ruby', 'CodeSearchNet-go', 'CodeSearchNet-java', 'CodeSearchNet-javascript', 'CodeSearchNet-php', 'CodeSearchNet-python', 'CodeSearchNet-ruby', 'codetrans-contest', 'codetrans-dl', 'cosqa', 'stackoverflow-qa', 'synthetic-text2sql']",
"choices": [
"apps",
"codefeedback-mt",
"codefeedback-st",
"CodeSearchNet-ccr-go",
"CodeSearchNet-ccr-java",
"CodeSearchNet-ccr-javascript",
"CodeSearchNet-ccr-php",
"CodeSearchNet-ccr-python",
"CodeSearchNet-ccr-ruby",
"CodeSearchNet-go",
"CodeSearchNet-java",
"CodeSearchNet-javascript",
"CodeSearchNet-php",
"CodeSearchNet-python",
"CodeSearchNet-ruby",
"codetrans-contest",
"codetrans-dl",
"cosqa",
"stackoverflow-qa",
"synthetic-text2sql"
]
}
)
use_special_instructions: bool = field(
default=False, metadata={"help": "Whether to use specific instructions in `prompts.py` for evaluation. Default: False"}
)

View File

@ -0,0 +1,16 @@
output_dir=result
python main.py \
--output_dir ${output_dir} \
--use_special_instructions True \
--embedder_name_or_path BAAI/bge-code-v1 \
--embedder_model_class decoder-only-base \
--query_instruction_format_for_retrieval '<instruct>{}\n<query>{}' \
--embedder_query_max_length 2048 \
--embedder_passage_max_length 2048 \
--trust_remote_code True \
--pooling_method last_token \
--embedder_batch_size 64 \
--devices cuda:0 cuda:1 cuda:2 cuda:3 cuda:4 cuda:5 cuda:6 cuda:7 \
--tasks apps codetrans-contest codetrans-dl cosqa synthetic-text2sql stackoverflow-qa codefeedback-mt codefeedback-st CodeSearchNet-ccr-go CodeSearchNet-ccr-java CodeSearchNet-ccr-javascript CodeSearchNet-ccr-php CodeSearchNet-ccr-python CodeSearchNet-ccr-ruby CodeSearchNet-go CodeSearchNet-java CodeSearchNet-javascript CodeSearchNet-php CodeSearchNet-python CodeSearchNet-ruby \
--cache_dir ./cache

View File

@ -0,0 +1,167 @@
import os
import json
import coir
from transformers import HfArgumentParser
from arguments import COIREvalArgs, COIREvalModelArgs
from prompts import get_task_def_by_task_name
from FlagEmbedding import FlagLLMModel, FlagModel
def get_model(model_args: COIREvalModelArgs):
embedder_name_or_path = model_args.embedder_name_or_path
if model_args.embedder_model_class == "encoder-only-base":
embedder = FlagModel(
model_name_or_path=embedder_name_or_path,
normalize_embeddings=model_args.normalize_embeddings,
pooling_method=model_args.pooling_method,
use_fp16=model_args.use_fp16,
query_instruction_for_retrieval=model_args.query_instruction_for_retrieval,
query_instruction_format=model_args.query_instruction_format_for_retrieval,
devices=model_args.devices,
trust_remote_code=model_args.trust_remote_code,
cache_dir=model_args.cache_dir,
batch_size=model_args.embedder_batch_size,
query_max_length=model_args.embedder_query_max_length,
passage_max_length=model_args.embedder_passage_max_length,
)
elif model_args.embedder_model_class == "decoder-only-base":
embedder = FlagLLMModel(
model_name_or_path=embedder_name_or_path,
normalize_embeddings=model_args.normalize_embeddings,
pooling_method=model_args.pooling_method,
use_fp16=model_args.use_fp16,
query_instruction_for_retrieval=model_args.query_instruction_for_retrieval,
query_instruction_format=model_args.query_instruction_format_for_retrieval,
devices=model_args.devices,
examples_for_task=model_args.examples_for_task,
examples_instruction_format=model_args.examples_instruction_format,
trust_remote_code=model_args.trust_remote_code,
cache_dir=model_args.cache_dir,
batch_size=model_args.embedder_batch_size,
query_max_length=model_args.embedder_query_max_length,
passage_max_length=model_args.embedder_passage_max_length,
)
else:
raise ValueError(f"Invalid model class: {model_args.embedder_model_class}")
embedder.model.config._name_or_path = model_args.embedder_name_or_path
class CustomFlagModel:
def __init__(self, model):
self.model = model
def encode_queries(self, queries, show_progress_bar, convert_to_tensor, **kwargs):
if isinstance(queries, str):
queries = [queries]
if isinstance(queries[0], dict):
queries = [(e.get('title') + ' ' + e['text']).strip() for e in queries]
return self.model.encode_queries(queries, **kwargs)
def encode_corpus(self, corpus, show_progress_bar, convert_to_tensor, **kwargs):
if isinstance(corpus, str):
corpus = [corpus]
if isinstance(corpus[0], dict):
corpus = [(e.get('title') + ' ' + e['text']).strip() for e in corpus]
return self.model.encode_corpus(corpus, **kwargs)
def encode(self, corpus, show_progress_bar, convert_to_tensor, **kwargs):
if isinstance(corpus, str):
corpus = [corpus]
if isinstance(corpus[0], dict):
corpus = [(e.get('title') + ' ' + e['text']).strip() for e in corpus]
return self.model.encode(corpus, **kwargs)
return CustomFlagModel(embedder)
def main(
eval_args: COIREvalArgs,
model_args: COIREvalModelArgs
):
model = get_model(model_args)
output_folder = os.path.join(eval_args.output_dir, os.path.basename(model.model.model.config._name_or_path))
all_task = eval_args.tasks
if not isinstance(all_task, list):
all_task = [all_task]
all_results = {}
for task_name in all_task:
save_path = os.path.join(output_folder, f"{task_name}.json")
if os.path.exists(save_path):
with open(save_path, "r", encoding="utf-8") as f:
results = json.load(f)
all_results[task_name] = results['metrics']
continue
tmp_task = coir.get_tasks(tasks=[task_name])
evaluation = coir.COIR(tasks=tmp_task,
batch_size=model_args.embedder_batch_size)
model.model.stop_self_pool()
if eval_args.use_special_instructions:
model.model.query_instruction_for_retrieval = get_task_def_by_task_name(task_name)
results = evaluation.run(model, output_folder=output_folder)
all_results[task_name] = results[task_name]
csn_result = 0
csn_num = 0
csn_ccr_result = 0
csn_ccr_num = 0
pop_keys = []
all_result = 0
all_num = 0
for k in all_results.keys():
if 'CodeSearchNet-ccr' in k:
csn_ccr_result += all_results[k]['NDCG']['NDCG@10']
csn_ccr_num += 1
pop_keys.append(k)
elif 'CodeSearchNet' in k:
csn_result += all_results[k]['NDCG']['NDCG@10']
csn_num += 1
pop_keys.append(k)
else:
all_result += all_results[k]['NDCG']['NDCG@10']
all_num += 1
if csn_num > 0:
print('Using CodeSearchNet')
all_result += csn_result / csn_num
all_num += 1
if csn_ccr_num > 0:
print('Using CodeSearchNet-ccr')
all_result += csn_ccr_result / csn_ccr_num
all_num += 1
new_results = {}
for k in all_results:
if k in pop_keys:
continue
new_results[k] = all_results[k]['NDCG']['NDCG@10']
if csn_num > 0:
new_results['CodeSearchNet'] = csn_result / csn_num
if csn_ccr_num > 0:
new_results['CodeSearchNet_CCR'] = csn_ccr_result / csn_ccr_num
new_results['all'] = all_result / all_num
print(new_results)
with open(os.path.join(output_folder, 'OVERALL-results.json'), 'w') as f:
json.dump(new_results, f)
if __name__ == "__main__":
parser = HfArgumentParser((
COIREvalArgs,
COIREvalModelArgs
))
eval_args, model_args = parser.parse_args_into_dataclasses()
main(eval_args, model_args)

View File

@ -0,0 +1,38 @@
from typing import Dict
def get_task_def_by_task_name(task_name: str) -> str:
task_name_to_instruct: Dict[str, str] = {
# Text-to-Code Retrieval
## Code Contest Retrieval
'apps': 'Given a code contest problem description, retrieve relevant code that can help solve the problem.',
## Web Query to Code Retrieval
'cosqa': 'Given a web search query, retrieve relevant code that can help answer the query.',
## Text-to-SQL Retrieval
'synthetic-text2sql': 'Given a question in text, retrieve SQL queries that are appropriate responses to the question.',
# Code-to-Text Retrieval
## Code Summary Retrieval
'CodeSearchNet-': 'Given a piece of code, retrieve the document string that summarizes the code.',
# Code-to-Code Retrieval
## Code Context Retrieval
'CodeSearchNet-ccr-': 'Given a piece of code segment, retrieve the code segment that is the latter part of the code.',
## Similar Code Retrieval
'codetrans-dl': 'Given a piece of code, retrieve code that is semantically equivalent to the input code.',
'codetrans-contest': 'Given a piece of Python code, retrieve C++ code that is semantically equivalent to the input code.',
# Hybrid Code Retrieval
## Single-turn Code QA
'stackoverflow-qa': 'Given a question that consists of a mix of text and code snippets, retrieve relevant answers that also consist of a mix of text and code snippets, and can help answer the question.',
'codefeedback-st': 'Given a question that consists of a mix of text and code snippets, retrieve relevant answers that also consist of a mix of text and code snippets, and can help answer the question.',
## Multi-turn Code QA
'codefeedback-mt': 'Given a multi-turn conversation history that consists of a mix of text and code snippets, retrieve relevant answers that also consist of a mix of text and code snippets, and can help answer the question.',
}
special_task_names = ['CodeSearchNet-ccr-', 'CodeSearchNet-']
for special_task_name in special_task_names:
if special_task_name in task_name:
return task_name_to_instruct[special_task_name]
return task_name_to_instruct[task_name]