2025-05-28 14:12:16 +08:00

77 lines
2.6 KiB
Python

import os
import re
import random
import argparse
import datasets
from tqdm import tqdm
from collections import Counter
from datasets import load_dataset
from create.utils import save_tsv_dict, save_file_jsonl
def document2code(data, split="test"):
data = data[split]
queries, docs, qrels = [], [], []
# build doc corpus
code_docs = load_dataset("neulab/docprompting-conala", "docs")["train"]
for i in range(len(code_docs)):
docs.append({
"_id": str(i),
"title": code_docs[i]["doc_id"],
"text": code_docs[i]["doc_content"],
"metadata": {}
})
# load canonical docs
odex = load_dataset("json", data_files={"test": args.canonical_file})["test"]
# collect queries and query-doc matching
for idx,item in enumerate(tqdm(data)):
query = item["intent"]
query_id = f"{idx}_{item['task_id']}"
queries.append({"_id": query_id, "text": query, "metadata": {}})
doc_ids = [doc["title"] for doc in odex[idx]["docs"]]
for doc_id in doc_ids:
corpus_id = code_docs["doc_id"].index(doc_id)
corpus_id = str(corpus_id)
qrels.append({"query-id": query_id, "corpus-id": corpus_id, "score": 1})
return queries, docs, qrels
def main():
if '_' in args.dataset_name:
dataset_name = args.dataset_name.split('_')[0]
language = args.dataset_name.split('_')[1]
else:
dataset_name = args.dataset_name
language = 'en'
dataset = datasets.load_dataset(dataset_name, language) # english version by default
path = os.path.join(args.output_dir, args.output_name.replace('en', language))
os.makedirs(path, exist_ok=True)
os.makedirs(os.path.join(path, "qrels"), exist_ok=True)
docs, queries = [], []
for split in ["test"]:
queries_split, docs_split, qrels_split = document2code(dataset, split)
docs += docs_split
queries += queries_split
save_tsv_dict(qrels_split, os.path.join(path, "qrels", "{}.tsv".format(split)), ["query-id", "corpus-id", "score"])
save_file_jsonl(queries, os.path.join(path, "queries.jsonl"))
save_file_jsonl(docs, os.path.join(path, "corpus.jsonl"))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_name", type=str, default="neulab/odex")
parser.add_argument("--output_name", type=str, default="odex_en")
parser.add_argument("--canonical_file", type=str, default="datasets/canonical/odex_docs.json")
parser.add_argument("--output_dir", type=str, default="datasets")
args = parser.parse_args()
main()