mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-07-24 17:40:42 +00:00
77 lines
2.6 KiB
Python
77 lines
2.6 KiB
Python
import os
|
|
import re
|
|
import random
|
|
import argparse
|
|
import datasets
|
|
from tqdm import tqdm
|
|
from collections import Counter
|
|
from datasets import load_dataset
|
|
from create.utils import save_tsv_dict, save_file_jsonl
|
|
|
|
|
|
def document2code(data, split="test"):
|
|
data = data[split]
|
|
queries, docs, qrels = [], [], []
|
|
|
|
# build doc corpus
|
|
code_docs = load_dataset("neulab/docprompting-conala", "docs")["train"]
|
|
for i in range(len(code_docs)):
|
|
docs.append({
|
|
"_id": str(i),
|
|
"title": code_docs[i]["doc_id"],
|
|
"text": code_docs[i]["doc_content"],
|
|
"metadata": {}
|
|
})
|
|
|
|
# load canonical docs
|
|
odex = load_dataset("json", data_files={"test": args.canonical_file})["test"]
|
|
# collect queries and query-doc matching
|
|
for idx,item in enumerate(tqdm(data)):
|
|
query = item["intent"]
|
|
query_id = f"{idx}_{item['task_id']}"
|
|
queries.append({"_id": query_id, "text": query, "metadata": {}})
|
|
|
|
doc_ids = [doc["title"] for doc in odex[idx]["docs"]]
|
|
for doc_id in doc_ids:
|
|
corpus_id = code_docs["doc_id"].index(doc_id)
|
|
corpus_id = str(corpus_id)
|
|
qrels.append({"query-id": query_id, "corpus-id": corpus_id, "score": 1})
|
|
|
|
return queries, docs, qrels
|
|
|
|
|
|
def main():
|
|
if '_' in args.dataset_name:
|
|
dataset_name = args.dataset_name.split('_')[0]
|
|
language = args.dataset_name.split('_')[1]
|
|
else:
|
|
dataset_name = args.dataset_name
|
|
language = 'en'
|
|
dataset = datasets.load_dataset(dataset_name, language) # english version by default
|
|
|
|
path = os.path.join(args.output_dir, args.output_name.replace('en', language))
|
|
os.makedirs(path, exist_ok=True)
|
|
os.makedirs(os.path.join(path, "qrels"), exist_ok=True)
|
|
|
|
docs, queries = [], []
|
|
for split in ["test"]:
|
|
queries_split, docs_split, qrels_split = document2code(dataset, split)
|
|
docs += docs_split
|
|
queries += queries_split
|
|
|
|
save_tsv_dict(qrels_split, os.path.join(path, "qrels", "{}.tsv".format(split)), ["query-id", "corpus-id", "score"])
|
|
|
|
save_file_jsonl(queries, os.path.join(path, "queries.jsonl"))
|
|
save_file_jsonl(docs, os.path.join(path, "corpus.jsonl"))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--dataset_name", type=str, default="neulab/odex")
|
|
parser.add_argument("--output_name", type=str, default="odex_en")
|
|
parser.add_argument("--canonical_file", type=str, default="datasets/canonical/odex_docs.json")
|
|
parser.add_argument("--output_dir", type=str, default="datasets")
|
|
args = parser.parse_args()
|
|
|
|
main()
|