mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-07-03 15:11:31 +00:00
79 lines
2.9 KiB
Python
79 lines
2.9 KiB
Python
import os
|
|
import argparse
|
|
import datasets
|
|
from tqdm import tqdm
|
|
from create.utils import save_tsv_dict, save_file_jsonl
|
|
|
|
|
|
def get_function_name(code: str) -> str:
|
|
"""Parse the function name for a code snippet string."""
|
|
lines = code.split('\n')
|
|
for line in lines:
|
|
if line.lstrip().startswith("def "):
|
|
break
|
|
func_name = line.lstrip()[4: ]
|
|
func_name = func_name.split('(')[0]
|
|
return func_name
|
|
|
|
|
|
def document2code(data, split="test"):
|
|
data = data[split]
|
|
queries, docs, qrels = [], [], []
|
|
|
|
for item in tqdm(data):
|
|
doc = item["text"]
|
|
code = "# " + item["text"] + '\n' + item["code"]
|
|
doc_id = "{task_id}_doc".format_map(item)
|
|
code_id = "{task_id}_code".format_map(item)
|
|
|
|
queries.append({"_id": doc_id, "text": doc, "metadata": {}})
|
|
docs.append({"_id": code_id, "title": get_function_name(item["code"]), "text": code, "metadata": {}})
|
|
qrels.append({"query-id": doc_id, "corpus-id": code_id, "score": 1})
|
|
|
|
return queries, docs, qrels
|
|
|
|
|
|
def main():
|
|
dataset = datasets.load_dataset(args.dataset_name)
|
|
|
|
path = os.path.join(args.output_dir, args.output_name)
|
|
os.makedirs(path, exist_ok=True)
|
|
os.makedirs(os.path.join(path, "qrels"), exist_ok=True)
|
|
|
|
docs, queries = [], []
|
|
for split in args.splits:
|
|
queries_split, docs_split, qrels_split = document2code(dataset, split)
|
|
docs += docs_split
|
|
queries += queries_split
|
|
|
|
qrels_path = os.path.join(path, "qrels", f"{split}.tsv")
|
|
save_tsv_dict(qrels_split, qrels_path, ["query-id", "corpus-id", "score"])
|
|
|
|
# create canonical file for test split if not existent yet
|
|
if split == "test" and (not os.path.exists(args.canonical_file)):
|
|
canonical_solutions = []
|
|
for doc in docs_split:
|
|
canonical_solutions.append([{
|
|
"text": doc["text"], "title": doc["title"]
|
|
}])
|
|
canonical_dataset = dataset["test"].add_column("docs", canonical_solutions)
|
|
canonical_dataset.to_json(args.canonical_file)
|
|
|
|
save_file_jsonl(queries, os.path.join(path, "queries.jsonl"))
|
|
save_file_jsonl(docs, os.path.join(path, "corpus.jsonl"))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--dataset_name", type=str, default="google-research-datasets/mbpp")
|
|
# parser.add_argument("--dataset_name", type=str, default="code-rag-bench/mbpp")
|
|
parser.add_argument("--splits", type=str, default=["train", "validation", "test"],
|
|
choices=["train", "validation", "test", "prompt"])
|
|
parser.add_argument("--output_name", type=str, default="mbpp")
|
|
parser.add_argument("--output_dir", type=str, default="datasets")
|
|
parser.add_argument("--canonical_file", type=str,
|
|
default="datasets/canonical/mbpp_solutions.json")
|
|
|
|
args = parser.parse_args()
|
|
main()
|