mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-06-27 02:39:58 +00:00
Merge pull request #1464 from 545999961/master
upload coder eval script
This commit is contained in:
commit
97ca07325d
@ -11,14 +11,17 @@
|
||||
</a>
|
||||
</p>
|
||||
|
||||
|
||||
This repo contains the data, training, and evaluation pipeline for CodeR / [BGE-Code-v1](https://huggingface.co/BAAI/bge-code-v1)
|
||||
|
||||
**[BGE-Code-v1](https://huggingface.co/BAAI/bge-code-v1)** is an LLM-based code embedding model that supports code retrieval, text retrieval, and multilingual retrieval. It primarily demonstrates the following capabilities:
|
||||
|
||||
- Superior Code Retrieval Performance: The model demonstrates exceptional code retrieval capabilities, supporting natural language queries in both English and Chinese, as well as 20 programming languages.
|
||||
- Robust Text Retrieval Capabilities: The model maintains strong text retrieval capabilities comparable to text embedding models of similar scale.
|
||||
- Extensive Multilingual Support: BGE-Code-v1 offers comprehensive multilingual retrieval capabilities, excelling in languages such as English, Chinese, Japanese, French, and more.
|
||||
|
||||
## :bell: News:
|
||||
|
||||
- 🥳 5/15/2025: We have released the CodeR! :fire:
|
||||
|
||||
## Usage
|
||||
@ -29,9 +32,6 @@ This repo contains the data, training, and evaluation pipeline for CodeR / [BGE-
|
||||
git clone https://github.com/FlagOpen/FlagEmbedding.git
|
||||
cd FlagEmbedding
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
```python
|
||||
from FlagEmbedding import FlagLLMModel
|
||||
queries = [
|
||||
"Delete the record with ID 4 from the 'Staff' table.",
|
||||
@ -149,29 +149,29 @@ print(scores.tolist())
|
||||
|
||||
- CoIR
|
||||
|
||||
| | CodeXEmbed-2B | CodeXEmbed-7B | Voyage-Code-002 | Voyage-Code-003 | BGE-Code-v1 |
|
||||
|---------------------------------------|---------------|---------------|-----------------|-----------------|-----------|
|
||||
| **Apps** | 76.86 | 85.38 | 26.52 | 93.62 | 98.08 |
|
||||
| **CosQA** | 40.47 | 42.47 | 29.79 | 34.45 | 46.72 |
|
||||
| **Text2SQL** | 78.42 | 78.94 | 69.26 | 62.87 | 64.35 |
|
||||
| **CSN** | 87.87 | 89.67 | 81.79 | 89.35 | 89.53 |
|
||||
| **CSN-CCR** | 97.66 | 97.95 | 73.45 | 90.05 | 98.30 |
|
||||
| **CodeTrans-Contest** | 90.30 | 94.45 | 72.77 | 94.96 | 94.38 |
|
||||
| **CodeTrans-DL** | 38.57 | 40.46 | 27.48 | 38.57 | 46.13 |
|
||||
| **StackOverFlow-QA** | 94.47 | 96.33 | 67.68 | 97.17 | 95.35 |
|
||||
| **CodeFeedBack-ST** | 86.36 | 87.53 | 65.35 | 90.67 | 90.56 |
|
||||
| **CodeFeedBack-MT** | 65.51 | 68.83 | 28.74 | 93.58 | 94.38 |
|
||||
| **AVG** | **75.65** | **78.20** | **56.26** | **78.53** | **81.77** |
|
||||
| | CodeXEmbed-2B | CodeXEmbed-7B | Voyage-Code-002 | Voyage-Code-003 | BGE-Code-v1 |
|
||||
| --------------------- | ------------- | ------------- | --------------- | --------------- | ----------- |
|
||||
| **Apps** | 76.86 | 85.38 | 26.52 | 93.62 | 98.08 |
|
||||
| **CosQA** | 40.47 | 42.47 | 29.79 | 34.45 | 46.72 |
|
||||
| **Text2SQL** | 78.42 | 78.94 | 69.26 | 62.87 | 64.35 |
|
||||
| **CSN** | 87.87 | 89.67 | 81.79 | 89.35 | 89.53 |
|
||||
| **CSN-CCR** | 97.66 | 97.95 | 73.45 | 90.05 | 98.30 |
|
||||
| **CodeTrans-Contest** | 90.30 | 94.45 | 72.77 | 94.96 | 94.38 |
|
||||
| **CodeTrans-DL** | 38.57 | 40.46 | 27.48 | 38.57 | 46.13 |
|
||||
| **StackOverFlow-QA** | 94.47 | 96.33 | 67.68 | 97.17 | 95.35 |
|
||||
| **CodeFeedBack-ST** | 86.36 | 87.53 | 65.35 | 90.67 | 90.56 |
|
||||
| **CodeFeedBack-MT** | 65.51 | 68.83 | 28.74 | 93.58 | 94.38 |
|
||||
| **AVG** | **75.65** | **78.20** | **56.26** | **78.53** | **81.77** |
|
||||
|
||||
- CodedRAG
|
||||
|
||||
| | HummanEval | MBPP | DS-1000 | ODEX | RepoEval | SWE-bench-Lite | AVG |
|
||||
| --------------- | ---------- | ---- | ------- | ---- | -------- | -------------- | ---- |
|
||||
| | HummanEval | MBPP | DS-1000 | ODEX | RepoEval | SWE-bench-Lite | AVG |
|
||||
| --------------- | ---------- | ---- | ------- | ---- | -------- | -------------- | -------- |
|
||||
| SFR | 100.0 | 99.0 | 19.3 | 37.1 | 83.8 | 62.7 | **67.0** |
|
||||
| Jina-v2-code | 100.0 | 97.7 | 26.2 | 19.9 | 90.5 | 58.3 | **65.4** |
|
||||
| CodeXEmbed-2B | 100.0 | 97.4 | 25.4 | 23.9 | 88.7 | 52.4 | **64.6** |
|
||||
| Voyage-Code-002 | 100.0 | 99.0 | 33.1 | 26.6 | 94.3 | 29.1 | **63.7** |
|
||||
| BGE-Code-v1 | 100.0 | 99.2 | 40.9 | 36.1 | 93.1 | 67.4 | **72.8** |
|
||||
| BGE-Code-v1 | 100.0 | 99.2 | 40.9 | 36.1 | 93.1 | 67.4 | **72.8** |
|
||||
|
||||
### Instructions for Evaluation
|
||||
|
||||
@ -195,3 +195,39 @@ print(scores.tolist())
|
||||
"SWE-bench-Lite": "Given a code snippet containing a bug and a natural language description of the bug or error, retrieve code snippets that demonstrate solutions or fixes for similar bugs or errors (the desired documents)."
|
||||
}
|
||||
```
|
||||
|
||||
### Evaluation script
|
||||
|
||||
#### CoIR
|
||||
|
||||
For CoIR, we use the [CoIR](https://github.com/CoIR-team/coir) evaluation script:
|
||||
|
||||
```shell
|
||||
cd ./evaluation/coir_eval
|
||||
### clone coir
|
||||
mkdir test
|
||||
cd ./test
|
||||
git clone https://github.com/CoIR-team/coir.git
|
||||
mv ./coir/coir ../
|
||||
cd ..
|
||||
rm -rf ./test
|
||||
### evaluate
|
||||
bash eval.sh
|
||||
```
|
||||
|
||||
### CodeRAG
|
||||
|
||||
For CodeRAG, we use the [CodeRAG](https://github.com/code-rag-bench/code-rag-bench) evaluation script:
|
||||
|
||||
```shell
|
||||
cd ./evaluation/coderag_eval
|
||||
### clone coderag
|
||||
git clone https://github.com/code-rag-bench/code-rag-bench.git
|
||||
## You need prepare environment according to README.md
|
||||
rm -rf ./code-rag-bench/retrieval/create
|
||||
cp -r ./test/* ./code-rag-bench/retrieval/
|
||||
### prepare data
|
||||
bash prepare_data.sh
|
||||
### evaluate
|
||||
bash eval.sh
|
||||
```
|
22
research/BGE_Coder/evaluation/coderag_eval/eval.sh
Normal file
22
research/BGE_Coder/evaluation/coderag_eval/eval.sh
Normal file
@ -0,0 +1,22 @@
|
||||
cd ./code-rag-bench/retrieval/
|
||||
|
||||
output_dir='result'
|
||||
|
||||
for dataset_name in "humaneval" "mbpp" "repoeval" "ds1000_all_completion" "odex_en" "swe-bench-lite"
|
||||
do
|
||||
echo "dataset_name: ${dataset_name}"
|
||||
python main.py \
|
||||
--embedder_name_or_path BAAI/bge-code-v1 \
|
||||
--embedder_model_class decoder-only-base \
|
||||
--query_instruction_format_for_retrieval '<instruct>{}\n<query>{}' \
|
||||
--embedder_query_max_length 2048 \
|
||||
--embedder_passage_max_length 2048 \
|
||||
--trust_remote_code True \
|
||||
--pooling_method last_token \
|
||||
--embedder_batch_size 64 \
|
||||
--devices cuda:0 cuda:1 cuda:2 cuda:3 cuda:4 cuda:5 cuda:6 cuda:7 \
|
||||
--cache_dir ./cache \
|
||||
--dataset $dataset_name \
|
||||
--output_file ../../${output_dir}/${dataset_name}_output.json \
|
||||
--results_file ../../${output_dir}/${dataset_name}_results.json
|
||||
done
|
@ -0,0 +1,7 @@
|
||||
cd ./code-rag-bench/retrieval/
|
||||
|
||||
for dataset_name in "humaneval" "mbpp" "live_code_bench" "ds1000" "odex" "repoeval_repo" "swebench_repo"
|
||||
do
|
||||
echo "dataset_name: ${dataset_name}"
|
||||
PYTHONPATH=./ python create/${dataset_name}.py
|
||||
done
|
34
research/BGE_Coder/evaluation/coderag_eval/test/arguments.py
Normal file
34
research/BGE_Coder/evaluation/coderag_eval/test/arguments.py
Normal file
@ -0,0 +1,34 @@
|
||||
from typing import List
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from FlagEmbedding.abc.evaluation import (
|
||||
AbsEvalModelArgs as CodeRAGEvalModelArgs,
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class CodeRAGEvalArgs:
|
||||
dataset: str = field(
|
||||
default='humaneval',
|
||||
metadata={
|
||||
"help": "Task to evaluate. Default: humaneval. Available tasks: "
|
||||
"['humaneval', 'mbpp', 'live_code_bench', 'ds1000', 'odex', 'repoeval_repo', 'swebench_repo', 'code_search_net']",
|
||||
}
|
||||
)
|
||||
max_length: int = field(
|
||||
default=2048, metadata={"help": "Max length to use for evaluation."}
|
||||
)
|
||||
batch_size: int = field(
|
||||
default=64, metadata={"help": "Batch size for evaluation."}
|
||||
)
|
||||
output_file: str = field(
|
||||
default="outputs.json",
|
||||
metadata={
|
||||
"help": "Specify the filepath if you want to save the retrieval (evaluation) results."
|
||||
},
|
||||
)
|
||||
results_file: str = field(
|
||||
default="results.json",
|
||||
metadata={
|
||||
"help": "Specify the filepath if you want to save the retrieval results."
|
||||
},
|
||||
)
|
@ -0,0 +1,52 @@
|
||||
import argparse
|
||||
import datasets
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
from create.utils import save_tsv_dict, save_file_jsonl, load_jsonlines
|
||||
|
||||
def document2code(data, split="train"):
|
||||
data = data[split]
|
||||
code_search_net_data_queries = []
|
||||
code_search_net_data_docs = []
|
||||
code_search_net_data_qrels = []
|
||||
|
||||
for item in tqdm(data):
|
||||
doc = item["func_documentation_string"]
|
||||
code = item["func_code_string"]
|
||||
doc_id = "{repository_name}_{func_path_in_repository}_{func_name}_doc".format_map(item)
|
||||
code_id = "{repository_name}_{func_path_in_repository}_{func_name}_code".format_map(item)
|
||||
code_search_net_data_queries.append({"_id": doc_id, "text": doc, "metadata": {}})
|
||||
code_search_net_data_docs.append({"_id": code_id, "title": item["func_name"], "text": code, "metadata": {}})
|
||||
code_search_net_data_qrels.append({"query-id": doc_id, "corpus-id": code_id, "score": 1})
|
||||
|
||||
return code_search_net_data_queries, code_search_net_data_docs, code_search_net_data_qrels
|
||||
|
||||
def main():
|
||||
#### /print debug information to stdout
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--language", type=str, default="python", help="codesearch net language")
|
||||
parser.add_argument("--output_dir", type=str, default="datasets")
|
||||
|
||||
args = parser.parse_args()
|
||||
dataset = datasets.load_dataset("code_search_net", args.language)
|
||||
|
||||
path = os.path.join(args.output_dir, "code_search_net_{}".format(args.language))
|
||||
os.makedirs(path)
|
||||
os.makedirs(os.path.join(path, "qrels"))
|
||||
|
||||
docs = []
|
||||
queries = []
|
||||
for split in ["train", "validation", "test"]:
|
||||
queries_split, docs_split, qrels_split = document2code(dataset, split)
|
||||
docs += docs_split
|
||||
queries += queries_split
|
||||
|
||||
save_tsv_dict(qrels_split, os.path.join(path, "qrels", "{}.tsv".format(split)), ["query-id", "corpus-id", "score"])
|
||||
|
||||
save_file_jsonl(queries, os.path.join(path, "queries.jsonl"))
|
||||
save_file_jsonl(docs, os.path.join(path, "corpus.jsonl"))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
123
research/BGE_Coder/evaluation/coderag_eval/test/create/ds1000.py
Normal file
123
research/BGE_Coder/evaluation/coderag_eval/test/create/ds1000.py
Normal file
@ -0,0 +1,123 @@
|
||||
import io
|
||||
import os
|
||||
import fcntl
|
||||
import pathlib
|
||||
import zipfile
|
||||
import argparse
|
||||
import requests
|
||||
import warnings
|
||||
import itertools
|
||||
from tqdm import tqdm
|
||||
from datasets import load_dataset
|
||||
from create.utils import save_tsv_dict, save_file_jsonl
|
||||
|
||||
|
||||
# Load dataset
|
||||
def download_source(source_dir):
|
||||
src = source_dir / "ds1000.py"
|
||||
url = "https://github.com/HKUNLP/DS-1000/blob/49c1c543ada8b58138181333cdc62e613204efcf/ds1000.py?raw=true"
|
||||
lock = src.with_suffix(".lock")
|
||||
with open(lock, "w") as f_lock:
|
||||
fcntl.flock(f_lock, fcntl.LOCK_EX)
|
||||
if not src.exists():
|
||||
warnings.warn(f"DS-1000 source is being saved to {src}.")
|
||||
print("Downloading source code...")
|
||||
r = requests.get(url, stream=True)
|
||||
with open(src, "wb") as f_src:
|
||||
f_src.write(r.content)
|
||||
open(src.parent / "__init__.py", "w").close()
|
||||
print("Done.")
|
||||
fcntl.flock(f_lock, fcntl.LOCK_UN)
|
||||
|
||||
def download_dataset(source_dir):
|
||||
path = source_dir / "ds1000_data"
|
||||
url = "https://github.com/HKUNLP/DS-1000/blob/49c1c543ada8b58138181333cdc62e613204efcf/ds1000_data.zip?raw=true"
|
||||
lock = path.with_suffix(".lock")
|
||||
with open(lock, "w") as f_lock:
|
||||
fcntl.flock(f_lock, fcntl.LOCK_EX)
|
||||
if not path.exists():
|
||||
warnings.warn(f"DS-1000 data is being saved to {path}.")
|
||||
print("Downloading dataset...")
|
||||
r = requests.get(url, stream=True)
|
||||
z = zipfile.ZipFile(io.BytesIO(r.content))
|
||||
z.extractall(source_dir)
|
||||
print("Done.")
|
||||
fcntl.flock(f_lock, fcntl.LOCK_UN)
|
||||
|
||||
def get_dataset(source_dir, mode: str = "Completion", key: str = "All"):
|
||||
"""Returns dataset for the task or an iterable of any object, that get_prompt can handle"""
|
||||
from ds.ds1000 import DS1000Dataset
|
||||
|
||||
data = DS1000Dataset(source_dir / "ds1000_data", mode=mode).data
|
||||
if key == "All":
|
||||
if mode == "Insertion":
|
||||
warnings.warn(
|
||||
"Insertion not supported for Matplotlib. Only running others."
|
||||
)
|
||||
data = {k: v for k, v in data.items() if k != "Matplotlib"}
|
||||
dataset = list(itertools.chain(*data.values()))
|
||||
else:
|
||||
dataset = data[key]
|
||||
return dataset
|
||||
|
||||
|
||||
# Collect queries, docs, and relations
|
||||
def document2code(data: list):
|
||||
queries, docs, qrels = [], [], []
|
||||
|
||||
# collect doc corpus
|
||||
code_docs = load_dataset("neulab/docprompting-conala", "docs")["train"]
|
||||
for i in range(len(code_docs)):
|
||||
docs.append({
|
||||
"_id": str(i),
|
||||
"title": code_docs[i]["doc_id"],
|
||||
"text": code_docs[i]["doc_content"],
|
||||
"metadata": {}
|
||||
})
|
||||
|
||||
# load canonical docs
|
||||
ds1000 = load_dataset("json", data_files={"test": args.canonical_file})["test"]
|
||||
for idx,item in enumerate(tqdm(data)):
|
||||
example = item.data
|
||||
query = example["prompt"]
|
||||
query_id = f"{example['lib']}_{example['perturbation_origin_id']}"
|
||||
queries.append({"_id": query_id, "text": query, "metadata": {}})
|
||||
|
||||
doc_ids = [doc["title"] for doc in ds1000[idx]["docs"]]
|
||||
for doc_id in doc_ids:
|
||||
corpus_id = code_docs["doc_id"].index(doc_id)
|
||||
corpus_id = str(corpus_id)
|
||||
qrels.append({"query-id": query_id, "corpus-id": corpus_id, "score": 1})
|
||||
|
||||
return queries, docs, qrels
|
||||
|
||||
|
||||
def main():
|
||||
args.source_dir = pathlib.Path(__file__).parent.parent / args.source_dir
|
||||
os.makedirs(args.source_dir, exist_ok=True)
|
||||
download_source(args.source_dir)
|
||||
download_dataset(args.source_dir)
|
||||
dataset = get_dataset(args.source_dir, mode=args.mode, key=args.key)
|
||||
|
||||
path = os.path.join(args.output_dir, f"ds1000_{args.key.lower()}_{args.mode.lower()}")
|
||||
os.makedirs(path, exist_ok=True)
|
||||
os.makedirs(os.path.join(path, "qrels"), exist_ok=True)
|
||||
|
||||
queries, docs, qrels = document2code(dataset)
|
||||
save_tsv_dict(qrels, os.path.join(path, "qrels", "test.tsv"), ["query-id", "corpus-id", "score"])
|
||||
|
||||
save_file_jsonl(queries, os.path.join(path, "queries.jsonl"))
|
||||
save_file_jsonl(docs, os.path.join(path, "corpus.jsonl"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--source_dir", type=str, default="ds")
|
||||
parser.add_argument("--output_dir", type=str, default="datasets")
|
||||
parser.add_argument("--mode", type=str, default="Completion", choices=["Completion", "Insertion"])
|
||||
parser.add_argument("--key", type=str, default="All",
|
||||
choices=["All", "Numpy", "Pandas", "Scipy", "Matplotlib", "Sklearn", "Tensorflow", "Pytorch"])
|
||||
parser.add_argument("--canonical_file", type=str, default="datasets/canonical/ds1000_docs.json")
|
||||
args = parser.parse_args()
|
||||
|
||||
main()
|
@ -0,0 +1,70 @@
|
||||
"""Aggregate all code-generation datasets."""
|
||||
|
||||
import os
|
||||
import json
|
||||
import datasets
|
||||
import argparse
|
||||
from create.utils import save_tsv_dict
|
||||
from create.humaneval import document2code as d2c_humaneval
|
||||
from create.mbpp import document2code as d2c_mbpp
|
||||
|
||||
D2C_FUNC_DICT = {
|
||||
"humaneval": d2c_humaneval,
|
||||
"mbpp": d2c_mbpp,
|
||||
}
|
||||
SPLIT_DICT = {
|
||||
"humaneval": ["test"],
|
||||
"mbpp": ["train", "test", "validation", "prompt"],
|
||||
}
|
||||
HF_NAME_DICT = {
|
||||
"humaneval": "openai_humaneval",
|
||||
"mbpp": "mbpp",
|
||||
}
|
||||
|
||||
|
||||
def save_file_jsonl(data, path):
|
||||
with open(path,'w') as fw:
|
||||
for item in data:
|
||||
fw.write(json.dumps(item) + '\n')
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--dataset_names", type=str, nargs='+', default=["humaneval", "mbpp"])
|
||||
parser.add_argument("--output_dir", type=str, default="datasets")
|
||||
parser.add_argument("--output_name", type=str, default="general-programming")
|
||||
args = parser.parse_args()
|
||||
|
||||
path = os.path.join(args.output_dir, args.output_name)
|
||||
os.makedirs(path)
|
||||
os.makedirs(os.path.join(path, "qrels"), exist_ok=True)
|
||||
|
||||
split_dict = {}
|
||||
for dataset_name in args.dataset_names:
|
||||
for split in SPLIT_DICT[dataset_name]:
|
||||
if split not in split_dict:
|
||||
split_dict[split] = []
|
||||
split_dict[split].append(dataset_name)
|
||||
|
||||
dataset_dict = {
|
||||
dataset_name: datasets.load_dataset(HF_NAME_DICT[dataset_name])
|
||||
for dataset_name in args.dataset_names
|
||||
}
|
||||
docs, queries = [], []
|
||||
for split, ds_names in split_dict.items():
|
||||
for ds in ds_names:
|
||||
dataset = dataset_dict[ds]
|
||||
|
||||
queries_split, docs_split, qrels_split = D2C_FUNC_DICT[ds](dataset, split)
|
||||
docs += docs_split
|
||||
queries += queries_split
|
||||
|
||||
qrels_path = os.path.join(path, "qrels", f"{split}.tsv")
|
||||
save_tsv_dict(qrels_split, qrels_path, ["query-id", "corpus-id", "score"])
|
||||
|
||||
save_file_jsonl(queries, os.path.join(path, "queries.jsonl"))
|
||||
save_file_jsonl(docs, os.path.join(path, "corpus.jsonl"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,58 @@
|
||||
import os
|
||||
import argparse
|
||||
import datasets
|
||||
from tqdm import tqdm
|
||||
from create.utils import save_tsv_dict, save_file_jsonl
|
||||
|
||||
|
||||
def document2code(data, split="test"):
|
||||
data = data[split]
|
||||
queries, docs, qrels = [], [], []
|
||||
|
||||
for item in tqdm(data):
|
||||
doc = item["prompt"]
|
||||
code = item["prompt"] + '\n' + item["canonical_solution"]
|
||||
doc_id = "{task_id}_doc".format_map(item)
|
||||
code_id = "{task_id}_code".format_map(item)
|
||||
|
||||
queries.append({"_id": doc_id, "text": doc, "metadata": {}})
|
||||
docs.append({"_id": code_id, "title": item["entry_point"], "text": code, "metadata": {}})
|
||||
qrels.append({"query-id": doc_id, "corpus-id": code_id, "score": 1})
|
||||
|
||||
return queries, docs, qrels
|
||||
|
||||
|
||||
def main():
|
||||
dataset = datasets.load_dataset(args.dataset_name)
|
||||
|
||||
path = os.path.join(args.output_dir, args.output_name)
|
||||
os.makedirs(path, exist_ok=True)
|
||||
os.makedirs(os.path.join(path, "qrels"), exist_ok=True)
|
||||
|
||||
queries, docs, qrels = document2code(dataset, split="test")
|
||||
save_file_jsonl(queries, os.path.join(path, "queries.jsonl"))
|
||||
save_file_jsonl(docs, os.path.join(path, "corpus.jsonl"))
|
||||
qrels_path = os.path.join(path, "qrels", "test.tsv")
|
||||
save_tsv_dict(qrels, qrels_path, ["query-id", "corpus-id", "score"])
|
||||
|
||||
# create canonical file if not existent yet
|
||||
if not os.path.exists(args.canonical_file):
|
||||
canonical_solutions = []
|
||||
for doc in docs:
|
||||
canonical_solutions.append([{
|
||||
"text": doc["text"], "title": doc["title"]
|
||||
}])
|
||||
canonical_dataset = dataset["test"].add_column("docs", canonical_solutions)
|
||||
canonical_dataset.to_json(args.canonical_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--dataset_name", type=str, default="openai_humaneval")
|
||||
parser.add_argument("--output_name", type=str, default="humaneval")
|
||||
parser.add_argument("--canonical_file", type=str,
|
||||
default="datasets/canonical/humaneval_solutions.json")
|
||||
parser.add_argument("--output_dir", type=str, default="datasets")
|
||||
args = parser.parse_args()
|
||||
|
||||
main()
|
@ -0,0 +1,53 @@
|
||||
import os
|
||||
import argparse
|
||||
import datasets
|
||||
from tqdm import tqdm
|
||||
from datasets import load_dataset
|
||||
from create.utils import save_tsv_dict, save_file_jsonl
|
||||
|
||||
|
||||
def get_queries(data, split="test") -> list[dict]:
|
||||
queries = [{
|
||||
"_id": item["question_id"] + '__' + item["contest_id"],
|
||||
"text": item["question_content"],
|
||||
"metadata": {}
|
||||
} for item in data[split]]
|
||||
return queries
|
||||
|
||||
def get_corpus(hf_name: str, cache_dir: str) -> list[dict]:
|
||||
dataset = load_dataset(hf_name, cache_dir=cache_dir)["train"]
|
||||
corpus = [
|
||||
{"_id": i, "text": item["text"], "title": item["title"]}
|
||||
for i,item in enumerate(dataset)
|
||||
]
|
||||
return corpus
|
||||
|
||||
|
||||
def main():
|
||||
dataset = datasets.load_dataset(args.dataset_name, cache_dir=args.cache_dir)
|
||||
|
||||
path = os.path.join(args.output_dir, args.output_name)
|
||||
os.makedirs(path, exist_ok=True)
|
||||
os.makedirs(os.path.join(path, "qrels"), exist_ok=True)
|
||||
|
||||
queries = get_queries(dataset, split="test")
|
||||
save_file_jsonl(queries, os.path.join(path, "queries.jsonl"))
|
||||
|
||||
docs = get_corpus(args.corpus_name, args.cache_dir)
|
||||
save_file_jsonl(docs, os.path.join(path, "corpus.jsonl"))
|
||||
|
||||
qrels = [] # no ground-truth solutions
|
||||
qrels_path = os.path.join(path, "qrels", "test.tsv")
|
||||
save_tsv_dict(qrels, qrels_path, ["query-id", "corpus-id", "score"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--dataset_name", type=str, default="livecodebench/code_generation")
|
||||
parser.add_argument("--corpus_name", type=str, default="code-rag-bench/programming-solutions")
|
||||
parser.add_argument("--cache_dir", type=str, default="/scratch/zhiruow/data")
|
||||
parser.add_argument("--output_name", type=str, default="livecodebench")
|
||||
parser.add_argument("--output_dir", type=str, default="datasets")
|
||||
args = parser.parse_args()
|
||||
|
||||
main()
|
@ -0,0 +1,78 @@
|
||||
import os
|
||||
import argparse
|
||||
import datasets
|
||||
from tqdm import tqdm
|
||||
from create.utils import save_tsv_dict, save_file_jsonl
|
||||
|
||||
|
||||
def get_function_name(code: str) -> str:
|
||||
"""Parse the function name for a code snippet string."""
|
||||
lines = code.split('\n')
|
||||
for line in lines:
|
||||
if line.lstrip().startswith("def "):
|
||||
break
|
||||
func_name = line.lstrip()[4: ]
|
||||
func_name = func_name.split('(')[0]
|
||||
return func_name
|
||||
|
||||
|
||||
def document2code(data, split="test"):
|
||||
data = data[split]
|
||||
queries, docs, qrels = [], [], []
|
||||
|
||||
for item in tqdm(data):
|
||||
doc = item["text"]
|
||||
code = "# " + item["text"] + '\n' + item["code"]
|
||||
doc_id = "{task_id}_doc".format_map(item)
|
||||
code_id = "{task_id}_code".format_map(item)
|
||||
|
||||
queries.append({"_id": doc_id, "text": doc, "metadata": {}})
|
||||
docs.append({"_id": code_id, "title": get_function_name(item["code"]), "text": code, "metadata": {}})
|
||||
qrels.append({"query-id": doc_id, "corpus-id": code_id, "score": 1})
|
||||
|
||||
return queries, docs, qrels
|
||||
|
||||
|
||||
def main():
|
||||
dataset = datasets.load_dataset(args.dataset_name)
|
||||
|
||||
path = os.path.join(args.output_dir, args.output_name)
|
||||
os.makedirs(path, exist_ok=True)
|
||||
os.makedirs(os.path.join(path, "qrels"), exist_ok=True)
|
||||
|
||||
docs, queries = [], []
|
||||
for split in args.splits:
|
||||
queries_split, docs_split, qrels_split = document2code(dataset, split)
|
||||
docs += docs_split
|
||||
queries += queries_split
|
||||
|
||||
qrels_path = os.path.join(path, "qrels", f"{split}.tsv")
|
||||
save_tsv_dict(qrels_split, qrels_path, ["query-id", "corpus-id", "score"])
|
||||
|
||||
# create canonical file for test split if not existent yet
|
||||
if split == "test" and (not os.path.exists(args.canonical_file)):
|
||||
canonical_solutions = []
|
||||
for doc in docs_split:
|
||||
canonical_solutions.append([{
|
||||
"text": doc["text"], "title": doc["title"]
|
||||
}])
|
||||
canonical_dataset = dataset["test"].add_column("docs", canonical_solutions)
|
||||
canonical_dataset.to_json(args.canonical_file)
|
||||
|
||||
save_file_jsonl(queries, os.path.join(path, "queries.jsonl"))
|
||||
save_file_jsonl(docs, os.path.join(path, "corpus.jsonl"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--dataset_name", type=str, default="google-research-datasets/mbpp")
|
||||
# parser.add_argument("--dataset_name", type=str, default="code-rag-bench/mbpp")
|
||||
parser.add_argument("--splits", type=str, default=["train", "validation", "test"],
|
||||
choices=["train", "validation", "test", "prompt"])
|
||||
parser.add_argument("--output_name", type=str, default="mbpp")
|
||||
parser.add_argument("--output_dir", type=str, default="datasets")
|
||||
parser.add_argument("--canonical_file", type=str,
|
||||
default="datasets/canonical/mbpp_solutions.json")
|
||||
|
||||
args = parser.parse_args()
|
||||
main()
|
@ -0,0 +1,76 @@
|
||||
import os
|
||||
import re
|
||||
import random
|
||||
import argparse
|
||||
import datasets
|
||||
from tqdm import tqdm
|
||||
from collections import Counter
|
||||
from datasets import load_dataset
|
||||
from create.utils import save_tsv_dict, save_file_jsonl
|
||||
|
||||
|
||||
def document2code(data, split="test"):
|
||||
data = data[split]
|
||||
queries, docs, qrels = [], [], []
|
||||
|
||||
# build doc corpus
|
||||
code_docs = load_dataset("neulab/docprompting-conala", "docs")["train"]
|
||||
for i in range(len(code_docs)):
|
||||
docs.append({
|
||||
"_id": str(i),
|
||||
"title": code_docs[i]["doc_id"],
|
||||
"text": code_docs[i]["doc_content"],
|
||||
"metadata": {}
|
||||
})
|
||||
|
||||
# load canonical docs
|
||||
odex = load_dataset("json", data_files={"test": args.canonical_file})["test"]
|
||||
# collect queries and query-doc matching
|
||||
for idx,item in enumerate(tqdm(data)):
|
||||
query = item["intent"]
|
||||
query_id = f"{idx}_{item['task_id']}"
|
||||
queries.append({"_id": query_id, "text": query, "metadata": {}})
|
||||
|
||||
doc_ids = [doc["title"] for doc in odex[idx]["docs"]]
|
||||
for doc_id in doc_ids:
|
||||
corpus_id = code_docs["doc_id"].index(doc_id)
|
||||
corpus_id = str(corpus_id)
|
||||
qrels.append({"query-id": query_id, "corpus-id": corpus_id, "score": 1})
|
||||
|
||||
return queries, docs, qrels
|
||||
|
||||
|
||||
def main():
|
||||
if '_' in args.dataset_name:
|
||||
dataset_name = args.dataset_name.split('_')[0]
|
||||
language = args.dataset_name.split('_')[1]
|
||||
else:
|
||||
dataset_name = args.dataset_name
|
||||
language = 'en'
|
||||
dataset = datasets.load_dataset(dataset_name, language) # english version by default
|
||||
|
||||
path = os.path.join(args.output_dir, args.output_name.replace('en', language))
|
||||
os.makedirs(path, exist_ok=True)
|
||||
os.makedirs(os.path.join(path, "qrels"), exist_ok=True)
|
||||
|
||||
docs, queries = [], []
|
||||
for split in ["test"]:
|
||||
queries_split, docs_split, qrels_split = document2code(dataset, split)
|
||||
docs += docs_split
|
||||
queries += queries_split
|
||||
|
||||
save_tsv_dict(qrels_split, os.path.join(path, "qrels", "{}.tsv".format(split)), ["query-id", "corpus-id", "score"])
|
||||
|
||||
save_file_jsonl(queries, os.path.join(path, "queries.jsonl"))
|
||||
save_file_jsonl(docs, os.path.join(path, "corpus.jsonl"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--dataset_name", type=str, default="neulab/odex")
|
||||
parser.add_argument("--output_name", type=str, default="odex_en")
|
||||
parser.add_argument("--canonical_file", type=str, default="datasets/canonical/odex_docs.json")
|
||||
parser.add_argument("--output_dir", type=str, default="datasets")
|
||||
args = parser.parse_args()
|
||||
|
||||
main()
|
@ -0,0 +1,306 @@
|
||||
import io
|
||||
import os
|
||||
import glob
|
||||
import json
|
||||
import argparse
|
||||
import requests
|
||||
import zipfile
|
||||
from collections import defaultdict
|
||||
from create.utils import save_tsv_dict, save_file_jsonl
|
||||
|
||||
REPOs_line_and_api = [
|
||||
'huggingface_diffusers',
|
||||
'nerfstudio-project_nerfstudio',
|
||||
'awslabs_fortuna',
|
||||
'huggingface_evaluate',
|
||||
'google_vizier',
|
||||
'alibaba_FederatedScope',
|
||||
'pytorch_rl',
|
||||
'opendilab_ACE',
|
||||
]
|
||||
|
||||
REPOs_function = [
|
||||
"amazon-science_patchcore-inspection",
|
||||
"deepmind_tracr",
|
||||
"facebookresearch_omnivore",
|
||||
"google_lightweight_mmm",
|
||||
"lucidrains_imagen-pytorch",
|
||||
"maxhumber_redframes",
|
||||
]
|
||||
|
||||
REPO_DIRs = {
|
||||
"api": "repositories/line_and_api_level",
|
||||
"line": "repositories/line_and_api_level",
|
||||
"function": "repositories/function_level",
|
||||
}
|
||||
|
||||
|
||||
def iterate_repository(base_dir: str, repo: str) -> dict:
|
||||
pattern = os.path.join(f'{base_dir}/{repo}', "**", "*.py")
|
||||
files = glob.glob(pattern, recursive=True)
|
||||
|
||||
skipped_files = []
|
||||
loaded_code_files = dict()
|
||||
base_dir_list = os.path.normpath(base_dir).split(os.sep)
|
||||
for fname in files:
|
||||
try:
|
||||
code = open(fname, 'r', encoding='utf8').read()
|
||||
fpath_tuple = tuple(os.path.normpath(fname).split(os.sep)[len(base_dir_list):])
|
||||
loaded_code_files[fpath_tuple]= code
|
||||
except Exception as e:
|
||||
skipped_files.append((fname, e))
|
||||
continue
|
||||
|
||||
if len(skipped_files) > 0:
|
||||
print(f"Skipped {len(skipped_files)} out of {len(files)} files due to I/O errors")
|
||||
for fname, e in skipped_files:
|
||||
print(f"{fname}: {e}")
|
||||
return loaded_code_files
|
||||
|
||||
|
||||
def window_overlap(span: tuple, target_span: tuple) -> bool:
|
||||
if span[0] >= target_span[1] or span[1] <= target_span[0]:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class RepoWindowMaker:
|
||||
def __init__(self, base_dir, repo, tasks, window_size, slice_size):
|
||||
self.base_dir = base_dir
|
||||
self.repo = repo
|
||||
self.window_size = window_size
|
||||
self.slice_size = slice_size
|
||||
self.slice_step = 1 if window_size // slice_size == 0 else window_size // slice_size
|
||||
self.tasks = tasks
|
||||
self.source_code_files = iterate_repository(base_dir, repo)
|
||||
|
||||
def _buid_windows_for_a_file(self, fpath_tuple, code):
|
||||
code_windows = []
|
||||
code_lines = code.splitlines()
|
||||
delta_size = self.window_size // 2
|
||||
for line_no in range(0, len(code_lines), self.slice_step): # line_no starts from 0
|
||||
start_line_no = max(0, line_no - delta_size)
|
||||
end_line_no = min(len(code_lines), line_no + self.window_size - delta_size)
|
||||
window_lines = [i for i in code_lines[start_line_no:end_line_no]]
|
||||
if not window_lines: # all empty lines
|
||||
continue
|
||||
window_text = '\n'.join(window_lines)
|
||||
code_windows.append({
|
||||
'context': window_text,
|
||||
'metadata': {
|
||||
'fpath_tuple': fpath_tuple,
|
||||
'line_no': line_no,
|
||||
'start_line_no': start_line_no,
|
||||
'end_line_no': end_line_no,
|
||||
'window_size': self.window_size,
|
||||
'repo': self.repo,
|
||||
'slice_size': self.slice_size,
|
||||
}
|
||||
})
|
||||
return code_windows
|
||||
|
||||
def _merge_windows_with_same_context(self, code_windows):
|
||||
merged_code_windows = defaultdict(list)
|
||||
for code_window in code_windows:
|
||||
context = code_window['context']
|
||||
metadata = code_window['metadata']
|
||||
merged_code_windows[context].append(metadata)
|
||||
json_lines = []
|
||||
for context, metadata_list in merged_code_windows.items():
|
||||
json_lines.append({
|
||||
'context': context,
|
||||
'metadata': metadata_list
|
||||
})
|
||||
return json_lines
|
||||
|
||||
def build_windows(self):
|
||||
all_code_windows = []
|
||||
for fpath_tuple, code in self.source_code_files.items():
|
||||
all_code_windows += self._buid_windows_for_a_file(fpath_tuple, code)
|
||||
merged_code_windows = self._merge_windows_with_same_context(all_code_windows)
|
||||
print(f'build {len(merged_code_windows)} windows for {self.repo} with window size {self.window_size} and slice {self.slice_size}')
|
||||
ground_truth_indices = {}
|
||||
for task in self.tasks:
|
||||
fpath_tuple = tuple(task['metadata']['fpath_tuple'])
|
||||
line_no = task['metadata']['line_no']
|
||||
start_line_no = task['metadata']['context_start_lineno']
|
||||
for i, window in enumerate(merged_code_windows):
|
||||
if window["metadata"][0]["fpath_tuple"] != fpath_tuple:
|
||||
continue
|
||||
if any([
|
||||
window_overlap(
|
||||
(sub_window["start_line_no"], sub_window["end_line_no"]),
|
||||
(start_line_no, line_no + 1)
|
||||
)
|
||||
for sub_window in window["metadata"]
|
||||
]):
|
||||
if i not in ground_truth_indices:
|
||||
ground_truth_indices[i] = []
|
||||
ground_truth_indices[i].append(task["metadata"]["task_id"])
|
||||
|
||||
return merged_code_windows, ground_truth_indices
|
||||
|
||||
|
||||
def download_data(directory: str = "repoeval"):
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
datasets_dir = os.path.join(directory, "datasets")
|
||||
repos_lineapi_dir = os.path.join(directory, "repositories", "line_and_api_level")
|
||||
repos_function_dir = os.path.join(directory, "repositories", "function_level")
|
||||
|
||||
print(f"Start downloading the necessary `datasets` and `repositories` files.")
|
||||
if not os.path.exists(datasets_dir):
|
||||
print(f"Start downloading the `datasets`.")
|
||||
datasets_url = "https://github.com/microsoft/CodeT/raw/main/RepoCoder/datasets/datasets.zip"
|
||||
r = requests.get(datasets_url, stream=True)
|
||||
z = zipfile.ZipFile(io.BytesIO(r.content))
|
||||
z.extractall(datasets_dir)
|
||||
print("Finished downloading the `datasets` files.")
|
||||
|
||||
if not os.path.exists(repos_lineapi_dir):
|
||||
print(f"Start downloading the `repositories` (line_and_api).")
|
||||
repos_lineapi_url = "https://github.com/microsoft/CodeT/raw/main/RepoCoder/repositories/line_and_api_level.zip"
|
||||
r = requests.get(repos_lineapi_url, stream=True)
|
||||
z = zipfile.ZipFile(io.BytesIO(r.content))
|
||||
z.extractall(repos_lineapi_dir)
|
||||
|
||||
if not os.path.exists(repos_function_dir):
|
||||
print(f"Start downloading the `repositories` (function).")
|
||||
# repos_function_url = "https://github.com/microsoft/CodeT/raw/main/RepoCoder/repositories/function_level.zip"
|
||||
repos_function_url = "https://github.com/Veronicium/repoeval_debug/raw/main/function_level.zip"
|
||||
r = requests.get(repos_function_url, stream=True)
|
||||
z = zipfile.ZipFile(io.BytesIO(r.content))
|
||||
z.extractall(repos_function_dir)
|
||||
print("Finished downloading the `repositories` files.")
|
||||
|
||||
|
||||
def repo2code(
|
||||
repo: str, data_cache_dir: str,
|
||||
split: str, context_length: str,
|
||||
window_size: int, slice_size: int
|
||||
):
|
||||
# load test examples
|
||||
file_name = f"{split}_level_completion_{context_length}_context_codex.test.jsonl"
|
||||
if split == 'function':
|
||||
file_name = file_name.replace('.test.jsonl', '.test.clean.jsonl')
|
||||
|
||||
task_path = os.path.join(data_cache_dir, "datasets", file_name)
|
||||
tasks = [json.loads(l.rstrip()) for l in open(task_path, 'r')]
|
||||
tasks = [task for task in tasks if repo == task['metadata']['task_id'].split('/')[0]]
|
||||
|
||||
# collect queries
|
||||
queries = []
|
||||
for task in tasks:
|
||||
query_id = task["metadata"]["task_id"]
|
||||
# text = '\n'.join(task["prompt"].split('\n')[-2:])
|
||||
text = task["prompt"]
|
||||
metadata = task["metadata"]
|
||||
queries.append({"_id": query_id, "text": text, "metadata": metadata})
|
||||
|
||||
base_dir = os.path.join(data_cache_dir, REPO_DIRs[split])
|
||||
repo_window_maker = RepoWindowMaker(base_dir, repo, tasks, window_size, slice_size)
|
||||
windows, ground_truth_indices = repo_window_maker.build_windows()
|
||||
corpus, qrels = [], []
|
||||
query_id2gt = {task['metadata']['task_id']:[] for task in tasks}
|
||||
for i, window in enumerate(windows):
|
||||
path = '-'.join(window["metadata"][0]["fpath_tuple"])
|
||||
line = f"{window['metadata'][0]['start_line_no']}-{window['metadata'][-1]['end_line_no']}"
|
||||
corpus_id = f"{repo}_{path}_{line}"
|
||||
corpus.append({
|
||||
"_id": corpus_id, "title": path,
|
||||
"text": window["context"], "metadata": window["metadata"]
|
||||
})
|
||||
if i in ground_truth_indices:
|
||||
for query_id in ground_truth_indices[i]:
|
||||
qrels.append({"query-id": query_id, "corpus-id": corpus_id, "score": 1})
|
||||
query_id2gt[query_id].append({"title": corpus_id.replace('_', '/'), "text": window["context"]})
|
||||
|
||||
return queries, corpus, qrels, query_id2gt
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--output_dir", type=str, default="datasets")
|
||||
parser.add_argument("--results_dir", type=str, default="results")
|
||||
parser.add_argument("--split", type=str, required=True, choices=["api", "line", "function"])
|
||||
parser.add_argument("--context_length", type=str, default="1k", choices=["1k", "2k", "4k"])
|
||||
parser.add_argument("--data_cache_dir", type=str, default="output/repoeval")
|
||||
parser.add_argument("--window_size", type=int, default=20)
|
||||
parser.add_argument("--slice_size", type=int, default=2)
|
||||
args = parser.parse_args()
|
||||
|
||||
download_data(args.data_cache_dir)
|
||||
|
||||
path = os.path.join(args.output_dir, "repoeval", args.split)
|
||||
os.makedirs(path, exist_ok=True)
|
||||
REPOs = REPOs_function if args.split == "function" else REPOs_line_and_api
|
||||
|
||||
file_name = f"{args.split}_level_completion_{args.context_length}_context_codex.test.jsonl"
|
||||
data_path = os.path.join(args.data_cache_dir, "datasets", file_name)
|
||||
data = [json.loads(l.rstrip()) for l in open(data_path, 'r')]
|
||||
|
||||
# preprocess function completion data (the data in the RepoCoder repo isn't correctly formatted)
|
||||
if args.split == 'function':
|
||||
repo2idx = {}
|
||||
for task in data:
|
||||
repo = task['metadata']['task_id'].replace('--', '_').split('/')[0]
|
||||
if repo not in repo2idx:
|
||||
repo2idx[repo] = 0
|
||||
task['metadata']['task_id'] = task['metadata']['task_id'].replace('--', '_').replace('idx', str(repo2idx[repo]))
|
||||
task['metadata']['line_no'] = task['metadata']['lineno']
|
||||
repo2idx[repo] += 1
|
||||
|
||||
new_data_path = data_path.replace('.test.jsonl', '.test.clean.jsonl')
|
||||
with open(new_data_path, 'w') as f:
|
||||
for task in data:
|
||||
repo = task['metadata']['task_id'].split('/')[0]
|
||||
if repo not in REPOs:
|
||||
continue
|
||||
f.write(json.dumps(task) + '\n')
|
||||
|
||||
data = [json.loads(l.rstrip()) for l in open(new_data_path, 'r')]
|
||||
|
||||
# build query, docs, and qrels for each repository
|
||||
queries, corpus, qrels = [], [], []
|
||||
query_id2gt = {}
|
||||
for repo in REPOs:
|
||||
repo_queries, repo_corpus, repo_qrels, repo_query_id2gt = repo2code(
|
||||
repo, args.data_cache_dir,
|
||||
args.split, args.context_length,
|
||||
args.window_size, args.slice_size
|
||||
)
|
||||
queries += repo_queries
|
||||
corpus += repo_corpus
|
||||
qrels += repo_qrels
|
||||
query_id2gt.update(repo_query_id2gt)
|
||||
|
||||
save_file_jsonl(queries, os.path.join(path, "queries.jsonl"))
|
||||
save_file_jsonl(corpus, os.path.join(path, "corpus.jsonl"))
|
||||
save_tsv_dict(qrels, os.path.join(path, "qrels", "test.tsv"), ["query-id", "corpus-id", "score"])
|
||||
|
||||
gt_data = []
|
||||
for example in data:
|
||||
query_id = example['metadata']['task_id']
|
||||
gt = query_id2gt[query_id]
|
||||
new_example = {
|
||||
"prompt": example["prompt"],
|
||||
"reference": example["metadata"]["ground_truth"],
|
||||
"docs": gt[:10],
|
||||
"metadata": {k:v for k,v in example["metadata"].items() if k != "ground_truth"},
|
||||
}
|
||||
gt_data.append(new_example)
|
||||
|
||||
results_file = os.path.join(args.results_dir, f"repoeval-{args.split}-{args.context_length}-gt.jsonl")
|
||||
with open(results_file, "w") as fw:
|
||||
for ex in gt_data:
|
||||
fw.write(json.dumps(ex) + "\n")
|
||||
|
||||
results_file = os.path.join(args.results_dir, f"repoeval-{args.split}-{args.context_length}-infile.jsonl")
|
||||
with open(results_file, "w") as fw:
|
||||
for ex in gt_data:
|
||||
ex = {k:v for k,v in ex.items() if k != "docs"}
|
||||
ex["docs"] = []
|
||||
fw.write(json.dumps(ex) + "\n")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,306 @@
|
||||
import io
|
||||
import os
|
||||
import glob
|
||||
import json
|
||||
import argparse
|
||||
import requests
|
||||
import zipfile
|
||||
from collections import defaultdict
|
||||
from create.utils import save_tsv_dict, save_file_jsonl
|
||||
|
||||
REPOs_line_and_api = [
|
||||
'huggingface_diffusers',
|
||||
'nerfstudio-project_nerfstudio',
|
||||
'awslabs_fortuna',
|
||||
'huggingface_evaluate',
|
||||
'google_vizier',
|
||||
'alibaba_FederatedScope',
|
||||
'pytorch_rl',
|
||||
'opendilab_ACE',
|
||||
]
|
||||
|
||||
REPOs_function = [
|
||||
"amazon-science_patchcore-inspection",
|
||||
"deepmind_tracr",
|
||||
"facebookresearch_omnivore",
|
||||
"google_lightweight_mmm",
|
||||
"lucidrains_imagen-pytorch",
|
||||
"maxhumber_redframes",
|
||||
]
|
||||
|
||||
REPO_DIRs = {
|
||||
"api": "repositories/line_and_api_level",
|
||||
"line": "repositories/line_and_api_level",
|
||||
"function": "repositories/function_level",
|
||||
}
|
||||
|
||||
|
||||
def iterate_repository(base_dir: str, repo: str) -> dict:
|
||||
pattern = os.path.join(f'{base_dir}/{repo}', "**", "*.py")
|
||||
files = glob.glob(pattern, recursive=True)
|
||||
|
||||
skipped_files = []
|
||||
loaded_code_files = dict()
|
||||
base_dir_list = os.path.normpath(base_dir).split(os.sep)
|
||||
for fname in files:
|
||||
try:
|
||||
code = open(fname, 'r', encoding='utf8').read()
|
||||
fpath_tuple = tuple(os.path.normpath(fname).split(os.sep)[len(base_dir_list):])
|
||||
loaded_code_files[fpath_tuple]= code
|
||||
except Exception as e:
|
||||
skipped_files.append((fname, e))
|
||||
continue
|
||||
|
||||
if len(skipped_files) > 0:
|
||||
print(f"Skipped {len(skipped_files)} out of {len(files)} files due to I/O errors")
|
||||
for fname, e in skipped_files:
|
||||
print(f"{fname}: {e}")
|
||||
return loaded_code_files
|
||||
|
||||
|
||||
def window_overlap(span: tuple, target_span: tuple) -> bool:
|
||||
if span[0] >= target_span[1] or span[1] <= target_span[0]:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class RepoWindowMaker:
|
||||
def __init__(self, base_dir, repo, tasks, window_size, slice_size):
|
||||
self.base_dir = base_dir
|
||||
self.repo = repo
|
||||
self.window_size = window_size
|
||||
self.slice_size = slice_size
|
||||
self.slice_step = 1 if window_size // slice_size == 0 else window_size // slice_size
|
||||
self.tasks = tasks
|
||||
self.source_code_files = iterate_repository(base_dir, repo)
|
||||
|
||||
def _buid_windows_for_a_file(self, fpath_tuple, code):
|
||||
code_windows = []
|
||||
code_lines = code.splitlines()
|
||||
delta_size = self.window_size // 2
|
||||
for line_no in range(0, len(code_lines), self.slice_step): # line_no starts from 0
|
||||
start_line_no = max(0, line_no - delta_size)
|
||||
end_line_no = min(len(code_lines), line_no + self.window_size - delta_size)
|
||||
window_lines = [i for i in code_lines[start_line_no:end_line_no]]
|
||||
if not window_lines: # all empty lines
|
||||
continue
|
||||
window_text = '\n'.join(window_lines)
|
||||
code_windows.append({
|
||||
'context': window_text,
|
||||
'metadata': {
|
||||
'fpath_tuple': fpath_tuple,
|
||||
'line_no': line_no,
|
||||
'start_line_no': start_line_no,
|
||||
'end_line_no': end_line_no,
|
||||
'window_size': self.window_size,
|
||||
'repo': self.repo,
|
||||
'slice_size': self.slice_size,
|
||||
}
|
||||
})
|
||||
return code_windows
|
||||
|
||||
def _merge_windows_with_same_context(self, code_windows):
|
||||
merged_code_windows = defaultdict(list)
|
||||
for code_window in code_windows:
|
||||
context = code_window['context']
|
||||
metadata = code_window['metadata']
|
||||
merged_code_windows[context].append(metadata)
|
||||
json_lines = []
|
||||
for context, metadata_list in merged_code_windows.items():
|
||||
json_lines.append({
|
||||
'context': context,
|
||||
'metadata': metadata_list
|
||||
})
|
||||
return json_lines
|
||||
|
||||
def build_windows(self):
|
||||
all_code_windows = []
|
||||
for fpath_tuple, code in self.source_code_files.items():
|
||||
all_code_windows += self._buid_windows_for_a_file(fpath_tuple, code)
|
||||
merged_code_windows = self._merge_windows_with_same_context(all_code_windows)
|
||||
print(f'build {len(merged_code_windows)} windows for {self.repo} with window size {self.window_size} and slice {self.slice_size}')
|
||||
ground_truth_indices = {}
|
||||
for task in self.tasks:
|
||||
fpath_tuple = tuple(task['metadata']['fpath_tuple'])
|
||||
line_no = task['metadata']['line_no']
|
||||
start_line_no = task['metadata']['context_start_lineno']
|
||||
for i, window in enumerate(merged_code_windows):
|
||||
# print(window["metadata"][0]["fpath_tuple"], fpath_tuple)
|
||||
if window["metadata"][0]["fpath_tuple"] != fpath_tuple and ' '.join(list(window["metadata"][0]["fpath_tuple"])) != ' '.join(list(fpath_tuple)):
|
||||
continue
|
||||
# print(1)
|
||||
if any([
|
||||
window_overlap(
|
||||
(sub_window["start_line_no"], sub_window["end_line_no"]),
|
||||
(start_line_no, line_no + 1)
|
||||
)
|
||||
for sub_window in window["metadata"]
|
||||
]):
|
||||
# print('test')
|
||||
if i not in ground_truth_indices:
|
||||
ground_truth_indices[i] = []
|
||||
ground_truth_indices[i].append(task["metadata"]["task_id"])
|
||||
# sys.exit()
|
||||
return merged_code_windows, ground_truth_indices
|
||||
|
||||
|
||||
def download_data(directory: str = "repoeval"):
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
datasets_dir = os.path.join(directory, "datasets")
|
||||
repos_function_dir = os.path.join(directory, "repositories", "function_level")
|
||||
|
||||
print(f"Start downloading the necessary `datasets` and `repositories` files.")
|
||||
if not os.path.exists(datasets_dir):
|
||||
print(f"Start downloading the `datasets`.")
|
||||
datasets_url = "https://github.com/microsoft/CodeT/raw/main/RepoCoder/datasets/datasets.zip"
|
||||
r = requests.get(datasets_url, stream=True)
|
||||
z = zipfile.ZipFile(io.BytesIO(r.content))
|
||||
z.extractall(datasets_dir)
|
||||
print("Finished downloading the `datasets` files.")
|
||||
|
||||
import shutil
|
||||
shutil.rmtree(repos_function_dir)
|
||||
if not os.path.exists(repos_function_dir):
|
||||
print(f"Start downloading the `repositories` (function).")
|
||||
repos_function_url = "https://github.com/microsoft/CodeT/raw/main/RepoCoder/repositories/function_level.zip"
|
||||
# repos_function_url = "https://github.com/Veronicium/repoeval_debug/raw/main/function_level.zip"
|
||||
r = requests.get(repos_function_url, stream=True)
|
||||
z = zipfile.ZipFile(io.BytesIO(r.content))
|
||||
z.extractall(repos_function_dir)
|
||||
print("Finished downloading the `repositories` files.")
|
||||
|
||||
|
||||
def repo2code(
|
||||
repo: str, tasks: list[dict], data_cache_dir: str,
|
||||
split: str, context_length: str,
|
||||
window_size: int, slice_size: int
|
||||
):
|
||||
# collect queries
|
||||
queries = []
|
||||
for task in tasks:
|
||||
query_id = task["metadata"]["task_id"]
|
||||
# text = '\n'.join(task["prompt"].split('\n')[-2:])
|
||||
text = task["prompt"]
|
||||
metadata = task["metadata"]
|
||||
queries.append({"_id": query_id, "text": text, "metadata": metadata})
|
||||
|
||||
base_dir = os.path.join(data_cache_dir, REPO_DIRs[split])
|
||||
repo_window_maker = RepoWindowMaker(base_dir, repo, tasks, window_size, slice_size)
|
||||
windows, ground_truth_indices = repo_window_maker.build_windows()
|
||||
corpus, qrels = [], []
|
||||
query_id2gt = {task['metadata']['task_id']:[] for task in tasks}
|
||||
for i, window in enumerate(windows):
|
||||
path = '-'.join(window["metadata"][0]["fpath_tuple"])
|
||||
line = f"{window['metadata'][0]['start_line_no']}-{window['metadata'][-1]['end_line_no']}"
|
||||
corpus_id = f"{repo}_{path}_{line}"
|
||||
corpus.append({
|
||||
"_id": corpus_id, "title": path,
|
||||
"text": window["context"], "metadata": window["metadata"]
|
||||
})
|
||||
# print(windows, ground_truth_indices)
|
||||
if i in ground_truth_indices:
|
||||
for query_id in ground_truth_indices[i]:
|
||||
qrels.append({"query-id": query_id, "corpus-id": corpus_id, "score": 1})
|
||||
query_id2gt[query_id].append({"title": corpus_id.replace('_', '/'), "text": window["context"]})
|
||||
|
||||
return queries, corpus, qrels, query_id2gt
|
||||
|
||||
|
||||
def main():
|
||||
download_data(args.data_cache_dir)
|
||||
|
||||
REPOs = REPOs_function if args.split == "function" else REPOs_line_and_api
|
||||
|
||||
file_name = f"{args.split}_level_completion_{args.context_length}_context_codex.test.jsonl"
|
||||
data_path = os.path.join(args.data_cache_dir, "datasets", file_name)
|
||||
data = [json.loads(l.rstrip()) for l in open(data_path, 'r')]
|
||||
|
||||
# preprocess function completion data (the data in the RepoCoder repo isn't correctly formatted)
|
||||
if args.split == 'function':
|
||||
repo2idx = {}
|
||||
for task in data:
|
||||
repo = task['metadata']['task_id'].replace('--', '_').split('/')[0]
|
||||
if repo not in repo2idx:
|
||||
repo2idx[repo] = 0
|
||||
task['metadata']['task_id'] = task['metadata']['task_id'].replace('--', '_').replace('idx', str(repo2idx[repo]))
|
||||
task['metadata']['line_no'] = task['metadata']['lineno']
|
||||
repo2idx[repo] += 1
|
||||
|
||||
new_data_path = data_path.replace('.test.jsonl', '.test.clean.jsonl')
|
||||
with open(new_data_path, 'w') as f:
|
||||
for task in data:
|
||||
repo = task['metadata']['task_id'].split('/')[0]
|
||||
if repo not in REPOs:
|
||||
continue
|
||||
f.write(json.dumps(task) + '\n')
|
||||
|
||||
data = [json.loads(l.rstrip()) for l in open(new_data_path, 'r')]
|
||||
|
||||
# group data instances by repository
|
||||
data_dict = {}
|
||||
for ex in data:
|
||||
repo_name = ex["metadata"]["task_id"]
|
||||
repo_name = repo_name.split('/')[0]
|
||||
if repo_name not in data_dict:
|
||||
data_dict[repo_name] = []
|
||||
data_dict[repo_name].append(ex)
|
||||
|
||||
# build query, docs, and qrels for each repository
|
||||
for repo in REPOs:
|
||||
queries, corpus, qrels, query_id2gt = repo2code(
|
||||
repo, data_dict[repo], args.data_cache_dir,
|
||||
args.split, args.context_length,
|
||||
args.window_size, args.slice_size
|
||||
)
|
||||
print(len(queries))
|
||||
if len(qrels) == 0:
|
||||
print(repo)
|
||||
# sys.exit()
|
||||
continue
|
||||
# sys.exit()
|
||||
path = os.path.join(args.output_dir, f"repoeval__{repo}")
|
||||
os.makedirs(path, exist_ok=True)
|
||||
save_file_jsonl(queries, os.path.join(path, "queries.jsonl"))
|
||||
save_file_jsonl(corpus, os.path.join(path, "corpus.jsonl"))
|
||||
save_tsv_dict(qrels, os.path.join(path, "qrels", "test.tsv"), ["query-id", "corpus-id", "score"])
|
||||
|
||||
gt_data = []
|
||||
for example in data_dict[repo]:
|
||||
query_id = example['metadata']['task_id']
|
||||
gt = query_id2gt[query_id]
|
||||
new_example = {
|
||||
"prompt": example["prompt"],
|
||||
"reference": example["metadata"]["ground_truth"],
|
||||
"docs": gt[:10],
|
||||
"metadata": {k:v for k,v in example["metadata"].items() if k != "ground_truth"},
|
||||
}
|
||||
gt_data.append(new_example)
|
||||
|
||||
os.makedirs(args.results_dir, exist_ok=True)
|
||||
|
||||
results_file = os.path.join(args.results_dir, f"repoeval-{args.split}-{repo}-{args.context_length}-gt.jsonl")
|
||||
with open(results_file, "w") as fw:
|
||||
for ex in gt_data:
|
||||
fw.write(json.dumps(ex) + "\n")
|
||||
|
||||
results_file = os.path.join(args.results_dir, f"repoeval-{args.split}-{repo}-{args.context_length}-infile.jsonl")
|
||||
with open(results_file, "w") as fw:
|
||||
for ex in gt_data:
|
||||
ex = {k:v for k,v in ex.items() if k != "docs"}
|
||||
ex["docs"] = []
|
||||
fw.write(json.dumps(ex) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--output_dir", type=str, default="datasets")
|
||||
parser.add_argument("--results_dir", type=str, default="results")
|
||||
parser.add_argument("--split", type=str, default="function", choices=["function"])
|
||||
parser.add_argument("--context_length", type=str, default="2k", choices=["1k", "2k", "4k"])
|
||||
parser.add_argument("--data_cache_dir", type=str, default="output/repoeval")
|
||||
parser.add_argument("--window_size", type=int, default=50)
|
||||
parser.add_argument("--slice_size", type=int, default=5)
|
||||
args = parser.parse_args()
|
||||
|
||||
main()
|
@ -0,0 +1,247 @@
|
||||
import os
|
||||
import re
|
||||
import chardet
|
||||
import unidiff
|
||||
import argparse
|
||||
import datasets
|
||||
import traceback
|
||||
import subprocess
|
||||
from git import Repo
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from create.utils import save_tsv_dict, save_file_jsonl
|
||||
|
||||
# %% Get oracle file contents
|
||||
|
||||
# get oracle file contents from the repo
|
||||
class ContextManager:
|
||||
def __init__(self, repo_path, base_commit, verbose=False):
|
||||
self.repo_path = Path(repo_path).resolve().as_posix()
|
||||
self.old_dir = os.getcwd()
|
||||
self.base_commit = base_commit
|
||||
self.verbose = verbose
|
||||
|
||||
def __enter__(self):
|
||||
os.chdir(self.repo_path)
|
||||
cmd = f"git reset --hard {self.base_commit} && git clean -fdxq"
|
||||
if self.verbose:
|
||||
subprocess.run(cmd, shell=True, check=True)
|
||||
else:
|
||||
subprocess.run(
|
||||
cmd,
|
||||
shell=True,
|
||||
check=True,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
return self
|
||||
|
||||
def get_environment(self):
|
||||
raise NotImplementedError() # TODO: activate conda environment and return the environment file
|
||||
|
||||
def get_readme_files(self):
|
||||
files = os.listdir(self.repo_path)
|
||||
files = list(filter(lambda x: os.path.isfile(x), files))
|
||||
files = list(filter(lambda x: x.lower().startswith("readme"), files))
|
||||
return files
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
os.chdir(self.old_dir)
|
||||
|
||||
|
||||
class AutoContextManager(ContextManager):
|
||||
"""Automatically clones the repo if it doesn't exist"""
|
||||
|
||||
def __init__(self, instance, root_dir=None, verbose=False, token=None):
|
||||
if token is None:
|
||||
token = os.environ.get("GITHUB_TOKEN", "git")
|
||||
self.tempdir = None
|
||||
if root_dir is None:
|
||||
self.tempdir = TemporaryDirectory()
|
||||
root_dir = self.tempdir.name
|
||||
self.root_dir = root_dir
|
||||
repo_dir = os.path.join(self.root_dir, instance["repo"].replace("/", "__"))
|
||||
if not os.path.exists(repo_dir):
|
||||
repo_url = (
|
||||
f"https://{token}@github.com/swe-bench/"
|
||||
+ instance["repo"].replace("/", "__")
|
||||
+ ".git"
|
||||
)
|
||||
if verbose:
|
||||
print(f"Cloning {instance['repo']} to {root_dir}")
|
||||
Repo.clone_from(repo_url, repo_dir)
|
||||
super().__init__(repo_dir, instance["base_commit"], verbose=verbose)
|
||||
self.instance = instance
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.tempdir is not None:
|
||||
self.tempdir.cleanup()
|
||||
return super().__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
|
||||
def ingest_files(filenames):
|
||||
files_dict = dict()
|
||||
for filename in filenames:
|
||||
with open(filename) as f:
|
||||
content = f.read()
|
||||
files_dict[filename] = content
|
||||
return files_dict
|
||||
|
||||
def get_oracle_filenames(instance):
|
||||
"""
|
||||
Returns the filenames that are changed in the patch
|
||||
"""
|
||||
source_files = {
|
||||
patch_file.source_file.split("a/", 1)[-1]
|
||||
for patch_file in unidiff.PatchSet(instance["patch"])
|
||||
}
|
||||
gold_docs = set()
|
||||
for source_file in source_files:
|
||||
gold_docs.add(source_file)
|
||||
return gold_docs
|
||||
|
||||
|
||||
# get all file contents from the repo
|
||||
def is_test(name, test_phrases=None):
|
||||
if test_phrases is None:
|
||||
test_phrases = ["test", "tests", "testing"]
|
||||
words = set(re.split(r" |_|\/|\.", name.lower()))
|
||||
return any(word in words for word in test_phrases)
|
||||
|
||||
def list_files(root_dir, include_tests=False):
|
||||
files = []
|
||||
for filename in Path(root_dir).rglob("*.py"):
|
||||
if not include_tests and is_test(filename.as_posix()):
|
||||
continue
|
||||
files.append(filename.relative_to(root_dir).as_posix())
|
||||
return files
|
||||
|
||||
def detect_encoding(filename):
|
||||
"""
|
||||
Detect the encoding of a file
|
||||
"""
|
||||
with open(filename, "rb") as file:
|
||||
rawdata = file.read()
|
||||
return chardet.detect(rawdata)["encoding"]
|
||||
|
||||
def ingest_directory_contents(root_dir, include_tests=False):
|
||||
files_content = {}
|
||||
for relative_path in list_files(root_dir, include_tests=include_tests):
|
||||
filename = os.path.join(root_dir, relative_path)
|
||||
encoding = detect_encoding(filename)
|
||||
if encoding is None:
|
||||
content = "[BINARY DATA FILE]"
|
||||
else:
|
||||
try:
|
||||
with open(filename, encoding=encoding) as file:
|
||||
content = file.read()
|
||||
except (UnicodeDecodeError, LookupError):
|
||||
content = "[BINARY DATA FILE]"
|
||||
files_content[relative_path] = content
|
||||
return files_content
|
||||
|
||||
def get_file_contents(input_instances, verbose: bool = False, tmp_dir: str = "/scratch"):
|
||||
orig_dir = os.getcwd()
|
||||
with TemporaryDirectory(dir=tmp_dir if os.path.exists(tmp_dir) else "/tmp") as root_dir:
|
||||
for instance_id, instance in tqdm(
|
||||
input_instances.items(),
|
||||
total=len(input_instances),
|
||||
desc="Getting file contents",
|
||||
):
|
||||
try:
|
||||
with AutoContextManager(instance, root_dir, verbose=verbose) as cm:
|
||||
readmes = cm.get_readme_files()
|
||||
instance["readmes"] = ingest_files(readmes)
|
||||
instance["oracle_file_contents"] = ingest_files(get_oracle_filenames(instance))
|
||||
instance["file_contents"] = ingest_directory_contents(cm.repo_path)
|
||||
assert all([
|
||||
okey in instance["file_contents"]
|
||||
for okey in instance["oracle_file_contents"].keys()
|
||||
])
|
||||
except Exception as e:
|
||||
print(f"Failed on instance {instance_id}", e)
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
# if AutoContextManager fails to exit properly future exits will return the wrong directory
|
||||
os.chdir(orig_dir)
|
||||
os.chdir(orig_dir)
|
||||
|
||||
|
||||
# %% Get queries, docs, and qrels
|
||||
|
||||
def document2code(data, split: str = "test"):
|
||||
subset = data[split]
|
||||
if args.num_examples is not None:
|
||||
import random
|
||||
indices = random.sample([i for i in range(len(subset))], args.num_examples)
|
||||
subset = subset.select(indices)
|
||||
print(subset)
|
||||
|
||||
# get queries for each example
|
||||
queries = [
|
||||
{
|
||||
"_id": item["instance_id"],
|
||||
"text": item["problem_statement"],
|
||||
"metadata": {}
|
||||
}
|
||||
for item in subset
|
||||
]
|
||||
|
||||
subset_dict = {x["instance_id"]: x for x in subset}
|
||||
get_file_contents(subset_dict, tmp_dir=args.tmp_dir)
|
||||
|
||||
# collect all docs, i.e., code chunks from the repo
|
||||
docs = []
|
||||
for instance_id, instance in subset_dict.items():
|
||||
print(f"Instance #{instance_id}: {len(instance['oracle_file_contents'])} oracle / {len(instance['file_contents'])} files")
|
||||
for filename, content in instance["file_contents"].items():
|
||||
docs.append({
|
||||
"_id": f"{instance_id}_{filename}",
|
||||
"title": filename,
|
||||
"text": content,
|
||||
"metadata": {},
|
||||
})
|
||||
|
||||
# find ground-truth docs for each example
|
||||
qrels = []
|
||||
for instance_id, instance in subset_dict.items():
|
||||
for filename, content in instance["oracle_file_contents"].items():
|
||||
qrels.append({
|
||||
"query-id": instance_id,
|
||||
"corpus-id": f"{instance_id}_{filename}",
|
||||
"score": 1
|
||||
})
|
||||
|
||||
return queries, docs, qrels
|
||||
|
||||
|
||||
def main():
|
||||
dataset = datasets.load_dataset(args.dataset_name, cache_dir=args.cache_dir)
|
||||
|
||||
name = "swe-bench"
|
||||
if "lite" in args.dataset_name.lower():
|
||||
name += "-lite"
|
||||
|
||||
path = os.path.join(args.output_dir, name)
|
||||
os.makedirs(path, exist_ok=True)
|
||||
os.makedirs(os.path.join(path, "qrels"), exist_ok=True)
|
||||
|
||||
queries, docs, qrels = document2code(dataset, split="test")
|
||||
save_file_jsonl(queries, os.path.join(path, "queries.jsonl"))
|
||||
save_file_jsonl(docs, os.path.join(path, "corpus.jsonl"))
|
||||
qrels_path = os.path.join(path, "qrels", "test.tsv")
|
||||
save_tsv_dict(qrels, qrels_path, ["query-id", "corpus-id", "score"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--dataset_name", type=str, default="princeton-nlp/SWE-bench_Lite",
|
||||
choices=["princeton-nlp/SWE-bench", "princeton-nlp/SWE-bench_Lite"])
|
||||
parser.add_argument("--cache_dir", type=str, default="/scratch/zhiruow/data")
|
||||
parser.add_argument("--tmp_dir", type=str, default="/scratch/zhiruow/tmp")
|
||||
parser.add_argument("--output_dir", type=str, default="datasets")
|
||||
parser.add_argument("--num_examples", type=int, default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
main()
|
@ -0,0 +1,263 @@
|
||||
import os
|
||||
import re
|
||||
import chardet
|
||||
import unidiff
|
||||
import argparse
|
||||
import datasets
|
||||
import traceback
|
||||
import subprocess
|
||||
from git import Repo
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from create.utils import save_tsv_dict, save_file_jsonl
|
||||
|
||||
# %% Get oracle file contents
|
||||
|
||||
# get oracle file contents from the repo
|
||||
class ContextManager:
|
||||
def __init__(self, repo_path, base_commit, verbose=False):
|
||||
self.repo_path = Path(repo_path).resolve().as_posix()
|
||||
self.old_dir = os.getcwd()
|
||||
self.base_commit = base_commit
|
||||
self.verbose = verbose
|
||||
|
||||
def __enter__(self):
|
||||
os.chdir(self.repo_path)
|
||||
cmd = f"git reset --hard {self.base_commit} && git clean -fdxq"
|
||||
if self.verbose:
|
||||
subprocess.run(cmd, shell=True, check=True)
|
||||
else:
|
||||
subprocess.run(
|
||||
cmd,
|
||||
shell=True,
|
||||
check=True,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
return self
|
||||
|
||||
def get_environment(self):
|
||||
raise NotImplementedError() # TODO: activate conda environment and return the environment file
|
||||
|
||||
def get_readme_files(self):
|
||||
files = os.listdir(self.repo_path)
|
||||
files = list(filter(lambda x: os.path.isfile(x), files))
|
||||
files = list(filter(lambda x: x.lower().startswith("readme"), files))
|
||||
return files
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
os.chdir(self.old_dir)
|
||||
|
||||
|
||||
class AutoContextManager(ContextManager):
|
||||
"""Automatically clones the repo if it doesn't exist"""
|
||||
|
||||
def __init__(self, instance, root_dir=None, verbose=False, token=None):
|
||||
if token is None:
|
||||
token = os.environ.get("GITHUB_TOKEN", "git")
|
||||
self.tempdir = None
|
||||
if root_dir is None:
|
||||
self.tempdir = TemporaryDirectory()
|
||||
root_dir = self.tempdir.name
|
||||
self.root_dir = root_dir
|
||||
repo_dir = os.path.join(self.root_dir, instance["repo"].replace("/", "__"))
|
||||
if not os.path.exists(repo_dir):
|
||||
repo_url = (
|
||||
f"https://{token}@github.com/swe-bench/"
|
||||
+ instance["repo"].replace("/", "__")
|
||||
+ ".git"
|
||||
)
|
||||
if verbose:
|
||||
print(f"Cloning {instance['repo']} to {root_dir}")
|
||||
Repo.clone_from(repo_url, repo_dir)
|
||||
super().__init__(repo_dir, instance["base_commit"], verbose=verbose)
|
||||
self.instance = instance
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.tempdir is not None:
|
||||
self.tempdir.cleanup()
|
||||
return super().__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
|
||||
def ingest_files(filenames):
|
||||
files_dict = dict()
|
||||
for filename in filenames:
|
||||
with open(filename) as f:
|
||||
content = f.read()
|
||||
files_dict[filename] = content
|
||||
return files_dict
|
||||
|
||||
def get_oracle_filenames(instance):
|
||||
"""
|
||||
Returns the filenames that are changed in the patch
|
||||
"""
|
||||
source_files = {
|
||||
patch_file.source_file.split("a/", 1)[-1]
|
||||
for patch_file in unidiff.PatchSet(instance["patch"])
|
||||
}
|
||||
gold_docs = set()
|
||||
for source_file in source_files:
|
||||
gold_docs.add(source_file)
|
||||
return gold_docs
|
||||
|
||||
|
||||
# get all file contents from the repo
|
||||
def is_test(name, test_phrases=None):
|
||||
if test_phrases is None:
|
||||
test_phrases = ["test", "tests", "testing"]
|
||||
words = set(re.split(r" |_|\/|\.", name.lower()))
|
||||
return any(word in words for word in test_phrases)
|
||||
|
||||
def list_files(root_dir, include_tests=False):
|
||||
files = []
|
||||
for filename in Path(root_dir).rglob("*.py"):
|
||||
if not include_tests and is_test(filename.as_posix()):
|
||||
continue
|
||||
files.append(filename.relative_to(root_dir).as_posix())
|
||||
return files
|
||||
|
||||
def detect_encoding(filename):
|
||||
"""
|
||||
Detect the encoding of a file
|
||||
"""
|
||||
with open(filename, "rb") as file:
|
||||
rawdata = file.read()
|
||||
return chardet.detect(rawdata)["encoding"]
|
||||
|
||||
def ingest_directory_contents(root_dir, include_tests=False):
|
||||
files_content = {}
|
||||
for relative_path in list_files(root_dir, include_tests=include_tests):
|
||||
filename = os.path.join(root_dir, relative_path)
|
||||
encoding = detect_encoding(filename)
|
||||
if encoding is None:
|
||||
content = "[BINARY DATA FILE]"
|
||||
else:
|
||||
try:
|
||||
with open(filename, encoding=encoding) as file:
|
||||
content = file.read()
|
||||
except (UnicodeDecodeError, LookupError):
|
||||
content = "[BINARY DATA FILE]"
|
||||
files_content[relative_path] = content
|
||||
return files_content
|
||||
|
||||
def get_file_contents(input_instances, verbose: bool = False, tmp_dir: str = "/scratch"):
|
||||
orig_dir = os.getcwd()
|
||||
with TemporaryDirectory(dir=tmp_dir if os.path.exists(tmp_dir) else "/tmp") as root_dir:
|
||||
for instance_id, instance in tqdm(
|
||||
input_instances.items(),
|
||||
total=len(input_instances),
|
||||
desc="Getting file contents",
|
||||
):
|
||||
try:
|
||||
with AutoContextManager(instance, root_dir, verbose=verbose) as cm:
|
||||
readmes = cm.get_readme_files()
|
||||
instance["readmes"] = ingest_files(readmes)
|
||||
instance["oracle_file_contents"] = ingest_files(get_oracle_filenames(instance))
|
||||
instance["file_contents"] = ingest_directory_contents(cm.repo_path)
|
||||
assert all([
|
||||
okey in instance["file_contents"]
|
||||
for okey in instance["oracle_file_contents"].keys()
|
||||
])
|
||||
except Exception as e:
|
||||
print(f"Failed on instance {instance_id}", e)
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
# if AutoContextManager fails to exit properly future exits will return the wrong directory
|
||||
os.chdir(orig_dir)
|
||||
os.chdir(orig_dir)
|
||||
|
||||
|
||||
import multiprocessing as mp
|
||||
from functools import partial
|
||||
|
||||
def process_single_item(item, args):
|
||||
"""处理单个数据项的函数"""
|
||||
name = "swe-bench"
|
||||
if "lite" in args.dataset_name.lower():
|
||||
name += "-lite"
|
||||
|
||||
queries = [{
|
||||
"_id": item["instance_id"],
|
||||
"text": item["problem_statement"],
|
||||
"metadata": {}
|
||||
}]
|
||||
item_dict = {item["instance_id"]: item}
|
||||
|
||||
output_path = os.path.join(args.output_dir, f"{name}_{item['instance_id']}", "qrels", "test.tsv")
|
||||
if os.path.exists(output_path):
|
||||
return
|
||||
|
||||
try:
|
||||
get_file_contents(item_dict, tmp_dir=args.tmp_dir)
|
||||
|
||||
docs = []
|
||||
for instance_id, instance in item_dict.items():
|
||||
print(f"Instance #{instance_id}: {len(instance['oracle_file_contents'])} oracle / {len(instance['file_contents'])} files")
|
||||
for filename, content in instance["file_contents"].items():
|
||||
docs.append({
|
||||
"_id": f"{instance_id}_{filename}",
|
||||
"title": filename,
|
||||
"text": content,
|
||||
"metadata": {},
|
||||
})
|
||||
|
||||
qrels = []
|
||||
for instance_id, instance in item_dict.items():
|
||||
for filename, content in instance["oracle_file_contents"].items():
|
||||
qrels.append({
|
||||
"query-id": instance_id,
|
||||
"corpus-id": f"{instance_id}_{filename}",
|
||||
"score": 1
|
||||
})
|
||||
|
||||
path = os.path.join(args.output_dir, f"{name}_{instance_id}")
|
||||
os.makedirs(path, exist_ok=True)
|
||||
os.makedirs(os.path.join(path, "qrels"), exist_ok=True)
|
||||
|
||||
save_file_jsonl(queries, os.path.join(path, "queries.jsonl"))
|
||||
save_file_jsonl(docs, os.path.join(path, "corpus.jsonl"))
|
||||
qrels_path = os.path.join(path, "qrels", "test.tsv")
|
||||
save_tsv_dict(qrels, qrels_path, ["query-id", "corpus-id", "score"])
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing item {item['instance_id']}: {str(e)}")
|
||||
|
||||
def main():
|
||||
dataset = datasets.load_dataset(args.dataset_name, cache_dir=args.cache_dir)["test"]
|
||||
if args.num_examples is not None:
|
||||
import random
|
||||
indices = random.sample([i for i in range(len(dataset))], args.num_examples)
|
||||
dataset = dataset.select(indices)
|
||||
print(dataset)
|
||||
|
||||
# 创建进程池
|
||||
num_processes = mp.cpu_count() - 1 # 留一个CPU核心
|
||||
pool = mp.Pool(processes=num_processes)
|
||||
|
||||
# 使用partial固定args参数
|
||||
process_func = partial(process_single_item, args=args)
|
||||
|
||||
# 使用进程池并行处理
|
||||
list(tqdm(
|
||||
pool.imap_unordered(process_func, dataset),
|
||||
total=len(dataset),
|
||||
desc="Processing items"
|
||||
))
|
||||
|
||||
# 关闭进程池
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--dataset_name", type=str, default="princeton-nlp/SWE-bench_Lite",
|
||||
choices=["princeton-nlp/SWE-bench", "princeton-nlp/SWE-bench_Lite"])
|
||||
parser.add_argument("--cache_dir", type=str, default="/scratch/zhiruow/data")
|
||||
parser.add_argument("--tmp_dir", type=str, default="/scratch/zhiruow/tmp")
|
||||
parser.add_argument("--output_dir", type=str, default="datasets")
|
||||
parser.add_argument("--num_examples", type=int, default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
main()
|
@ -0,0 +1,37 @@
|
||||
import jsonlines
|
||||
import csv
|
||||
import os
|
||||
|
||||
def load_jsonlines(file):
|
||||
with jsonlines.open(file, 'r') as jsonl_f:
|
||||
lst = [obj for obj in jsonl_f]
|
||||
return lst
|
||||
|
||||
def save_file_jsonl(data, fp):
|
||||
with jsonlines.open(fp, mode='w') as writer:
|
||||
writer.write_all(data)
|
||||
|
||||
def save_tsv_dict(data, fp, fields):
|
||||
# build dir
|
||||
dir_path = os.path.dirname(fp)
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
|
||||
# writing to csv file
|
||||
with open(fp, 'w') as csvfile:
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fields, delimiter='\t',)
|
||||
writer.writeheader()
|
||||
writer.writerows(data)
|
||||
|
||||
def cost_esitmate(path):
|
||||
corpus = load_jsonlines(os.path.join(path, "corpus.jsonl"))
|
||||
queries = load_jsonlines(os.path.join(path, "queries.jsonl"))
|
||||
num_corpus_words = 0
|
||||
num_queries_words = 0
|
||||
for item in tqdm(corpus):
|
||||
num_corpus_words += len(item["text"].split(" "))
|
||||
for item in tqdm(queries):
|
||||
num_queries_words += len(item["text"].split(" "))
|
||||
print(len(corpus))
|
||||
print(len(queries))
|
||||
print(num_corpus_words)
|
||||
print(num_queries_words)
|
313
research/BGE_Coder/evaluation/coderag_eval/test/main.py
Normal file
313
research/BGE_Coder/evaluation/coderag_eval/test/main.py
Normal file
@ -0,0 +1,313 @@
|
||||
import os
|
||||
import json
|
||||
import random
|
||||
import logging
|
||||
import pathlib
|
||||
import argparse
|
||||
import numpy as np
|
||||
from time import time
|
||||
from datasets import load_dataset
|
||||
from beir import util, LoggingHandler
|
||||
from beir.retrieval import models
|
||||
from beir.datasets.data_loader import GenericDataLoader
|
||||
from beir.retrieval.evaluation import EvaluateRetrieval
|
||||
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
|
||||
from tqdm import tqdm
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
from arguments import CodeRAGEvalArgs, CodeRAGEvalModelArgs
|
||||
from prompts import get_task_def_by_task_name
|
||||
from FlagEmbedding import FlagLLMModel, FlagModel
|
||||
|
||||
|
||||
def get_model(model_args: CodeRAGEvalModelArgs):
|
||||
embedder_name_or_path = model_args.embedder_name_or_path
|
||||
|
||||
if model_args.embedder_model_class == "encoder-only-base":
|
||||
embedder = FlagModel(
|
||||
model_name_or_path=embedder_name_or_path,
|
||||
normalize_embeddings=model_args.normalize_embeddings,
|
||||
pooling_method=model_args.pooling_method,
|
||||
use_fp16=model_args.use_fp16,
|
||||
query_instruction_for_retrieval=model_args.query_instruction_for_retrieval,
|
||||
query_instruction_format=model_args.query_instruction_format_for_retrieval,
|
||||
devices=model_args.devices,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
cache_dir=model_args.cache_dir,
|
||||
batch_size=model_args.embedder_batch_size,
|
||||
query_max_length=model_args.embedder_query_max_length,
|
||||
passage_max_length=model_args.embedder_passage_max_length,
|
||||
)
|
||||
elif model_args.embedder_model_class == "decoder-only-base":
|
||||
embedder = FlagLLMModel(
|
||||
model_name_or_path=embedder_name_or_path,
|
||||
normalize_embeddings=model_args.normalize_embeddings,
|
||||
pooling_method=model_args.pooling_method,
|
||||
use_fp16=model_args.use_fp16,
|
||||
query_instruction_for_retrieval=model_args.query_instruction_for_retrieval,
|
||||
query_instruction_format=model_args.query_instruction_format_for_retrieval,
|
||||
devices=model_args.devices,
|
||||
examples_for_task=model_args.examples_for_task,
|
||||
examples_instruction_format=model_args.examples_instruction_format,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
cache_dir=model_args.cache_dir,
|
||||
batch_size=model_args.embedder_batch_size,
|
||||
query_max_length=model_args.embedder_query_max_length,
|
||||
passage_max_length=model_args.embedder_passage_max_length,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid model class: {model_args.embedder_model_class}")
|
||||
embedder.model.config._name_or_path = model_args.embedder_name_or_path
|
||||
|
||||
class CustomFlagModel:
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
|
||||
def encode_queries(self, queries, show_progress_bar, convert_to_tensor, **kwargs):
|
||||
if isinstance(queries, str):
|
||||
queries = [queries]
|
||||
|
||||
if isinstance(queries[0], dict):
|
||||
queries = [(e.get('title') + ' ' + e['text']).strip() for e in queries]
|
||||
|
||||
return self.model.encode_queries(queries, **kwargs)
|
||||
|
||||
def encode_corpus(self, corpus, show_progress_bar, convert_to_tensor, **kwargs):
|
||||
if isinstance(corpus, str):
|
||||
corpus = [corpus]
|
||||
|
||||
if isinstance(corpus[0], dict):
|
||||
corpus = [(e.get('title') + ' ' + e['text']).strip() for e in corpus]
|
||||
|
||||
return self.model.encode_corpus(corpus, **kwargs)
|
||||
|
||||
def encode(self, corpus, show_progress_bar, convert_to_tensor, **kwargs):
|
||||
if isinstance(corpus, str):
|
||||
corpus = [corpus]
|
||||
|
||||
if isinstance(corpus[0], dict):
|
||||
corpus = [(e.get('title') + ' ' + e['text']).strip() for e in corpus]
|
||||
|
||||
return self.model.encode(corpus, **kwargs)
|
||||
|
||||
return CustomFlagModel(embedder)
|
||||
|
||||
#### Just some code to print debug information to stdout
|
||||
logging.basicConfig(format='%(asctime)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S',
|
||||
level=logging.INFO,
|
||||
handlers=[LoggingHandler()])
|
||||
|
||||
|
||||
def get_top_docs(results: dict, corpus: dict, task_id: str, topk: int = 10) -> list[str]:
|
||||
if task_id not in results: return []
|
||||
doc_scores = results[task_id]
|
||||
doc_scores_sorted = sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)
|
||||
doc_scores_sorted = doc_scores_sorted[:topk]
|
||||
doc_code_snippets = [corpus[code_id] for code_id, score in doc_scores_sorted]
|
||||
return doc_code_snippets
|
||||
|
||||
|
||||
def main(
|
||||
eval_args: CodeRAGEvalArgs,
|
||||
model_args: CodeRAGEvalModelArgs
|
||||
):
|
||||
args = eval_args
|
||||
|
||||
embedder = get_model(model_args)
|
||||
model = DRES(
|
||||
embedder,
|
||||
batch_size=args.batch_size,
|
||||
corpus_chunk_size=512 * 9999
|
||||
)
|
||||
retriever = EvaluateRetrieval(model, score_function="dot")
|
||||
|
||||
if args.dataset.startswith("swe-bench") or args.dataset.startswith("repoeval"):
|
||||
all_eval_results = []
|
||||
|
||||
if args.dataset.startswith("swe-bench"):
|
||||
swebench = load_dataset("princeton-nlp/SWE-bench_Lite")["test"]
|
||||
all_top_docs = [[] for _ in swebench]
|
||||
|
||||
instance_list = [i for i in os.listdir("datasets") if i.startswith(f"{args.dataset}_")]
|
||||
instance_list_filtered = []
|
||||
|
||||
for ins_dir in tqdm(instance_list):
|
||||
logging.info("Instance Repo: {}".format(ins_dir))
|
||||
# load data and perform retrieval
|
||||
corpus, queries, qrels = GenericDataLoader(
|
||||
data_folder=os.path.join("datasets", ins_dir)
|
||||
).load(split="test")
|
||||
logging.info(f"Instance #{ins_dir}: #{len(corpus)} corpus, #{len(queries)} queries")
|
||||
|
||||
start_time = time()
|
||||
if len(queries) == 1:
|
||||
queries.update({"dummy": "dummy"})
|
||||
results = retriever.retrieve(corpus, queries)
|
||||
if "dummy" in queries:
|
||||
queries.pop("dummy")
|
||||
results.pop("dummy")
|
||||
end_time = time()
|
||||
logging.info("Time taken to retrieve: {:.2f} seconds".format(end_time - start_time))
|
||||
|
||||
# get topk retrieved docs
|
||||
if args.dataset.startswith("swe-bench"):
|
||||
indices = [i for i, ex in enumerate(swebench) if ex["instance_id"] in queries]
|
||||
for index in indices:
|
||||
instance_id = swebench[index]["instance_id"]
|
||||
all_top_docs[index] = get_top_docs(results, corpus, instance_id)
|
||||
elif args.dataset.startswith("repoeval"):
|
||||
args.dataset_path = "output/repoeval/datasets/function_level_completion_2k_context_codex.test.clean.jsonl"
|
||||
tasks = [json.loads(line.strip()) for line in open(args.dataset_path, 'r')]
|
||||
prompts, references, docs, metadatas = [], [], [], []
|
||||
for task in tasks:
|
||||
if task["metadata"]["task_id"] not in queries: continue
|
||||
prompts.append(task["prompt"]) # save full prompt
|
||||
references.append(task["metadata"]["ground_truth"])
|
||||
docs.append(get_top_docs(
|
||||
results=results, corpus=corpus, task_id=task["metadata"]["task_id"],
|
||||
))
|
||||
metadatas.append(task["metadata"])
|
||||
assert len(prompts) == len(references) == len(docs)
|
||||
dataset = [
|
||||
{"prompt": p, "reference": r, "docs": d, "metadata": m}
|
||||
for p, r, d, m in zip(prompts, references, docs, metadatas)
|
||||
]
|
||||
with open(args.results_file, "a") as fout:
|
||||
for curr in dataset:
|
||||
fout.write(json.dumps(curr) + "\n")
|
||||
else:
|
||||
raise ValueError(f"`dataset` should starts with either 'swe-bench' or 'repoeval'.")
|
||||
|
||||
# evaluate retrieval results
|
||||
if len(qrels) == 0:
|
||||
logging.info("No qrels found for this dataset.")
|
||||
return
|
||||
logging.info("Retriever evaluation for k in: {}".format(retriever.k_values))
|
||||
ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
|
||||
mrr = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="mrr")
|
||||
eval_results = {
|
||||
"ndcg": ndcg, "mrr": mrr,
|
||||
"recall": recall, "precision": precision,
|
||||
"time": end_time - start_time
|
||||
}
|
||||
logging.info(f"Instance #{ins_dir}: {eval_results}")
|
||||
all_eval_results.append(eval_results)
|
||||
|
||||
with open(args.output_file + "_all", "w") as f:
|
||||
json.dump(all_eval_results, f)
|
||||
|
||||
if args.dataset.startswith("swe-bench"):
|
||||
swebench = swebench.add_column("docs", all_top_docs)
|
||||
swebench.to_json(args.results_file)
|
||||
|
||||
avg_eval_results = {}
|
||||
for k, v_dict in all_eval_results[0].items():
|
||||
if isinstance(v_dict, dict):
|
||||
avg_v_dict = {}
|
||||
for vk, vv in v_dict.items():
|
||||
avg_vv = sum([e[k][vk] for e in all_eval_results]) / len(all_eval_results)
|
||||
avg_v_dict[vk] = avg_vv
|
||||
avg_eval_results.update(avg_v_dict)
|
||||
elif isinstance(v_dict, float):
|
||||
avg_v = sum([e[k] for e in all_eval_results]) / len(all_eval_results)
|
||||
avg_eval_results[k] = avg_v
|
||||
else:
|
||||
raise ValueError
|
||||
print("Average Eval Results: ", avg_eval_results)
|
||||
with open(args.output_file, "w") as f:
|
||||
json.dump(avg_eval_results, f)
|
||||
else:
|
||||
dataset = args.dataset
|
||||
corpus, queries, qrels = GenericDataLoader(data_folder=os.path.join("datasets", args.dataset)).load(
|
||||
split="test")
|
||||
#### Retrieve dense results (format of results is identical to qrels)
|
||||
start_time = time()
|
||||
results = retriever.retrieve(corpus, queries)
|
||||
end_time = time()
|
||||
print("Time taken to retrieve: {:.2f} seconds".format(end_time - start_time))
|
||||
|
||||
if args.dataset in ["humaneval", "mbpp", "apps"]:
|
||||
if args.dataset == "humaneval":
|
||||
ds = load_dataset("openai_humaneval")
|
||||
id_key = "task_id"
|
||||
elif args.dataset == "mbpp":
|
||||
ds = load_dataset("mbpp")
|
||||
id_key = "task_id"
|
||||
elif args.dataset == "apps":
|
||||
ds = load_dataset("codeparrot/apps")
|
||||
id_key = "problem_id"
|
||||
all_top_docs = []
|
||||
for task_id in ds["test"][id_key]:
|
||||
all_top_docs.append(get_top_docs(results, corpus, f"{task_id}_doc"))
|
||||
ds["test"] = ds["test"].add_column("docs", all_top_docs)
|
||||
ds["test"].to_json(args.results_file) # this outputs to arrow format and read as .jsonl
|
||||
elif args.dataset.startswith("odex"):
|
||||
lang = args.dataset.split("_")[-1]
|
||||
ds = load_dataset("neulab/odex", lang, trust_remote_code=True)
|
||||
all_top_docs = []
|
||||
for idx, task_id in enumerate(ds["test"]["task_id"]):
|
||||
all_top_docs.append(get_top_docs(results, corpus, f"{idx}_{task_id}"))
|
||||
ds["test"] = ds["test"].add_column("docs", all_top_docs)
|
||||
ds["test"].to_json(args.results_file) # this outputs to arrow format and read as .jsonl
|
||||
elif args.dataset.startswith("ds1000"):
|
||||
_, key, mode = args.dataset.split("_")
|
||||
key = key.capitalize()
|
||||
mode = mode.capitalize()
|
||||
from create.ds1000 import get_dataset
|
||||
source_dir = pathlib.Path(__file__).parent / "ds"
|
||||
data = get_dataset(source_dir, mode=mode, key=key)
|
||||
all_docs = []
|
||||
example_ids = []
|
||||
for item in data:
|
||||
example = item.data
|
||||
example_id = f"{example['lib']}_{example['perturbation_origin_id']}"
|
||||
all_docs.append(get_top_docs(results, corpus, example_id))
|
||||
example_ids.append(example_id)
|
||||
assert len(all_docs) == len(
|
||||
example_ids), f"length of all_docs should be {len(example_ids)}, now is {len(all_docs)}"
|
||||
with open(args.results_file, "w+") as fout:
|
||||
for idx, all_doc in enumerate(all_docs):
|
||||
fout.write(json.dumps({"example_id": example_id,
|
||||
"docs": all_doc}) + "\n")
|
||||
else:
|
||||
with open(args.results_file, 'w+') as fw:
|
||||
for curr in results:
|
||||
fw.write(json.dumps({curr: results[curr]}) + "\n")
|
||||
|
||||
#### Evaluate your retrieval using NDCG@k, MAP@K ...
|
||||
if len(qrels) == 0:
|
||||
logging.info("No qrels found for this dataset.")
|
||||
return
|
||||
logging.info("Retriever evaluation for k in: {}".format(retriever.k_values))
|
||||
ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
|
||||
|
||||
mrr = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="mrr")
|
||||
recall_cap = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="r_cap")
|
||||
hole = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="hole")
|
||||
|
||||
all_results = {"ndcg": ndcg, "mrr": mrr, "recall": recall, "precision": precision,
|
||||
"time": end_time - start_time}
|
||||
with open(args.output_file, "w") as f:
|
||||
json.dump(all_results, f)
|
||||
#### Print top-k documents retrieved ####
|
||||
top_k = 3
|
||||
|
||||
query_id, ranking_scores = random.choice(list(results.items()))
|
||||
scores_sorted = sorted(ranking_scores.items(), key=lambda item: item[1], reverse=True)
|
||||
logging.info("Query : %s\n" % queries[query_id])
|
||||
|
||||
for rank in range(top_k):
|
||||
doc_id = scores_sorted[rank][0]
|
||||
# Format: Rank x: ID [Title] Body
|
||||
logging.info(
|
||||
"Rank %d: %s [%s] - %s\n" % (rank + 1, doc_id, corpus[doc_id].get("title"), corpus[doc_id].get("text")))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser((
|
||||
CodeRAGEvalArgs,
|
||||
CodeRAGEvalModelArgs
|
||||
))
|
||||
eval_args, model_args = parser.parse_args_into_dataclasses()
|
||||
main(eval_args, model_args)
|
18
research/BGE_Coder/evaluation/coderag_eval/test/prompts.py
Normal file
18
research/BGE_Coder/evaluation/coderag_eval/test/prompts.py
Normal file
@ -0,0 +1,18 @@
|
||||
from typing import Dict
|
||||
|
||||
|
||||
def get_task_def_by_task_name(task_name: str) -> str:
|
||||
task_name_to_instruct: Dict[str, str] = {
|
||||
'humaneval': 'Given a question that consists of a mix of text and code snippets, retrieve relevant answers that also consist of a mix of text and code snippets, and can help answer the question.',
|
||||
'mbpp': 'Given a textual explanation of code functionality, retrieve the corresponding code implementation.',
|
||||
'ds1000_all_completion': 'Given a question that consists of a mix of text and code snippets, retrieve relevant answers that also consist of a mix of text and code snippets, and can help answer the question.',
|
||||
'odex_en': 'Given a question, retrieve relevant answers that also consist of a mix of text and code snippets, and can help answer the question.',
|
||||
'odex_es': 'Given a question, retrieve relevant answers that also consist of a mix of text and code snippets, and can help answer the question.',
|
||||
'odex_ja': 'Given a question, retrieve relevant answers that also consist of a mix of text and code snippets, and can help answer the question.',
|
||||
'odex_ru': 'Given a question, retrieve relevant answers that also consist of a mix of text and code snippets, and can help answer the question.',
|
||||
'repoeval': 'Given a code snippet and a new function name, retrieve the implementation of the function.',
|
||||
# 'repoeval': 'Given a piece of code segment, retrieve the code segment that is the latter part of the code.',
|
||||
'swe-bench-lite': 'Given a code snippet containing a bug and a natural language description of the bug or error, retrieve code snippets that demonstrate solutions or fixes for similar bugs or errors (the desired documents).'
|
||||
}
|
||||
|
||||
return task_name_to_instruct[task_name]
|
69
research/BGE_Coder/evaluation/coir_eval/arguments.py
Normal file
69
research/BGE_Coder/evaluation/coir_eval/arguments.py
Normal file
@ -0,0 +1,69 @@
|
||||
from typing import List
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from FlagEmbedding.abc.evaluation import (
|
||||
AbsEvalModelArgs as COIREvalModelArgs,
|
||||
)
|
||||
|
||||
|
||||
def coir_tasks():
|
||||
return [
|
||||
"apps",
|
||||
"codefeedback-mt",
|
||||
"codefeedback-st",
|
||||
"CodeSearchNet-ccr-go",
|
||||
"CodeSearchNet-ccr-java",
|
||||
"CodeSearchNet-ccr-javascript",
|
||||
"CodeSearchNet-ccr-php",
|
||||
"CodeSearchNet-ccr-python",
|
||||
"CodeSearchNet-ccr-ruby",
|
||||
"CodeSearchNet-go",
|
||||
"CodeSearchNet-java",
|
||||
"CodeSearchNet-javascript",
|
||||
"CodeSearchNet-php",
|
||||
"CodeSearchNet-python",
|
||||
"CodeSearchNet-ruby",
|
||||
"codetrans-contest",
|
||||
"codetrans-dl",
|
||||
"cosqa",
|
||||
"stackoverflow-qa",
|
||||
"synthetic-text2sql"
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class COIREvalArgs:
|
||||
output_dir: str = field(
|
||||
default="./results", metadata={"help": "Path to save results."}
|
||||
)
|
||||
tasks: List[str] = field(
|
||||
default_factory=coir_tasks,
|
||||
metadata={
|
||||
"help": "Tasks to evaluate. Default: None. Available tasks: ['apps', 'codefeedback-mt', 'codefeedback-st', 'CodeSearchNet-ccr-go', 'CodeSearchNet-ccr-java', 'CodeSearchNet-ccr-javascript', 'CodeSearchNet-ccr-php', 'CodeSearchNet-ccr-python', 'CodeSearchNet-ccr-ruby', 'CodeSearchNet-go', 'CodeSearchNet-java', 'CodeSearchNet-javascript', 'CodeSearchNet-php', 'CodeSearchNet-python', 'CodeSearchNet-ruby', 'codetrans-contest', 'codetrans-dl', 'cosqa', 'stackoverflow-qa', 'synthetic-text2sql']",
|
||||
"choices": [
|
||||
"apps",
|
||||
"codefeedback-mt",
|
||||
"codefeedback-st",
|
||||
"CodeSearchNet-ccr-go",
|
||||
"CodeSearchNet-ccr-java",
|
||||
"CodeSearchNet-ccr-javascript",
|
||||
"CodeSearchNet-ccr-php",
|
||||
"CodeSearchNet-ccr-python",
|
||||
"CodeSearchNet-ccr-ruby",
|
||||
"CodeSearchNet-go",
|
||||
"CodeSearchNet-java",
|
||||
"CodeSearchNet-javascript",
|
||||
"CodeSearchNet-php",
|
||||
"CodeSearchNet-python",
|
||||
"CodeSearchNet-ruby",
|
||||
"codetrans-contest",
|
||||
"codetrans-dl",
|
||||
"cosqa",
|
||||
"stackoverflow-qa",
|
||||
"synthetic-text2sql"
|
||||
]
|
||||
}
|
||||
)
|
||||
use_special_instructions: bool = field(
|
||||
default=False, metadata={"help": "Whether to use specific instructions in `prompts.py` for evaluation. Default: False"}
|
||||
)
|
16
research/BGE_Coder/evaluation/coir_eval/eval.sh
Normal file
16
research/BGE_Coder/evaluation/coir_eval/eval.sh
Normal file
@ -0,0 +1,16 @@
|
||||
output_dir=result
|
||||
|
||||
python main.py \
|
||||
--output_dir ${output_dir} \
|
||||
--use_special_instructions True \
|
||||
--embedder_name_or_path BAAI/bge-code-v1 \
|
||||
--embedder_model_class decoder-only-base \
|
||||
--query_instruction_format_for_retrieval '<instruct>{}\n<query>{}' \
|
||||
--embedder_query_max_length 2048 \
|
||||
--embedder_passage_max_length 2048 \
|
||||
--trust_remote_code True \
|
||||
--pooling_method last_token \
|
||||
--embedder_batch_size 64 \
|
||||
--devices cuda:0 cuda:1 cuda:2 cuda:3 cuda:4 cuda:5 cuda:6 cuda:7 \
|
||||
--tasks apps codetrans-contest codetrans-dl cosqa synthetic-text2sql stackoverflow-qa codefeedback-mt codefeedback-st CodeSearchNet-ccr-go CodeSearchNet-ccr-java CodeSearchNet-ccr-javascript CodeSearchNet-ccr-php CodeSearchNet-ccr-python CodeSearchNet-ccr-ruby CodeSearchNet-go CodeSearchNet-java CodeSearchNet-javascript CodeSearchNet-php CodeSearchNet-python CodeSearchNet-ruby \
|
||||
--cache_dir ./cache
|
167
research/BGE_Coder/evaluation/coir_eval/main.py
Normal file
167
research/BGE_Coder/evaluation/coir_eval/main.py
Normal file
@ -0,0 +1,167 @@
|
||||
import os
|
||||
import json
|
||||
import coir
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
from arguments import COIREvalArgs, COIREvalModelArgs
|
||||
from prompts import get_task_def_by_task_name
|
||||
from FlagEmbedding import FlagLLMModel, FlagModel
|
||||
|
||||
|
||||
def get_model(model_args: COIREvalModelArgs):
|
||||
embedder_name_or_path = model_args.embedder_name_or_path
|
||||
|
||||
if model_args.embedder_model_class == "encoder-only-base":
|
||||
embedder = FlagModel(
|
||||
model_name_or_path=embedder_name_or_path,
|
||||
normalize_embeddings=model_args.normalize_embeddings,
|
||||
pooling_method=model_args.pooling_method,
|
||||
use_fp16=model_args.use_fp16,
|
||||
query_instruction_for_retrieval=model_args.query_instruction_for_retrieval,
|
||||
query_instruction_format=model_args.query_instruction_format_for_retrieval,
|
||||
devices=model_args.devices,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
cache_dir=model_args.cache_dir,
|
||||
batch_size=model_args.embedder_batch_size,
|
||||
query_max_length=model_args.embedder_query_max_length,
|
||||
passage_max_length=model_args.embedder_passage_max_length,
|
||||
)
|
||||
elif model_args.embedder_model_class == "decoder-only-base":
|
||||
embedder = FlagLLMModel(
|
||||
model_name_or_path=embedder_name_or_path,
|
||||
normalize_embeddings=model_args.normalize_embeddings,
|
||||
pooling_method=model_args.pooling_method,
|
||||
use_fp16=model_args.use_fp16,
|
||||
query_instruction_for_retrieval=model_args.query_instruction_for_retrieval,
|
||||
query_instruction_format=model_args.query_instruction_format_for_retrieval,
|
||||
devices=model_args.devices,
|
||||
examples_for_task=model_args.examples_for_task,
|
||||
examples_instruction_format=model_args.examples_instruction_format,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
cache_dir=model_args.cache_dir,
|
||||
batch_size=model_args.embedder_batch_size,
|
||||
query_max_length=model_args.embedder_query_max_length,
|
||||
passage_max_length=model_args.embedder_passage_max_length,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid model class: {model_args.embedder_model_class}")
|
||||
embedder.model.config._name_or_path = model_args.embedder_name_or_path
|
||||
|
||||
class CustomFlagModel:
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
|
||||
def encode_queries(self, queries, show_progress_bar, convert_to_tensor, **kwargs):
|
||||
if isinstance(queries, str):
|
||||
queries = [queries]
|
||||
|
||||
if isinstance(queries[0], dict):
|
||||
queries = [(e.get('title') + ' ' + e['text']).strip() for e in queries]
|
||||
|
||||
return self.model.encode_queries(queries, **kwargs)
|
||||
|
||||
def encode_corpus(self, corpus, show_progress_bar, convert_to_tensor, **kwargs):
|
||||
if isinstance(corpus, str):
|
||||
corpus = [corpus]
|
||||
|
||||
if isinstance(corpus[0], dict):
|
||||
corpus = [(e.get('title') + ' ' + e['text']).strip() for e in corpus]
|
||||
|
||||
return self.model.encode_corpus(corpus, **kwargs)
|
||||
|
||||
def encode(self, corpus, show_progress_bar, convert_to_tensor, **kwargs):
|
||||
if isinstance(corpus, str):
|
||||
corpus = [corpus]
|
||||
|
||||
if isinstance(corpus[0], dict):
|
||||
corpus = [(e.get('title') + ' ' + e['text']).strip() for e in corpus]
|
||||
|
||||
return self.model.encode(corpus, **kwargs)
|
||||
|
||||
return CustomFlagModel(embedder)
|
||||
|
||||
|
||||
def main(
|
||||
eval_args: COIREvalArgs,
|
||||
model_args: COIREvalModelArgs
|
||||
):
|
||||
model = get_model(model_args)
|
||||
|
||||
output_folder = os.path.join(eval_args.output_dir, os.path.basename(model.model.model.config._name_or_path))
|
||||
|
||||
all_task = eval_args.tasks
|
||||
if not isinstance(all_task, list):
|
||||
all_task = [all_task]
|
||||
|
||||
all_results = {}
|
||||
for task_name in all_task:
|
||||
save_path = os.path.join(output_folder, f"{task_name}.json")
|
||||
if os.path.exists(save_path):
|
||||
with open(save_path, "r", encoding="utf-8") as f:
|
||||
results = json.load(f)
|
||||
all_results[task_name] = results['metrics']
|
||||
continue
|
||||
|
||||
tmp_task = coir.get_tasks(tasks=[task_name])
|
||||
evaluation = coir.COIR(tasks=tmp_task,
|
||||
batch_size=model_args.embedder_batch_size)
|
||||
|
||||
model.model.stop_self_pool()
|
||||
|
||||
if eval_args.use_special_instructions:
|
||||
model.model.query_instruction_for_retrieval = get_task_def_by_task_name(task_name)
|
||||
|
||||
results = evaluation.run(model, output_folder=output_folder)
|
||||
all_results[task_name] = results[task_name]
|
||||
|
||||
csn_result = 0
|
||||
csn_num = 0
|
||||
csn_ccr_result = 0
|
||||
csn_ccr_num = 0
|
||||
pop_keys = []
|
||||
all_result = 0
|
||||
all_num = 0
|
||||
for k in all_results.keys():
|
||||
if 'CodeSearchNet-ccr' in k:
|
||||
csn_ccr_result += all_results[k]['NDCG']['NDCG@10']
|
||||
csn_ccr_num += 1
|
||||
pop_keys.append(k)
|
||||
elif 'CodeSearchNet' in k:
|
||||
csn_result += all_results[k]['NDCG']['NDCG@10']
|
||||
csn_num += 1
|
||||
pop_keys.append(k)
|
||||
else:
|
||||
all_result += all_results[k]['NDCG']['NDCG@10']
|
||||
all_num += 1
|
||||
if csn_num > 0:
|
||||
print('Using CodeSearchNet')
|
||||
all_result += csn_result / csn_num
|
||||
all_num += 1
|
||||
if csn_ccr_num > 0:
|
||||
print('Using CodeSearchNet-ccr')
|
||||
all_result += csn_ccr_result / csn_ccr_num
|
||||
all_num += 1
|
||||
new_results = {}
|
||||
for k in all_results:
|
||||
if k in pop_keys:
|
||||
continue
|
||||
new_results[k] = all_results[k]['NDCG']['NDCG@10']
|
||||
if csn_num > 0:
|
||||
new_results['CodeSearchNet'] = csn_result / csn_num
|
||||
if csn_ccr_num > 0:
|
||||
new_results['CodeSearchNet_CCR'] = csn_ccr_result / csn_ccr_num
|
||||
new_results['all'] = all_result / all_num
|
||||
|
||||
print(new_results)
|
||||
|
||||
with open(os.path.join(output_folder, 'OVERALL-results.json'), 'w') as f:
|
||||
json.dump(new_results, f)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser((
|
||||
COIREvalArgs,
|
||||
COIREvalModelArgs
|
||||
))
|
||||
eval_args, model_args = parser.parse_args_into_dataclasses()
|
||||
main(eval_args, model_args)
|
38
research/BGE_Coder/evaluation/coir_eval/prompts.py
Normal file
38
research/BGE_Coder/evaluation/coir_eval/prompts.py
Normal file
@ -0,0 +1,38 @@
|
||||
from typing import Dict
|
||||
|
||||
|
||||
def get_task_def_by_task_name(task_name: str) -> str:
|
||||
task_name_to_instruct: Dict[str, str] = {
|
||||
# Text-to-Code Retrieval
|
||||
## Code Contest Retrieval
|
||||
'apps': 'Given a code contest problem description, retrieve relevant code that can help solve the problem.',
|
||||
## Web Query to Code Retrieval
|
||||
'cosqa': 'Given a web search query, retrieve relevant code that can help answer the query.',
|
||||
## Text-to-SQL Retrieval
|
||||
'synthetic-text2sql': 'Given a question in text, retrieve SQL queries that are appropriate responses to the question.',
|
||||
|
||||
# Code-to-Text Retrieval
|
||||
## Code Summary Retrieval
|
||||
'CodeSearchNet-': 'Given a piece of code, retrieve the document string that summarizes the code.',
|
||||
|
||||
# Code-to-Code Retrieval
|
||||
## Code Context Retrieval
|
||||
'CodeSearchNet-ccr-': 'Given a piece of code segment, retrieve the code segment that is the latter part of the code.',
|
||||
## Similar Code Retrieval
|
||||
'codetrans-dl': 'Given a piece of code, retrieve code that is semantically equivalent to the input code.',
|
||||
'codetrans-contest': 'Given a piece of Python code, retrieve C++ code that is semantically equivalent to the input code.',
|
||||
|
||||
# Hybrid Code Retrieval
|
||||
## Single-turn Code QA
|
||||
'stackoverflow-qa': 'Given a question that consists of a mix of text and code snippets, retrieve relevant answers that also consist of a mix of text and code snippets, and can help answer the question.',
|
||||
'codefeedback-st': 'Given a question that consists of a mix of text and code snippets, retrieve relevant answers that also consist of a mix of text and code snippets, and can help answer the question.',
|
||||
## Multi-turn Code QA
|
||||
'codefeedback-mt': 'Given a multi-turn conversation history that consists of a mix of text and code snippets, retrieve relevant answers that also consist of a mix of text and code snippets, and can help answer the question.',
|
||||
}
|
||||
|
||||
special_task_names = ['CodeSearchNet-ccr-', 'CodeSearchNet-']
|
||||
for special_task_name in special_task_names:
|
||||
if special_task_name in task_name:
|
||||
return task_name_to_instruct[special_task_name]
|
||||
|
||||
return task_name_to_instruct[task_name]
|
Loading…
x
Reference in New Issue
Block a user