Merge pull request #1449 from 545999961/master

update code
This commit is contained in:
chaofan 2025-05-16 18:32:04 +08:00 committed by GitHub
commit c597c2de1a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 2887 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,104 @@
import os
import random
import datasets
from tqdm import tqdm
from typing import List, Tuple
from utils import clean_code
from constant import DocLength
class CorpusGenerator:
def __init__(
self,
cache_dir: str = None,
):
self.cache_dir = cache_dir
def _load_corpus(self, corpus_dir: str, doc_length: List[str], external_path: List[str],
source_language: str, stop_threshold: int = -1):
"""
Load availavle documents for a given task from the CoIR-Retrieval dataset.
"""
corpus_list = []
if corpus_dir is not None and os.path.exists(corpus_dir):
file_list = os.listdir(corpus_dir)
random.shuffle(file_list)
for file in file_list:
flag = False
if not file.endswith('.jsonl'):
flag = False
for d_length in doc_length:
d_length = DocLength[d_length].value
if d_length in file:
flag = True
if flag is False:
continue
file_path = os.path.join(corpus_dir, file)
corpus = datasets.load_dataset('json', data_files=file_path, cache_dir=self.cache_dir)['train']
for data in tqdm(corpus, desc="Loading corpus"):
if source_language is None:
lang = os.path.basename(corpus_dir)
data['language'] = lang
else:
data['language'] = source_language
text = clean_code(data["text"], data["language"], length_threshold=200)
data["text"] = text
if text != '':
corpus_list.append(data)
if stop_threshold > 0 and len(corpus_list) > stop_threshold:
break
break
for ep in external_path:
if os.path.exists(ep):
corpus = datasets.load_dataset('json', data_files=ep, cache_dir=self.cache_dir)['train']
for data in tqdm(corpus, desc="Loading corpus"):
if source_language is None:
lang = os.path.basename(os.path.dirname(ep))
data['language'] = lang
else:
data['language'] = source_language
# useful when the text is not present in the data
if "text" not in data:
data["text"] = data["pos"][0]
corpus_list.append(data)
text = clean_code(data["text"], lang, length_threshold=200)
data["text"] = text
if text != '':
corpus_list.append(data)
return corpus_list
def run(
self,
num_samples: int = -1,
max_corpus: int = -1,
corpus_dir: str = None,
doc_length: List[str] = ["len_0_500"],
external_path: List[str] = None,
source_language: str = None
) -> Tuple[List[dict], List[dict]]:
stop_threshold = max(num_samples * 10, max_corpus * 2)
corpus_list = self._load_corpus(
corpus_dir, doc_length, external_path, source_language, stop_threshold
)
if num_samples > 0 and num_samples < len(corpus_list):
small_corpus_list = random.sample(corpus_list, num_samples)
else:
small_corpus_list = corpus_list
if max_corpus > 0 and max_corpus < len(corpus_list):
corpus_list = random.sample(corpus_list, max_corpus)
else:
corpus_list = corpus_list
return small_corpus_list, corpus_list

View File

@ -0,0 +1,127 @@
import os
import json
from constant import Language, CodeLanguage, TaskType, CODE_TRANSLATION_RETRIEVAL_PAIRS, \
get_pos_as_input_by_task_type
def format_generated_examples(
file_path: str,
save_path: str,
task_type: TaskType
):
if os.path.exists(save_path):
return
if not os.path.exists(file_path):
print("====================================")
print("Warning: file not found! Maybe need to generate it first.")
print(f"file_path: {file_path}")
return
pos_as_input = get_pos_as_input_by_task_type(task_type)
data_list = []
with open(file_path, "r", encoding="utf-8") as f:
for line in f.readlines():
data = json.loads(line)
if pos_as_input:
_input = data["pos"][0]
_output = data["query"]
else:
_input = data["query"]
_output = data["pos"][0]
if 'provided' in _input:
continue
if len(_input) > 12000 or len(_output) > 12000:
continue
data_list.append({
"input": _input,
"output": _output
})
if len(data_list) == 0:
print("====================================")
print("Warning: no data found!")
print(f"file_path: {file_path}")
return
os.makedirs(os.path.dirname(save_path), exist_ok=True)
with open(save_path, "w", encoding="utf-8") as f:
json.dump(data_list, f, indent=4, ensure_ascii=False)
def main():
original_gen_examples_dir = "./examples"
formatted_examples_dir = "./filtered_for_generation"
for language in Language:
for task_type in TaskType:
if task_type == TaskType.code_translation_retrieval:
for code_language_pair in CODE_TRANSLATION_RETRIEVAL_PAIRS:
code_language, tgt_code_language = code_language_pair
file_path = os.path.join(
original_gen_examples_dir,
language.name, task_type.name, f"{language.name}-{code_language.name}-to-{tgt_code_language.name}-triplets.jsonl"
)
save_path = os.path.join(
formatted_examples_dir,
language.name, task_type.name, f"{code_language.name}-to-{tgt_code_language.name}_sample_examples.json"
)
format_generated_examples(file_path, save_path, task_type)
for code_language_pair in CODE_TRANSLATION_RETRIEVAL_PAIRS:
tgt_code_language, code_language = code_language_pair
file_path = os.path.join(
original_gen_examples_dir,
language.name, task_type.name, f"{language.name}-{code_language.name}-to-{tgt_code_language.name}-triplets.jsonl"
)
save_path = os.path.join(
formatted_examples_dir,
language.name, task_type.name, f"{code_language.name}-to-{tgt_code_language.name}_sample_examples.json"
)
format_generated_examples(file_path, save_path, task_type)
elif task_type == TaskType.text2sql_retrieval:
file_path = os.path.join(
original_gen_examples_dir,
language.name, task_type.name, f"{language.name}-sql-triplets.jsonl"
)
save_path = os.path.join(
formatted_examples_dir,
language.name, task_type.name, "sql_sample_examples.json"
)
format_generated_examples(file_path, save_path, task_type)
elif task_type == TaskType.code_context_retrieval:
continue
else:
for code_language in CodeLanguage:
if code_language == CodeLanguage.null:
continue
file_path = os.path.join(
original_gen_examples_dir,
language.name, task_type.name, f"{language.name}-{code_language.name}-triplets.jsonl"
)
save_path = os.path.join(
formatted_examples_dir,
language.name, task_type.name, f"{code_language.name}_sample_examples.json"
)
format_generated_examples(file_path, save_path, task_type)
print("All done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,134 @@
import os
import time
import openai
import random
import tiktoken
import threading
from openai import OpenAI, AzureOpenAI
from typing import Tuple
class LLM:
def __init__(
self,
model: str="Qwen2-5-Coder-32B-Instruct",
model_type: str = "open-source",
port: int = 8000,
):
if model_type == "open-source":
self.client = OpenAI(
api_key="EMPTY",
base_url=f"http://localhost:{port}/v1/"
)
elif model_type == "azure":
self.client = AzureOpenAI(
api_key=os.getenv("OPENAI_API_KEY"),
api_version=os.getenv("AZURE_API_VERSION", "2024-02-01"),
azure_endpoint=os.getenv("AZURE_ENDPOINT"),
azure_deployment=os.getenv("OPENAI_DEPLOYMENT_NAME", 'gpt-35-turbo')
)
elif model_type == "openai":
self.client = OpenAI(
api_key=os.getenv("OPENAI_API_KEY"),
base_url=os.getenv("OPENAI_BASE_URL", None)
)
else:
raise ValueError("model_type must be one of ['open-source', 'azure', 'openai']")
self.model = model
self.tokenizer = tiktoken.get_encoding("o200k_base")
def split_text(self, text: str, anchor_points: Tuple[float, float] = (0.4, 0.7)):
token_ids = self.tokenizer.encode(text)
anchor_point = random.uniform(anchor_points[0], anchor_points[1])
split_index = int(len(token_ids) * anchor_point)
return self.tokenizer.decode(token_ids[:split_index]), self.tokenizer.decode(token_ids[split_index:])
def chat(
self,
prompt: str,
max_tokens: int = 8192,
logit_bais: dict = None,
n: int = 1,
temperature: float = 1.0,
top_p: float = 0.6,
repetition_penalty: float = 1.0,
remove_thinking: bool = True,
timeout: int = 90,
):
endure_time = 0
endure_time_limit = timeout * 2
def create_completion(results):
try:
completion = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens,
logit_bias=logit_bais if logit_bais is not None else {},
n=n,
temperature=temperature,
top_p=top_p,
extra_body={'repetition_penalty': repetition_penalty},
timeout=timeout,
)
results["content"] = [x.message.content for x in completion.choices[:n]]
except openai.BadRequestError as e:
# The response was filtered due to the prompt triggering Azure OpenAI's content management policy.
results["content"] = [None for _ in range(n)]
except openai.APIConnectionError as e:
results["error"] = f'APIConnectionError({e})'
except openai.RateLimitError as e:
results["error"] = f'RateLimitError({e})'
except Exception as e:
results["error"] = f"Error: {e}"
while True:
results = {"content": None, "error": None}
completion_thread = threading.Thread(target=create_completion, args=(results,))
completion_thread.start()
start_time = time.time()
while completion_thread.is_alive():
elapsed_time = time.time() - start_time
if elapsed_time > endure_time_limit:
print("Completion timeout exceeded. Aborting...")
return [None for _ in range(n)]
time.sleep(1)
# If an error occurred during result processing
if results["error"]:
if endure_time >= endure_time_limit:
print(f'{results["error"]} - Skip this prompt.')
return [None for _ in range(n)]
print(f"{results['error']} - Waiting for 5 seconds...")
endure_time += 5
time.sleep(5)
continue
content_list = results["content"]
if remove_thinking:
content_list = [x.split('</think>')[-1].strip('\n').strip() if x is not None else None for x in content_list]
return content_list
if __name__ == "__main__":
llm = LLM(
model="gpt-4o-mini-2024-07-18",
model_type="openai"
)
prompt = "hello, who are you?"
response = llm.chat(prompt)[0]
print(response)
if __name__ == "__main__":
llm = LLM(
model="gpt-4o-mini-2024-07-18",
model_type="openai"
)
prompt = "hello, who are you?"
response = llm.chat(prompt)[0]
print(response)

View File

@ -0,0 +1,368 @@
import os
import json
import time
import gc
import torch
import argparse
import random
from hashlib import md5
import multiprocessing as mp
from typing import List, Optional
from constant import TaskType, Language, CodeLanguage, NUM_HARD_NEGATIVES
from corpus_generator import CorpusGenerator
from triplet_generator import TripletGenerator
from search import get_top1
def compute_md5(text: str):
return md5(text.encode()).hexdigest()
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'--task_type',
type=str,
required=True,
help='The task type to generate data for',
choices=[t.name for t in TaskType]
)
parser.add_argument(
'--code_language',
type=str,
required=True,
help='The code language to generate questions for.',
choices=[c.name for c in CodeLanguage]
)
parser.add_argument(
'--corpus_root',
type=str,
required=True,
help='The root directory of the corpus data.'
)
parser.add_argument(
'--save_dir',
type=str,
required=True,
help='The path to save the generated data'
)
parser.add_argument(
'--examples_dir',
type=str,
default=None,
help='The path to the examples directory. If not None, the examples will be used for few-shot generation.'
)
parser.add_argument(
'--num_examples',
type=int,
default=3,
help='The number of examples to use for few-shot generation. Default: 3'
)
parser.add_argument(
'--cache_dir',
type=str,
default=None,
help='The cache directory'
)
parser.add_argument(
'--language',
type=str,
default='en',
help='The language to generate for. ISO 639-1 code. Default: en',
choices=[l.name for l in Language]
)
parser.add_argument(
'--tgt_code_language',
type=str,
default=None,
help='The target code language to generate code translations for.',
choices=[c.name for c in CodeLanguage]
)
parser.add_argument(
'--num_samples',
type=int,
default=-1,
help='The number of examples to use for generation. Default: -1. Use all available examples.'
)
parser.add_argument(
'--model',
type=str,
default='Qwen2.5-72B-Instruct',
help='The model to use for generation. Default: Qwen2.5-72B-Instruct'
)
parser.add_argument(
'--model_type',
type=str,
default='open-source',
help='The type of model to use for generation. Default: open-source',
)
parser.add_argument(
'--port',
type=int,
default=8000,
help='The port for vllm.'
)
parser.add_argument(
'--num_processes',
type=int,
default=1,
help='The number of processes to use for generation. Default: 1'
)
parser.add_argument(
'--doc_length',
type=str,
default='len_0_500',
help='The corpus length used to load dataset. Default: len_0_500'
)
parser.add_argument(
'--external_path',
type=str,
default='',
help='The corpus length used to load dataset. Default: len_0_500'
)
parser.add_argument(
'--sim_model_name',
type=str,
default=None,
help='The language of source corpus.'
)
parser.add_argument(
'--max_corpus',
type=int,
default=500000,
help='The max num of corpus to load.'
)
parser.add_argument(
'--overwrite',
action='store_true',
help='Whether to overwrite the existing data.'
)
parser.add_argument(
'--debug_mode',
action='store_true',
help='Whether to open debug mode.'
)
parser.add_argument(
'--gen_hard_neg',
action='store_true',
help='Whether to generate hard negatives.'
)
parser.add_argument(
'--seed',
type=int,
default=None,
help='Random seed for generating triplets using the same positive. Default: 42'
)
args = parser.parse_args()
return args
def gen_triplets(
model: str,
model_type: str,
port: int,
positives: List[dict],
task_type: str,
language: str,
code_language: str,
tgt_code_language: str,
examples_pool: Optional[List[dict]] = None,
num_examples: int = 3,
tqdm_desc: str = "Generating triplets",
thread_count: int = 1,
gen_cache_dir: Optional[str] = None,
debug_mode: bool = False,
gen_hard_neg: bool = False,
):
triplet_generator = TripletGenerator(model, model_type, port, cache_dir=gen_cache_dir)
triplets = triplet_generator.run(
positives=positives,
task_type=task_type,
language=language,
code_language=code_language,
tgt_code_language=tgt_code_language,
examples_pool=examples_pool,
num_examples=num_examples,
tqdm_desc=tqdm_desc,
thread_count=thread_count,
debug_mode=debug_mode,
gen_hard_neg=gen_hard_neg,
num_negatives=NUM_HARD_NEGATIVES,
)
return triplets
def get_save_path(
save_dir: str,
task_type: str,
language: str,
code_language: str,
tgt_code_language: Optional[str] = None
):
save_dir = os.path.join(save_dir, language, task_type)
if tgt_code_language is not None:
file_name = f"{language}-{code_language}-to-{tgt_code_language}-triplets.jsonl"
else:
file_name = f"{language}-{code_language}-triplets.jsonl"
save_path = os.path.join(save_dir, file_name)
os.makedirs(save_dir, exist_ok=True)
return save_path
def save_triplets(
triplets: list,
save_dir: str,
task_type: str,
language: str,
code_language: str,
tgt_code_language: Optional[str] = None
):
if len(triplets) == 0:
print(f"No triplets to save: {task_type} | {language} | {code_language} | {tgt_code_language}")
return
save_path = get_save_path(save_dir, task_type, language, code_language, tgt_code_language)
query_md5s = set()
pos_md5s = set()
old_triplets = []
if os.path.exists(save_path):
with open(save_path, "r", encoding="utf-8") as f:
for line in f.readlines():
triplet = json.loads(line)
old_triplets.append(triplet)
query_md5s.add(compute_md5(triplet['query']))
pos_md5s.add(compute_md5(triplet['pos'][0]))
with open(save_path, 'w', encoding='utf-8') as f:
for triplet in old_triplets:
f.write(json.dumps(triplet, ensure_ascii=False) + '\n')
for triplet in triplets:
_query_md5 = compute_md5(triplet['query'])
_pos_md5 = compute_md5(triplet['pos'][0])
if _query_md5 in query_md5s or _pos_md5 in pos_md5s:
continue
f.write(json.dumps(triplet, ensure_ascii=False) + '\n')
print(f"Triplets saved to {save_path}")
def main(args):
# set seed
seed = args.seed
if seed is not None:
print(f"------------------- Seed set to {seed} -------------------")
random.seed(seed)
model = args.model
model_type = args.model_type
port = args.port
num_samples = args.num_samples
task_type = args.task_type
language = args.language
code_language = args.code_language
tgt_code_language = args.tgt_code_language
corpus_root = args.corpus_root
corpus_dir = os.path.join(corpus_root, code_language)
doc_length = args.doc_length.split()
external_path = args.external_path.split()
save_dir = args.save_dir
cache_dir = args.cache_dir
num_processes = min(args.num_processes, int(mp.cpu_count() * 0.8))
overwrite = args.overwrite
debug_mode = args.debug_mode
gen_hard_neg = args.gen_hard_neg
save_path = get_save_path(save_dir, task_type, language, code_language, tgt_code_language)
# if os.path.exists(save_path) and not overwrite:
# data = []
# with open(save_path) as f:
# for line in f:
# data.append(json.loads(line))
# if len(data) >= num_samples * 0.8:
# print(f"Triplets already exist at {save_path}. Skipping generation.")
# return
# else:
# print(f"Triplets already exist at {save_path}. But samples is really small, continue generation.")
# num_samples = int((num_samples - len(data)) * 1.25) # consider the filtered samples
corpus_generator = CorpusGenerator(cache_dir)
examples_dir = args.examples_dir
num_examples = args.num_examples
if examples_dir is not None:
# if task_type in ["single_turn_code_qa", "multi_turn_code_qa"]:
# examples_path = os.path.join(examples_dir, language, task_type, "sample_examples.json")
if task_type in ["code_translation_retrieval"]:
examples_path = os.path.join(examples_dir, language, task_type,
f"{code_language}-to-{tgt_code_language}_sample_examples.json")
else:
examples_path = os.path.join(examples_dir, language, task_type, f"{code_language}_sample_examples.json")
try:
with open(examples_path, 'r', encoding='utf-8') as f:
examples_pool = json.load(f)
examples_pool = random.sample(examples_pool,
min(30, len(examples_pool))) # sample 30 examples for few-shot generation
except:
print(f'Error for loading examples from {examples_path}')
examples_pool = None
else:
examples_pool = None
positives, large_positives = corpus_generator.run(
num_samples=num_samples,
max_corpus=args.max_corpus,
corpus_dir=corpus_dir,
doc_length=doc_length,
external_path=external_path,
source_language=code_language
)
if task_type in ["code_modification_retrieval", "code_comparison_retrieval"]:
top1_docs = get_top1([e['text'] for e in positives], args.sim_model_name, [e['text'] for e in large_positives])
for i in range(len(top1_docs)):
positives[i]['similar'] = top1_docs[i]
gc.collect()
torch.cuda.empty_cache()
print("=================== Generate training data ===================")
print(f'Task Type: {task_type} | Language: {language} | Code Language: {code_language} | Target Code Language: {tgt_code_language}')
start_time = time.time()
triplets = gen_triplets(
model=model,
model_type=model_type,
port=port,
positives=positives,
task_type=task_type,
language=language,
code_language=code_language,
tgt_code_language=tgt_code_language,
examples_pool=examples_pool,
num_examples=num_examples,
thread_count=num_processes,
gen_cache_dir=os.path.join(save_dir, language, task_type, "gen_cache_dir"),
debug_mode=debug_mode,
gen_hard_neg=gen_hard_neg,
)
save_triplets(
triplets=triplets,
save_dir=save_dir,
task_type=task_type,
language=language,
code_language=code_language,
tgt_code_language=tgt_code_language
)
end_time = time.time()
print("=============================================================")
print(f"Time taken: {end_time - start_time:.2f} seconds")
print("=============================================================")
print("DONE!")
if __name__ == "__main__":
args = get_args()
main(args)

View File

@ -0,0 +1,71 @@
from typing import Optional, List
import faiss
import numpy as np
from tqdm import tqdm
from FlagEmbedding import FlagModel
def create_index(embeddings: np.ndarray, use_gpu: bool = False):
index = faiss.IndexFlatIP(len(embeddings[0]))
embeddings = np.asarray(embeddings, dtype=np.float32)
if use_gpu:
co = faiss.GpuMultipleClonerOptions()
co.shard = True
co.useFloat16 = True
index = faiss.index_cpu_to_all_gpus(index, co=co)
index.add(embeddings)
return index
def search(
faiss_index: faiss.Index,
k: int = 100,
query_embeddings: Optional[np.ndarray] = None,
load_path: Optional[str] = None
):
if query_embeddings is None:
query_embeddings = np.load(load_path)
query_size = len(query_embeddings)
all_scores = []
all_indices = []
for i in tqdm(range(0, query_size, 32), desc="Searching"):
j = min(i + 32, query_size)
query_embedding = query_embeddings[i: j]
score, indice = faiss_index.search(query_embedding.astype(np.float32), k=k)
all_scores.append(score)
all_indices.append(indice)
all_scores = np.concatenate(all_scores, axis=0)
all_indices = np.concatenate(all_indices, axis=0)
return all_scores, all_indices
def get_top1(
small_docs,
encoder_name,
docs: List[str],
top: int = 1
):
encoder = FlagModel(encoder_name, trust_remote_code=True)
doc_emb = encoder.encode_corpus(docs, max_length=512, batch_size=256)
small_doc_emb = encoder.encode_corpus(small_docs, max_length=512, batch_size=256)
faiss_index = create_index(doc_emb, True)
all_scores, all_indices = search(faiss_index, 1000, small_doc_emb)
return_docs = []
for i in range(len(all_indices)):
return_docs.append([])
for idx, score in zip(all_indices[i][20:], all_scores[i][20:]):
d1 = set(docs[idx].split())
d2 = set(small_docs[i].split())
if len(d1 & d2) / len(d1 | d2) > 0.95:
continue
return_docs[-1].append(docs[idx])
if len(return_docs[-1]) >= top:
break
if len(return_docs[-1]) == 0:
print(all_indices[i], all_scores[i])
# print(return_docs)
del faiss_index
return return_docs

View File

@ -0,0 +1,654 @@
import os
import json
import random
from tqdm import tqdm
from hashlib import md5
from warnings import warn
from typing import List, Optional
from concurrent.futures import ThreadPoolExecutor
from llm import LLM
from utils import clean_content
from constant import TaskType, Task, SPECIAL_TASK_STEPS, \
get_task, get_generation_prompt, get_quality_control_prompt, \
get_gen_hard_neg_prompt
def compute_md5(text: str):
return md5(text.encode()).hexdigest()
class TripletGenerator(LLM):
def __init__(
self,
model: str = "Qwen2-5-Coder-32B-Instruct",
model_type: str = "open-source",
port: int = 8000,
cache_dir: Optional[str] = None
):
super().__init__(model, model_type, port)
self.cache_dir = cache_dir
if self.cache_dir is not None:
os.makedirs(self.cache_dir, exist_ok=True)
def _gen_for_code_modification_retrieval(
self,
task: Task,
text: str,
text_b: Optional[str] = None,
examples: Optional[List[dict]] = None,
debug_mode: bool = False,
**kwargs
):
gen_prompt = get_generation_prompt(
task=task,
text=text,
text_b=text_b,
examples=examples,
idx=0
)
response = self.chat(gen_prompt, **kwargs)[0]
diff = clean_content(response)
gen_prompt = get_generation_prompt(
task=task,
text=diff,
examples=examples,
idx=1
)
response = self.chat(gen_prompt, **kwargs)[0]
modification_instr = clean_content(response)
query = f"{modification_instr}\n```\n{text}\n```"
pos = text_b
if debug_mode:
result = {
"generation_prompt": gen_prompt,
"prompt": task.task_instruction,
"query": query,
"pos": [pos],
"neg": []
}
else:
result = {
"prompt": task.task_instruction,
"query": query,
"pos": [pos],
"neg": []
}
return result
def _gen_for_code_comparison_retrieval(
self,
task: Task,
text: str,
text_b: Optional[str] = None,
examples: Optional[List[dict]] = None,
debug_mode: bool = False,
**kwargs
):
gen_prompt = get_generation_prompt(
task=task,
text=text,
text_b=text_b,
examples=examples,
idx=0
)
response = self.chat(gen_prompt, **kwargs)[0]
diff_question = clean_content(response)
query = f"{diff_question}\n\nInput Code:\n```\n{text}\n```\n\nOutput Code:\n```\n{text_b}\n```"
gen_prompt = get_generation_prompt(
task=task,
text=query,
examples=examples,
idx=1
)
response = self.chat(gen_prompt, **kwargs)[0]
pos = clean_content(response)
if debug_mode:
result = {
"generation_prompt": gen_prompt,
"prompt": task.task_instruction,
"query": query,
"pos": [pos],
"neg": []
}
else:
result = {
"prompt": task.task_instruction,
"query": query,
"pos": [pos],
"neg": []
}
return result
def _gen_for_code_context_retrieval(
self,
task: Task,
text: str,
anchor_points: Optional[tuple] = (0.4, 0.7),
**kwargs
):
former_part, latter_part = self.split_text(
text,
anchor_points=anchor_points
)
result = {
"prompt": task.task_instruction,
"query": former_part,
"pos": [latter_part],
"neg": []
}
return result
@staticmethod
def _arrange_query_and_pos(task: Task, input_text: str, response: str):
"""
Arrange the query and positive example based on the task type.
Args:
- task: Task
- input_text: str
- response: str
Returns:
- query: str
- pos: str
"""
# TODO: support more task types, including some special task types.
if task.main_task_type in ["text2code", "hybrid"]:
query = clean_content(response)
pos = input_text
else:
query = input_text
pos = clean_content(response)
return query, pos
def _gen_for_normal_task(
self,
task: Task,
text: str,
examples: Optional[List[dict]] = None,
debug_mode: bool = False,
**kwargs
):
gen_prompt = get_generation_prompt(
task=task,
text=text,
examples=examples
)
response = self.chat(gen_prompt, **kwargs)[0]
# Arrange the query and positive example based on the task type.
query, pos = self._arrange_query_and_pos(
task=task,
input_text=text,
response=response
)
if debug_mode:
result = {
"generation_prompt": gen_prompt,
"prompt": task.task_instruction,
"query": query,
"pos": [pos],
"neg": [],
"response": response
}
else:
result = {
"prompt": task.task_instruction,
"query": query,
"pos": [pos],
"neg": []
}
return result
def _gen_for_bug_desc_retrieval(
self,
task: Task,
text: str,
examples: Optional[List[dict]] = None,
debug_mode: bool = False,
**kwargs
):
gen_prompt = get_generation_prompt(
task=task,
text=text,
examples=examples,
idx=0
)
response = self.chat(gen_prompt, **kwargs)[0]
if response is None:
raise ValueError("Response is None.")
buggy_code = response
gen_prompt = get_generation_prompt(
task=task,
text=buggy_code,
examples=examples,
idx=1
)
response = self.chat(gen_prompt, **kwargs)[0]
query = clean_content(response)
pos = text
if debug_mode:
result = {
"generation_prompt": gen_prompt,
"prompt": task.task_instruction,
"query": query,
"pos": [pos],
"neg": []
}
else:
result = {
"prompt": task.task_instruction,
"query": query,
"pos": [pos],
"neg": []
}
return result
def _gen_for_two_step_not_use_last(
self,
task: Task,
text: str,
examples: Optional[List[dict]] = None,
debug_mode: bool = False,
reverse_query_pos: bool = False,
**kwargs
):
gen_prompt = get_generation_prompt(
task=task,
text=text,
idx=0
)
response = self.chat(gen_prompt, **kwargs)[0]
query = clean_content(response)
gen_prompt = get_generation_prompt(
task=task,
text=query,
examples=examples,
idx=1
)
response = self.chat(gen_prompt, **kwargs)[0]
pos = clean_content(response)
if reverse_query_pos:
query, pos = pos, query
if debug_mode:
result = {
"generation_prompt": gen_prompt,
"prompt": task.task_instruction,
"query": query,
"pos": [pos],
"neg": []
}
else:
result = {
"prompt": task.task_instruction,
"query": query,
"pos": [pos],
"neg": []
}
return result
def _gen_for_two_step_use_last(
self,
task: Task,
text: str,
examples: Optional[List[dict]] = None,
debug_mode: bool = False,
reverse_query_pos: bool = False,
**kwargs
):
gen_prompt = get_generation_prompt(
task=task,
text=text,
idx=0
)
response = self.chat(gen_prompt, **kwargs)[0]
query = clean_content(response) + f"\n```\n{text}\n```"
gen_prompt = get_generation_prompt(
task=task,
text=query,
examples=examples,
idx=1
)
response = self.chat(gen_prompt, **kwargs)[0]
pos = clean_content(response)
if reverse_query_pos:
query, pos = pos, query
if debug_mode:
result = {
"generation_prompt": gen_prompt,
"prompt": task.task_instruction,
"query": query,
"pos": [pos],
"neg": []
}
else:
result = {
"prompt": task.task_instruction,
"query": query,
"pos": [pos],
"neg": []
}
return result
def generate_triplets(
self,
data: dict,
task: Task,
examples_pool: Optional[List[dict]] = None,
num_examples: int = 3,
debug_mode: bool = False,
**kwargs
):
kwargs["remove_thinking"] = not debug_mode
result_list = []
examples = None
if examples_pool is not None:
examples = random.sample(examples_pool, min(num_examples, len(examples_pool)))
try:
if task.task_type in SPECIAL_TASK_STEPS:
text = data["text"]
if task.task_type == TaskType.code_modification_retrieval:
text_b = data["similar"][0]
result = self._gen_for_code_modification_retrieval(
task=task,
text=text,
text_b=text_b,
examples=examples,
debug_mode=debug_mode
)
elif task.task_type == TaskType.code_comparison_retrieval:
text_b = data["similar"][0]
result = self._gen_for_code_comparison_retrieval(
task=task,
text=text,
text_b=text_b,
examples=examples,
debug_mode=debug_mode
)
elif task.task_type == TaskType.bug_desc_retrieval:
result = self._gen_for_bug_desc_retrieval(
task=task,
text=text,
examples=examples,
debug_mode=debug_mode
)
elif task.task_type in [
# cf - updated
TaskType.code_issue_discussion_retrieval,
TaskType.code_version_update_retrieval,
TaskType.code_bug_fix_example_retrieval,
]:
result = self._gen_for_two_step_not_use_last(
task=task,
text=text,
examples=examples,
debug_mode=debug_mode,
reverse_query_pos=False
)
elif task.task_type in [
# cf - updated
TaskType.code_refactoring_pattern_retrieval,
TaskType.code_style_guideline_example_retrieval,
TaskType.code_migration_retrieval,
# jl - updated
TaskType.code_optimization_hybrid_retrieval,
TaskType.code_best_practices_retrieval,
TaskType.security_vulnerability_fix_retrieval,
]:
result = self._gen_for_two_step_use_last(
task=task,
text=text,
examples=examples,
debug_mode=debug_mode,
reverse_query_pos=False
)
else:
raise NotImplementedError(f"Task type {task.task_type} not implemented.")
elif task.task_type == TaskType.code_context_retrieval:
text = data["text"]
result = self._gen_for_code_context_retrieval(
task=task,
text=text,
**kwargs
)
# NOTE: no need to do quality control for code context retrieval task
result_list.append(result)
return result_list
else:
text = data["text"]
result = self._gen_for_normal_task(
task=task,
text=text,
examples=examples,
debug_mode=debug_mode,
**kwargs
)
# print(gen_prompt)
# print('================================================')
qc_prompt = get_quality_control_prompt(
task=task,
query=result["query"],
pos=result["pos"][0]
)
# print(qc_prompt)
# print('*********************************************************************')
response = self.chat(qc_prompt, **kwargs)[0]
judge = clean_content(response)
# print(response, judge)
if "1" in judge:
if debug_mode:
result["judge"] = judge
result["judge_response"] = response
result_list.append(result)
else:
if debug_mode:
result["judge"] = judge
result["judge_response"] = response
result_list.append(result)
except Exception as e:
warn(f"Error: {e}")
return result_list
def gen_hard_negatives(self, result: dict, task: Task, num_negatives: int = 7, **kwargs):
gen_hard_neg_prompt = get_gen_hard_neg_prompt(
task=task,
query=result["query"],
pos=result["pos"][0]
)
response_list = self.chat(gen_hard_neg_prompt, n=num_negatives, **kwargs)
for response in response_list:
if response is None:
continue
hard_neg = clean_content(response)
result["neg"].append(hard_neg)
result["neg"] = list(set(result["neg"]))
return result
def run_single(
self,
data: dict,
task: Task,
examples_pool: Optional[List[dict]] = None,
num_examples: int = 3,
debug_mode: bool = False,
gen_hard_neg: bool = False,
num_negatives: int = 7,
**kwargs
):
result_list = []
docid = compute_md5(data["text"])
if self.cache_dir is not None:
gen_data_cache_path = os.path.join(self.cache_dir, f"{docid}.json")
if os.path.exists(gen_data_cache_path):
with open(gen_data_cache_path, "r", encoding="utf-8") as f:
result_list = json.load(f)
if len(result_list) > 0:
if gen_hard_neg:
for i in range(len(result_list)):
if len(result_list[i]["neg"]) == 0:
result_list[i] = self.gen_hard_negatives(
result=result_list[i],
task=task,
num_negatives=num_negatives,
**kwargs
)
# overwrite the cache file
with open(gen_data_cache_path, "w", encoding="utf-8") as f:
json.dump(result_list, f, indent=4, ensure_ascii=False)
return result_list
triplets = self.generate_triplets(
data,
task=task,
examples_pool=examples_pool,
num_examples=num_examples,
debug_mode=debug_mode,
**kwargs
)
if len(triplets) == 0:
return []
result = triplets[0]
if debug_mode:
result["docid"] = docid
if gen_hard_neg:
result = self.gen_hard_negatives(
result,
task=task,
num_negatives=num_negatives,
**kwargs
)
result_list.append(result)
if self.cache_dir is not None:
gen_data_cache_path = os.path.join(self.cache_dir, f"{docid}.json")
with open(gen_data_cache_path, "w", encoding="utf-8") as f:
json.dump(result_list, f, indent=4, ensure_ascii=False)
return result_list
def run(
self,
positives: List[dict],
task_type: str,
language: str = "en",
code_language: str = "python",
tgt_code_language: Optional[str] = None,
examples_pool: Optional[List[dict]] = None,
num_examples: int = 3,
tqdm_desc: str = "Generating triplets",
debug_mode: bool = False,
gen_hard_neg: bool = False,
num_negatives: int = 7,
thread_count: int = 1,
**kwargs
):
task = get_task(
task_type=task_type,
language=language,
code_language=code_language,
tgt_code_language=tgt_code_language
)
result_list = []
def process_positive(positive):
return self.run_single(
data=positive,
task=task,
examples_pool=examples_pool,
num_examples=num_examples,
debug_mode=debug_mode,
gen_hard_neg=gen_hard_neg,
num_negatives=num_negatives,
**kwargs
)
# Use thread pool for parallel processing with tqdm progress bar.
with ThreadPoolExecutor(max_workers=thread_count) as executor:
results = list(tqdm(executor.map(
process_positive,
positives
), total=len(positives), desc=tqdm_desc))
# Collect results into result_list.
for res in results:
if isinstance(res, list):
result_list.extend(res)
else:
result_list.append(res)
# result_list.extend(results)
return result_list
def run_for_gen_neg(
self,
pairs: List[dict],
task_type: str,
language: str = "en",
code_language: str = "python",
tgt_code_language: Optional[str] = None,
examples_pool: Optional[List[dict]] = None,
num_examples: int = 3,
tqdm_desc: str = "Generating triplets",
debug_mode: bool = False,
gen_hard_neg: bool = False,
num_negatives: int = 7,
thread_count: int = 1,
**kwargs
):
task = get_task(
task_type=task_type,
language=language,
code_language=code_language,
tgt_code_language=tgt_code_language
)
result_list = []
def gen_single_negative(pair):
result = self.gen_hard_negatives(
pair,
task=task,
num_negatives=num_negatives,
**kwargs
)
return [result]
# Use thread pool for parallel processing with tqdm progress bar.
with ThreadPoolExecutor(max_workers=thread_count) as executor:
results = list(tqdm(executor.map(
gen_single_negative,
pairs
), total=len(pairs), desc=tqdm_desc))
# Collect results into result_list.
for res in results:
if isinstance(res, list):
result_list.extend(res)
else:
result_list.append(res)
# result_list.extend(results)
return result_list

View File

@ -0,0 +1,128 @@
import re
def clean_content(content: str):
if content is None:
raise ValueError("content is None.")
content = content.split('</think>')[-1].strip('\n').strip()
if content.startswith('\"') and content.endswith('\"'):
content = content[1:-1]
if content.startswith("```\n") and content.endswith("\n```"):
content = content[4:-4]
return content
def clean_code(code: str, lang: str, length_threshold: int = 30) -> str:
cleaned_code = code.strip('\ufeff').strip()
if not cleaned_code:
return ''
def clean_empty_lines(text: str) -> str:
return re.sub(r'\n\s*\n', '\n', text).strip()
# 各语言函数/类定义检测正则表达式
function_patterns = {
"java": r"(?m)^(?!\s*(import|package)\b).*\b(public\s+class|class\s+\w+|void\s+main|new\s+\w+\(|@Override)\b",
"python": r"(?m)^(?!\s*(import|from\s+\S+\s+import)\b).*\b(def\s+\w+|class\s+\w+|=\s*\S+|if\s+[:\w]|print\s+)",
"javascript": r"(?m)^(?!\s*(import|require\(|export\s)).*\b(function\s+\w+|const\s+\w+|=>|\(\)\s*=>|console\.log)",
"php": r"(?m)^(?!\s*(include|require|use)\b).*\b(function\s+\w+|echo\s+\S+|class\s+\w+)",
"ruby": r"(?m)^(?!\s*(require|load)\b).*\b(class\s+\w+|def\s+\w+|puts\s+\S+)",
"go": r"(?m)^(?!\s*import\b).*\bfunc\s+main\s*\(|type\s+\w+\s+struct",
"c#": r"(?m)^(?!\s*using\b).*\b(class\s+\w+|void\s+Main\s*\()",
"cplusplus": r"(?m)^(?!#include\b).*\b(int\s+main\s*\(|class\s+\w+|void\s+\w+\s*\(.*\)\s*{)",
"c": r"(?m)^(?!#include\b).*\b(int\s+main\s*\(|void\s+\w+\s*\(.*\)\s*{)",
"rust": r"(?m)^(?!\s*use\b).*\b(fn\s+main\s*\(|struct\s+\w+|impl\s+\w+)",
"typescript": r"(?m)^(?!\s*(import|require\(|export\s)).*\b(interface\s+\w+|class\s+\w+|function\s+\w+)",
"perl": r"(?m)^(?!\s*(use|require)\b).*\b(sub\s+\w+|my\s+\$\w+|print\s+\S+)",
"shell": r"(?m)^(?!\s*(source|\.)\s).*\b(function\s+\w+|if\s+\[|\$\(|echo\s+\S+)",
"sql": r"(?i)\b(CREATE\s+TABLE|SELECT\s+\*|INSERT\s+INTO|UPDATE\s+\w+|DELETE\s+FROM)\b",
"batchfile": r"(?m)^(?!\s*@?call\b).*\b(echo\s+\S+|set\s+\w+|if\s+.*\s+==\s+)",
"fortran": r"(?mi)^(?!\s*use\b).*\b(program\s+\w+|subroutine\s+\w+|do\s+\d+\s*,\s*\d+)",
"haskell": r"(?m)^(?!\s*import\b).*\b(main\s*=\s*do|data\s+\w+|putStrLn\s+\S+)",
"lua": r"(?m)^(?!\s*require\b).*\b(function\s+\w+|local\s+\w+|print\s*\()",
"powershell": r"(?m)^(?!\s*Import-Module\b).*\b(function\s+\w+|Write-Host\s+\S+|\$\w+\s*=)",
"visual_basic": r"(?m)^(?!\s*Imports\b).*\b(Module\s+\w+|Sub\s+Main|Class\s+\w+)"
}
# 各语言注释处理规则
comment_patterns = {
'java': (r'//.*?$|/\*.*?\*/|\\/\\/.*?$|\\/\*.*?\*\\/', re.DOTALL | re.MULTILINE),
'python': (r'#.*?$', re.MULTILINE),
'javascript': (r'//.*?$|/\*.*?\*/|\\/\\/.*?$|\\/\*.*?\*\\/', re.DOTALL | re.MULTILINE),
'php': (r'//.*?$|#.*?$|/\*.*?\*/|\\/\\/.*?$|#.*?$|\\/\*.*?\*\\/', re.DOTALL | re.MULTILINE),
'ruby': (r'#.*', re.MULTILINE),
'go': (r'//.*?$|/\*.*?\*/|\\/\\/.*?$|\\/\*.*?\*\\/', re.DOTALL | re.MULTILINE),
'csharp': (r'//.*?$|/\*.*?\*/|\\/\\/.*?$|\\/\*.*?\*\\/', re.DOTALL | re.MULTILINE),
'cplusplus': (r'//.*?$|/\*.*?\*/|\\/\\/.*?$|\\/\*.*?\*\\/', re.DOTALL | re.MULTILINE),
'c': (r'//.*?$|/\*.*?\*/|\\/\\/.*?$|\\/\*.*?\*\\/', re.DOTALL | re.MULTILINE),
'rust': (r'//.*?$|/\*.*?\*/|\\/\\/.*?$|\\/\*.*?\*\\/', re.DOTALL | re.MULTILINE),
'typescript': (r'//.*?$|/\*.*?\*/|\\/\\/.*?$|\\/\*.*?\*\\/', re.DOTALL | re.MULTILINE),
'perl': (r'#.*', re.MULTILINE),
'shell': (r'#.*', re.MULTILINE),
'sql': (r'--.*?$|/\*.*?\*/', re.DOTALL),
'batchfile': (r'^\s*(REM|@REM|::).*', re.MULTILINE | re.IGNORECASE),
'fortran': (r'!.*', re.MULTILINE),
'haskell': (r'--.*', re.MULTILINE),
'lua': (r'--.*?$|--\[\[.*?\]\]', re.DOTALL),
'powershell': (r'<#.*?#>|#.*', re.DOTALL),
'visual_basic': (r"'.*", re.MULTILINE),
}
# 执行注释清理
if lang in comment_patterns:
pattern, flags = comment_patterns[lang]
cleaned_code = re.sub(pattern, '', cleaned_code, flags=flags)
cleaned_code = clean_empty_lines(cleaned_code)
# 特殊语言处理规则
if lang == 'fortran':
cleaned_code = re.sub(r'^[Cc*].*', '', cleaned_code, flags=re.MULTILINE)
elif lang == 'sql':
cleaned_code = re.sub(r'/\*.*?\*/', '', cleaned_code, flags=re.DOTALL)
elif lang == 'python':
cleaned_code = re.sub(r'^\s*#.*', '', cleaned_code, flags=re.MULTILINE)
# 函数定义检测及内容验证
def has_valid_code(text: str, lang: str) -> bool:
pattern = function_patterns.get(lang)
if not pattern:
return len(text.strip()) > 0
# 增强检测逻辑
if lang == 'batchfile':
return bool(re.search(r'^\s*@?echo\b|:\w+', text, re.MULTILINE))
if lang == 'shell':
return bool(re.search(r'^\s*(if|for|while|case|echo|export|shopt|source)\b', text, re.MULTILINE))
if lang == 'python':
if re.search(r'^\s*(def|class)\s+\w+', text, re.MULTILINE):
return bool(re.search(r'^\s+[^\s#]', text, re.MULTILINE))
return False
if lang == 'ruby':
return bool(re.search(r'(def\s+\w+|class\s+\w+).*?\n\s+[^\s#]', text, re.MULTILINE))
return bool(re.search(pattern, text, re.DOTALL | re.MULTILINE))
# 最终有效性检查
if not has_valid_code(cleaned_code, lang):
return ''
cleaned_code = cleaned_code.strip('\ufeff').strip()
if len(cleaned_code) < length_threshold:
return ''
return cleaned_code
if __name__ == "__main__":
test_text = "\/\/ ----------------------------------------------------------------------\n\/\/ ----------------------------------------------------------------------\n\/\/\n\/\/ File: StrMaxProjection.h\n\/\/ Author: mgrosso \n\/\/ Created: Mon Jul 17 14:39:22 PDT 2006 on caliban\n\/\/ Project: \n\/\/ Purpose: \n\/\/ \n\/\/ $Id$\n\/\/ ----------------------------------------------------------------------\n\/\/ ----------------------------------------------------------------------\n\n#ifndef STRMAXPROJECTION_H\n#define STRMAXPROJECTION_H 1\n\n#include \"StrMinProjection.h\"\n\nclass StrMaxProjection : public StrMinProjection\n{\n public:\n StrMaxProjection(ExpressionPtr &operand);\n virtual ~StrMaxProjection();\n virtual AbstractProjectionPtr copy();\n\n protected:\n int compare(const char *lhs, const char *rhs);\n\n private:\n \/\/not implemented\n StrMaxProjection();\n StrMaxProjection( const StrMaxProjection &rhs );\n StrMaxProjection &operator=( const StrMaxProjection &rhs );\n};\n\n#endif \/* STRMAXPROJECTION_H *\/"
result = clean_code(test_text, "c", 200)
print(result)
test_text = "\/**\n * Copyright (c) Microsoft Corporation. All rights reserved.\n * Licensed under the MIT License. See License.txt in the project root for\n * license information.\n *\n * Code generated by Microsoft (R) AutoRest Code Generator.\n *\/\n\npackage com.microsoft.azure.management.datafactory.v2018_06_01;\n\nimport com.fasterxml.jackson.annotation.JsonProperty;\nimport com.fasterxml.jackson.annotation.JsonTypeInfo;\nimport com.fasterxml.jackson.annotation.JsonTypeName;\n\n\/**\n * The location of Google Cloud Storage dataset.\n *\/\n@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.PROPERTY, property = \"type\", defaultImpl = GoogleCloudStorageLocation.class)\n@JsonTypeName(\"GoogleCloudStorageLocation\")\npublic class GoogleCloudStorageLocation extends DatasetLocation {\n \/**\n * Specify the bucketName of Google Cloud Storage. Type: string (or\n * Expression with resultType string).\n *\/\n @JsonProperty(value = \"bucketName\")\n private Object bucketName;\n\n \/**\n * Specify the version of Google Cloud Storage. Type: string (or Expression\n * with resultType string).\n *\/\n @JsonProperty(value = \"version\")\n private Object version;\n\n \/**\n * Get specify the bucketName of Google Cloud Storage. Type: string (or Expression with resultType string).\n *\n * @return the bucketName value\n *\/\n public Object bucketName() {\n return this.bucketName;\n }\n\n \/**\n * Set specify the bucketName of Google Cloud Storage. Type: string (or Expression with resultType string).\n *\n * @param bucketName the bucketName value to set\n * @return the GoogleCloudStorageLocation object itself.\n *\/\n public GoogleCloudStorageLocation withBucketName(Object bucketName) {\n this.bucketName = bucketName;\n return this;\n }\n\n \/**\n * Get specify the version of Google Cloud Storage. Type: string (or Expression with resultType string).\n *\n * @return the version value\n *\/\n public Object version() {\n return this.version;\n }\n\n \/**\n * Set specify the version of Google Cloud Storage. Type: string (or Expression with resultType string).\n *\n * @param version the version value to set\n * @return the GoogleCloudStorageLocation object itself.\n *\/\n public GoogleCloudStorageLocation withVersion(Object version) {\n this.version = version;\n return this;\n }\n\n}"
result = clean_code(test_text, "java", 200)
print(result)