mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-06-27 02:39:58 +00:00
248 lines
8.6 KiB
Python
248 lines
8.6 KiB
Python
import os
|
|
import re
|
|
import chardet
|
|
import unidiff
|
|
import argparse
|
|
import datasets
|
|
import traceback
|
|
import subprocess
|
|
from git import Repo
|
|
from tqdm import tqdm
|
|
from pathlib import Path
|
|
from tempfile import TemporaryDirectory
|
|
from create.utils import save_tsv_dict, save_file_jsonl
|
|
|
|
# %% Get oracle file contents
|
|
|
|
# get oracle file contents from the repo
|
|
class ContextManager:
|
|
def __init__(self, repo_path, base_commit, verbose=False):
|
|
self.repo_path = Path(repo_path).resolve().as_posix()
|
|
self.old_dir = os.getcwd()
|
|
self.base_commit = base_commit
|
|
self.verbose = verbose
|
|
|
|
def __enter__(self):
|
|
os.chdir(self.repo_path)
|
|
cmd = f"git reset --hard {self.base_commit} && git clean -fdxq"
|
|
if self.verbose:
|
|
subprocess.run(cmd, shell=True, check=True)
|
|
else:
|
|
subprocess.run(
|
|
cmd,
|
|
shell=True,
|
|
check=True,
|
|
stdout=subprocess.DEVNULL,
|
|
stderr=subprocess.DEVNULL,
|
|
)
|
|
return self
|
|
|
|
def get_environment(self):
|
|
raise NotImplementedError() # TODO: activate conda environment and return the environment file
|
|
|
|
def get_readme_files(self):
|
|
files = os.listdir(self.repo_path)
|
|
files = list(filter(lambda x: os.path.isfile(x), files))
|
|
files = list(filter(lambda x: x.lower().startswith("readme"), files))
|
|
return files
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
os.chdir(self.old_dir)
|
|
|
|
|
|
class AutoContextManager(ContextManager):
|
|
"""Automatically clones the repo if it doesn't exist"""
|
|
|
|
def __init__(self, instance, root_dir=None, verbose=False, token=None):
|
|
if token is None:
|
|
token = os.environ.get("GITHUB_TOKEN", "git")
|
|
self.tempdir = None
|
|
if root_dir is None:
|
|
self.tempdir = TemporaryDirectory()
|
|
root_dir = self.tempdir.name
|
|
self.root_dir = root_dir
|
|
repo_dir = os.path.join(self.root_dir, instance["repo"].replace("/", "__"))
|
|
if not os.path.exists(repo_dir):
|
|
repo_url = (
|
|
f"https://{token}@github.com/swe-bench/"
|
|
+ instance["repo"].replace("/", "__")
|
|
+ ".git"
|
|
)
|
|
if verbose:
|
|
print(f"Cloning {instance['repo']} to {root_dir}")
|
|
Repo.clone_from(repo_url, repo_dir)
|
|
super().__init__(repo_dir, instance["base_commit"], verbose=verbose)
|
|
self.instance = instance
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
if self.tempdir is not None:
|
|
self.tempdir.cleanup()
|
|
return super().__exit__(exc_type, exc_val, exc_tb)
|
|
|
|
|
|
def ingest_files(filenames):
|
|
files_dict = dict()
|
|
for filename in filenames:
|
|
with open(filename) as f:
|
|
content = f.read()
|
|
files_dict[filename] = content
|
|
return files_dict
|
|
|
|
def get_oracle_filenames(instance):
|
|
"""
|
|
Returns the filenames that are changed in the patch
|
|
"""
|
|
source_files = {
|
|
patch_file.source_file.split("a/", 1)[-1]
|
|
for patch_file in unidiff.PatchSet(instance["patch"])
|
|
}
|
|
gold_docs = set()
|
|
for source_file in source_files:
|
|
gold_docs.add(source_file)
|
|
return gold_docs
|
|
|
|
|
|
# get all file contents from the repo
|
|
def is_test(name, test_phrases=None):
|
|
if test_phrases is None:
|
|
test_phrases = ["test", "tests", "testing"]
|
|
words = set(re.split(r" |_|\/|\.", name.lower()))
|
|
return any(word in words for word in test_phrases)
|
|
|
|
def list_files(root_dir, include_tests=False):
|
|
files = []
|
|
for filename in Path(root_dir).rglob("*.py"):
|
|
if not include_tests and is_test(filename.as_posix()):
|
|
continue
|
|
files.append(filename.relative_to(root_dir).as_posix())
|
|
return files
|
|
|
|
def detect_encoding(filename):
|
|
"""
|
|
Detect the encoding of a file
|
|
"""
|
|
with open(filename, "rb") as file:
|
|
rawdata = file.read()
|
|
return chardet.detect(rawdata)["encoding"]
|
|
|
|
def ingest_directory_contents(root_dir, include_tests=False):
|
|
files_content = {}
|
|
for relative_path in list_files(root_dir, include_tests=include_tests):
|
|
filename = os.path.join(root_dir, relative_path)
|
|
encoding = detect_encoding(filename)
|
|
if encoding is None:
|
|
content = "[BINARY DATA FILE]"
|
|
else:
|
|
try:
|
|
with open(filename, encoding=encoding) as file:
|
|
content = file.read()
|
|
except (UnicodeDecodeError, LookupError):
|
|
content = "[BINARY DATA FILE]"
|
|
files_content[relative_path] = content
|
|
return files_content
|
|
|
|
def get_file_contents(input_instances, verbose: bool = False, tmp_dir: str = "/scratch"):
|
|
orig_dir = os.getcwd()
|
|
with TemporaryDirectory(dir=tmp_dir if os.path.exists(tmp_dir) else "/tmp") as root_dir:
|
|
for instance_id, instance in tqdm(
|
|
input_instances.items(),
|
|
total=len(input_instances),
|
|
desc="Getting file contents",
|
|
):
|
|
try:
|
|
with AutoContextManager(instance, root_dir, verbose=verbose) as cm:
|
|
readmes = cm.get_readme_files()
|
|
instance["readmes"] = ingest_files(readmes)
|
|
instance["oracle_file_contents"] = ingest_files(get_oracle_filenames(instance))
|
|
instance["file_contents"] = ingest_directory_contents(cm.repo_path)
|
|
assert all([
|
|
okey in instance["file_contents"]
|
|
for okey in instance["oracle_file_contents"].keys()
|
|
])
|
|
except Exception as e:
|
|
print(f"Failed on instance {instance_id}", e)
|
|
traceback.print_exc()
|
|
finally:
|
|
# if AutoContextManager fails to exit properly future exits will return the wrong directory
|
|
os.chdir(orig_dir)
|
|
os.chdir(orig_dir)
|
|
|
|
|
|
# %% Get queries, docs, and qrels
|
|
|
|
def document2code(data, split: str = "test"):
|
|
subset = data[split]
|
|
if args.num_examples is not None:
|
|
import random
|
|
indices = random.sample([i for i in range(len(subset))], args.num_examples)
|
|
subset = subset.select(indices)
|
|
print(subset)
|
|
|
|
# get queries for each example
|
|
queries = [
|
|
{
|
|
"_id": item["instance_id"],
|
|
"text": item["problem_statement"],
|
|
"metadata": {}
|
|
}
|
|
for item in subset
|
|
]
|
|
|
|
subset_dict = {x["instance_id"]: x for x in subset}
|
|
get_file_contents(subset_dict, tmp_dir=args.tmp_dir)
|
|
|
|
# collect all docs, i.e., code chunks from the repo
|
|
docs = []
|
|
for instance_id, instance in subset_dict.items():
|
|
print(f"Instance #{instance_id}: {len(instance['oracle_file_contents'])} oracle / {len(instance['file_contents'])} files")
|
|
for filename, content in instance["file_contents"].items():
|
|
docs.append({
|
|
"_id": f"{instance_id}_{filename}",
|
|
"title": filename,
|
|
"text": content,
|
|
"metadata": {},
|
|
})
|
|
|
|
# find ground-truth docs for each example
|
|
qrels = []
|
|
for instance_id, instance in subset_dict.items():
|
|
for filename, content in instance["oracle_file_contents"].items():
|
|
qrels.append({
|
|
"query-id": instance_id,
|
|
"corpus-id": f"{instance_id}_{filename}",
|
|
"score": 1
|
|
})
|
|
|
|
return queries, docs, qrels
|
|
|
|
|
|
def main():
|
|
dataset = datasets.load_dataset(args.dataset_name, cache_dir=args.cache_dir)
|
|
|
|
name = "swe-bench"
|
|
if "lite" in args.dataset_name.lower():
|
|
name += "-lite"
|
|
|
|
path = os.path.join(args.output_dir, name)
|
|
os.makedirs(path, exist_ok=True)
|
|
os.makedirs(os.path.join(path, "qrels"), exist_ok=True)
|
|
|
|
queries, docs, qrels = document2code(dataset, split="test")
|
|
save_file_jsonl(queries, os.path.join(path, "queries.jsonl"))
|
|
save_file_jsonl(docs, os.path.join(path, "corpus.jsonl"))
|
|
qrels_path = os.path.join(path, "qrels", "test.tsv")
|
|
save_tsv_dict(qrels, qrels_path, ["query-id", "corpus-id", "score"])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--dataset_name", type=str, default="princeton-nlp/SWE-bench_Lite",
|
|
choices=["princeton-nlp/SWE-bench", "princeton-nlp/SWE-bench_Lite"])
|
|
parser.add_argument("--cache_dir", type=str, default="/scratch/zhiruow/data")
|
|
parser.add_argument("--tmp_dir", type=str, default="/scratch/zhiruow/tmp")
|
|
parser.add_argument("--output_dir", type=str, default="datasets")
|
|
parser.add_argument("--num_examples", type=int, default=None)
|
|
args = parser.parse_args()
|
|
|
|
main()
|