mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-06-27 02:39:58 +00:00
commit
c597c2de1a
1301
research/BGE_Coder/data_generation/constant.py
Normal file
1301
research/BGE_Coder/data_generation/constant.py
Normal file
File diff suppressed because it is too large
Load Diff
104
research/BGE_Coder/data_generation/corpus_generator.py
Normal file
104
research/BGE_Coder/data_generation/corpus_generator.py
Normal file
@ -0,0 +1,104 @@
|
||||
import os
|
||||
import random
|
||||
import datasets
|
||||
from tqdm import tqdm
|
||||
from typing import List, Tuple
|
||||
|
||||
from utils import clean_code
|
||||
from constant import DocLength
|
||||
|
||||
|
||||
class CorpusGenerator:
|
||||
def __init__(
|
||||
self,
|
||||
cache_dir: str = None,
|
||||
):
|
||||
self.cache_dir = cache_dir
|
||||
|
||||
def _load_corpus(self, corpus_dir: str, doc_length: List[str], external_path: List[str],
|
||||
source_language: str, stop_threshold: int = -1):
|
||||
"""
|
||||
Load availavle documents for a given task from the CoIR-Retrieval dataset.
|
||||
"""
|
||||
|
||||
corpus_list = []
|
||||
|
||||
if corpus_dir is not None and os.path.exists(corpus_dir):
|
||||
file_list = os.listdir(corpus_dir)
|
||||
random.shuffle(file_list)
|
||||
|
||||
for file in file_list:
|
||||
flag = False
|
||||
if not file.endswith('.jsonl'):
|
||||
flag = False
|
||||
for d_length in doc_length:
|
||||
d_length = DocLength[d_length].value
|
||||
if d_length in file:
|
||||
flag = True
|
||||
if flag is False:
|
||||
continue
|
||||
file_path = os.path.join(corpus_dir, file)
|
||||
corpus = datasets.load_dataset('json', data_files=file_path, cache_dir=self.cache_dir)['train']
|
||||
for data in tqdm(corpus, desc="Loading corpus"):
|
||||
if source_language is None:
|
||||
lang = os.path.basename(corpus_dir)
|
||||
data['language'] = lang
|
||||
else:
|
||||
data['language'] = source_language
|
||||
|
||||
text = clean_code(data["text"], data["language"], length_threshold=200)
|
||||
data["text"] = text
|
||||
if text != '':
|
||||
corpus_list.append(data)
|
||||
|
||||
if stop_threshold > 0 and len(corpus_list) > stop_threshold:
|
||||
break
|
||||
break
|
||||
|
||||
for ep in external_path:
|
||||
if os.path.exists(ep):
|
||||
corpus = datasets.load_dataset('json', data_files=ep, cache_dir=self.cache_dir)['train']
|
||||
for data in tqdm(corpus, desc="Loading corpus"):
|
||||
if source_language is None:
|
||||
lang = os.path.basename(os.path.dirname(ep))
|
||||
data['language'] = lang
|
||||
else:
|
||||
data['language'] = source_language
|
||||
|
||||
# useful when the text is not present in the data
|
||||
if "text" not in data:
|
||||
data["text"] = data["pos"][0]
|
||||
|
||||
corpus_list.append(data)
|
||||
text = clean_code(data["text"], lang, length_threshold=200)
|
||||
data["text"] = text
|
||||
if text != '':
|
||||
corpus_list.append(data)
|
||||
|
||||
return corpus_list
|
||||
|
||||
def run(
|
||||
self,
|
||||
num_samples: int = -1,
|
||||
max_corpus: int = -1,
|
||||
corpus_dir: str = None,
|
||||
doc_length: List[str] = ["len_0_500"],
|
||||
external_path: List[str] = None,
|
||||
source_language: str = None
|
||||
) -> Tuple[List[dict], List[dict]]:
|
||||
stop_threshold = max(num_samples * 10, max_corpus * 2)
|
||||
corpus_list = self._load_corpus(
|
||||
corpus_dir, doc_length, external_path, source_language, stop_threshold
|
||||
)
|
||||
|
||||
if num_samples > 0 and num_samples < len(corpus_list):
|
||||
small_corpus_list = random.sample(corpus_list, num_samples)
|
||||
else:
|
||||
small_corpus_list = corpus_list
|
||||
|
||||
if max_corpus > 0 and max_corpus < len(corpus_list):
|
||||
corpus_list = random.sample(corpus_list, max_corpus)
|
||||
else:
|
||||
corpus_list = corpus_list
|
||||
|
||||
return small_corpus_list, corpus_list
|
127
research/BGE_Coder/data_generation/format_generated_examples.py
Normal file
127
research/BGE_Coder/data_generation/format_generated_examples.py
Normal file
@ -0,0 +1,127 @@
|
||||
import os
|
||||
import json
|
||||
from constant import Language, CodeLanguage, TaskType, CODE_TRANSLATION_RETRIEVAL_PAIRS, \
|
||||
get_pos_as_input_by_task_type
|
||||
|
||||
|
||||
def format_generated_examples(
|
||||
file_path: str,
|
||||
save_path: str,
|
||||
task_type: TaskType
|
||||
):
|
||||
if os.path.exists(save_path):
|
||||
return
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
print("====================================")
|
||||
print("Warning: file not found! Maybe need to generate it first.")
|
||||
print(f"file_path: {file_path}")
|
||||
return
|
||||
|
||||
pos_as_input = get_pos_as_input_by_task_type(task_type)
|
||||
|
||||
data_list = []
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
for line in f.readlines():
|
||||
data = json.loads(line)
|
||||
|
||||
if pos_as_input:
|
||||
_input = data["pos"][0]
|
||||
_output = data["query"]
|
||||
else:
|
||||
_input = data["query"]
|
||||
_output = data["pos"][0]
|
||||
|
||||
if 'provided' in _input:
|
||||
continue
|
||||
if len(_input) > 12000 or len(_output) > 12000:
|
||||
continue
|
||||
|
||||
data_list.append({
|
||||
"input": _input,
|
||||
"output": _output
|
||||
})
|
||||
|
||||
if len(data_list) == 0:
|
||||
print("====================================")
|
||||
print("Warning: no data found!")
|
||||
print(f"file_path: {file_path}")
|
||||
return
|
||||
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
with open(save_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data_list, f, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
def main():
|
||||
original_gen_examples_dir = "./examples"
|
||||
|
||||
formatted_examples_dir = "./filtered_for_generation"
|
||||
|
||||
for language in Language:
|
||||
for task_type in TaskType:
|
||||
if task_type == TaskType.code_translation_retrieval:
|
||||
for code_language_pair in CODE_TRANSLATION_RETRIEVAL_PAIRS:
|
||||
code_language, tgt_code_language = code_language_pair
|
||||
|
||||
file_path = os.path.join(
|
||||
original_gen_examples_dir,
|
||||
language.name, task_type.name, f"{language.name}-{code_language.name}-to-{tgt_code_language.name}-triplets.jsonl"
|
||||
)
|
||||
save_path = os.path.join(
|
||||
formatted_examples_dir,
|
||||
language.name, task_type.name, f"{code_language.name}-to-{tgt_code_language.name}_sample_examples.json"
|
||||
)
|
||||
|
||||
format_generated_examples(file_path, save_path, task_type)
|
||||
|
||||
for code_language_pair in CODE_TRANSLATION_RETRIEVAL_PAIRS:
|
||||
tgt_code_language, code_language = code_language_pair
|
||||
|
||||
file_path = os.path.join(
|
||||
original_gen_examples_dir,
|
||||
language.name, task_type.name, f"{language.name}-{code_language.name}-to-{tgt_code_language.name}-triplets.jsonl"
|
||||
)
|
||||
save_path = os.path.join(
|
||||
formatted_examples_dir,
|
||||
language.name, task_type.name, f"{code_language.name}-to-{tgt_code_language.name}_sample_examples.json"
|
||||
)
|
||||
|
||||
format_generated_examples(file_path, save_path, task_type)
|
||||
|
||||
elif task_type == TaskType.text2sql_retrieval:
|
||||
file_path = os.path.join(
|
||||
original_gen_examples_dir,
|
||||
language.name, task_type.name, f"{language.name}-sql-triplets.jsonl"
|
||||
)
|
||||
save_path = os.path.join(
|
||||
formatted_examples_dir,
|
||||
language.name, task_type.name, "sql_sample_examples.json"
|
||||
)
|
||||
|
||||
format_generated_examples(file_path, save_path, task_type)
|
||||
|
||||
elif task_type == TaskType.code_context_retrieval:
|
||||
continue
|
||||
|
||||
else:
|
||||
for code_language in CodeLanguage:
|
||||
if code_language == CodeLanguage.null:
|
||||
continue
|
||||
|
||||
file_path = os.path.join(
|
||||
original_gen_examples_dir,
|
||||
language.name, task_type.name, f"{language.name}-{code_language.name}-triplets.jsonl"
|
||||
)
|
||||
save_path = os.path.join(
|
||||
formatted_examples_dir,
|
||||
language.name, task_type.name, f"{code_language.name}_sample_examples.json"
|
||||
)
|
||||
|
||||
format_generated_examples(file_path, save_path, task_type)
|
||||
|
||||
print("All done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
134
research/BGE_Coder/data_generation/llm.py
Normal file
134
research/BGE_Coder/data_generation/llm.py
Normal file
@ -0,0 +1,134 @@
|
||||
import os
|
||||
import time
|
||||
import openai
|
||||
import random
|
||||
import tiktoken
|
||||
import threading
|
||||
from openai import OpenAI, AzureOpenAI
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
class LLM:
|
||||
def __init__(
|
||||
self,
|
||||
model: str="Qwen2-5-Coder-32B-Instruct",
|
||||
model_type: str = "open-source",
|
||||
port: int = 8000,
|
||||
):
|
||||
if model_type == "open-source":
|
||||
self.client = OpenAI(
|
||||
api_key="EMPTY",
|
||||
base_url=f"http://localhost:{port}/v1/"
|
||||
)
|
||||
elif model_type == "azure":
|
||||
self.client = AzureOpenAI(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
api_version=os.getenv("AZURE_API_VERSION", "2024-02-01"),
|
||||
azure_endpoint=os.getenv("AZURE_ENDPOINT"),
|
||||
azure_deployment=os.getenv("OPENAI_DEPLOYMENT_NAME", 'gpt-35-turbo')
|
||||
)
|
||||
elif model_type == "openai":
|
||||
self.client = OpenAI(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
base_url=os.getenv("OPENAI_BASE_URL", None)
|
||||
)
|
||||
else:
|
||||
raise ValueError("model_type must be one of ['open-source', 'azure', 'openai']")
|
||||
|
||||
self.model = model
|
||||
self.tokenizer = tiktoken.get_encoding("o200k_base")
|
||||
|
||||
def split_text(self, text: str, anchor_points: Tuple[float, float] = (0.4, 0.7)):
|
||||
token_ids = self.tokenizer.encode(text)
|
||||
anchor_point = random.uniform(anchor_points[0], anchor_points[1])
|
||||
split_index = int(len(token_ids) * anchor_point)
|
||||
return self.tokenizer.decode(token_ids[:split_index]), self.tokenizer.decode(token_ids[split_index:])
|
||||
|
||||
def chat(
|
||||
self,
|
||||
prompt: str,
|
||||
max_tokens: int = 8192,
|
||||
logit_bais: dict = None,
|
||||
n: int = 1,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 0.6,
|
||||
repetition_penalty: float = 1.0,
|
||||
remove_thinking: bool = True,
|
||||
timeout: int = 90,
|
||||
):
|
||||
endure_time = 0
|
||||
endure_time_limit = timeout * 2
|
||||
|
||||
def create_completion(results):
|
||||
try:
|
||||
completion = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
max_tokens=max_tokens,
|
||||
logit_bias=logit_bais if logit_bais is not None else {},
|
||||
n=n,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
extra_body={'repetition_penalty': repetition_penalty},
|
||||
timeout=timeout,
|
||||
)
|
||||
results["content"] = [x.message.content for x in completion.choices[:n]]
|
||||
except openai.BadRequestError as e:
|
||||
# The response was filtered due to the prompt triggering Azure OpenAI's content management policy.
|
||||
results["content"] = [None for _ in range(n)]
|
||||
except openai.APIConnectionError as e:
|
||||
results["error"] = f'APIConnectionError({e})'
|
||||
except openai.RateLimitError as e:
|
||||
results["error"] = f'RateLimitError({e})'
|
||||
except Exception as e:
|
||||
results["error"] = f"Error: {e}"
|
||||
|
||||
while True:
|
||||
results = {"content": None, "error": None}
|
||||
completion_thread = threading.Thread(target=create_completion, args=(results,))
|
||||
completion_thread.start()
|
||||
|
||||
start_time = time.time()
|
||||
while completion_thread.is_alive():
|
||||
elapsed_time = time.time() - start_time
|
||||
if elapsed_time > endure_time_limit:
|
||||
print("Completion timeout exceeded. Aborting...")
|
||||
return [None for _ in range(n)]
|
||||
time.sleep(1)
|
||||
|
||||
# If an error occurred during result processing
|
||||
if results["error"]:
|
||||
if endure_time >= endure_time_limit:
|
||||
print(f'{results["error"]} - Skip this prompt.')
|
||||
return [None for _ in range(n)]
|
||||
print(f"{results['error']} - Waiting for 5 seconds...")
|
||||
endure_time += 5
|
||||
time.sleep(5)
|
||||
continue
|
||||
|
||||
content_list = results["content"]
|
||||
if remove_thinking:
|
||||
content_list = [x.split('</think>')[-1].strip('\n').strip() if x is not None else None for x in content_list]
|
||||
return content_list
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
llm = LLM(
|
||||
model="gpt-4o-mini-2024-07-18",
|
||||
model_type="openai"
|
||||
)
|
||||
|
||||
prompt = "hello, who are you?"
|
||||
response = llm.chat(prompt)[0]
|
||||
print(response)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
llm = LLM(
|
||||
model="gpt-4o-mini-2024-07-18",
|
||||
model_type="openai"
|
||||
)
|
||||
|
||||
prompt = "hello, who are you?"
|
||||
response = llm.chat(prompt)[0]
|
||||
print(response)
|
368
research/BGE_Coder/data_generation/run_generation.py
Normal file
368
research/BGE_Coder/data_generation/run_generation.py
Normal file
@ -0,0 +1,368 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import gc
|
||||
import torch
|
||||
import argparse
|
||||
import random
|
||||
from hashlib import md5
|
||||
import multiprocessing as mp
|
||||
from typing import List, Optional
|
||||
|
||||
from constant import TaskType, Language, CodeLanguage, NUM_HARD_NEGATIVES
|
||||
from corpus_generator import CorpusGenerator
|
||||
from triplet_generator import TripletGenerator
|
||||
from search import get_top1
|
||||
|
||||
|
||||
def compute_md5(text: str):
|
||||
return md5(text.encode()).hexdigest()
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--task_type',
|
||||
type=str,
|
||||
required=True,
|
||||
help='The task type to generate data for',
|
||||
choices=[t.name for t in TaskType]
|
||||
)
|
||||
parser.add_argument(
|
||||
'--code_language',
|
||||
type=str,
|
||||
required=True,
|
||||
help='The code language to generate questions for.',
|
||||
choices=[c.name for c in CodeLanguage]
|
||||
)
|
||||
parser.add_argument(
|
||||
'--corpus_root',
|
||||
type=str,
|
||||
required=True,
|
||||
help='The root directory of the corpus data.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--save_dir',
|
||||
type=str,
|
||||
required=True,
|
||||
help='The path to save the generated data'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--examples_dir',
|
||||
type=str,
|
||||
default=None,
|
||||
help='The path to the examples directory. If not None, the examples will be used for few-shot generation.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--num_examples',
|
||||
type=int,
|
||||
default=3,
|
||||
help='The number of examples to use for few-shot generation. Default: 3'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--cache_dir',
|
||||
type=str,
|
||||
default=None,
|
||||
help='The cache directory'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--language',
|
||||
type=str,
|
||||
default='en',
|
||||
help='The language to generate for. ISO 639-1 code. Default: en',
|
||||
choices=[l.name for l in Language]
|
||||
)
|
||||
parser.add_argument(
|
||||
'--tgt_code_language',
|
||||
type=str,
|
||||
default=None,
|
||||
help='The target code language to generate code translations for.',
|
||||
choices=[c.name for c in CodeLanguage]
|
||||
)
|
||||
parser.add_argument(
|
||||
'--num_samples',
|
||||
type=int,
|
||||
default=-1,
|
||||
help='The number of examples to use for generation. Default: -1. Use all available examples.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--model',
|
||||
type=str,
|
||||
default='Qwen2.5-72B-Instruct',
|
||||
help='The model to use for generation. Default: Qwen2.5-72B-Instruct'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--model_type',
|
||||
type=str,
|
||||
default='open-source',
|
||||
help='The type of model to use for generation. Default: open-source',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--port',
|
||||
type=int,
|
||||
default=8000,
|
||||
help='The port for vllm.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--num_processes',
|
||||
type=int,
|
||||
default=1,
|
||||
help='The number of processes to use for generation. Default: 1'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--doc_length',
|
||||
type=str,
|
||||
default='len_0_500',
|
||||
help='The corpus length used to load dataset. Default: len_0_500'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--external_path',
|
||||
type=str,
|
||||
default='',
|
||||
help='The corpus length used to load dataset. Default: len_0_500'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--sim_model_name',
|
||||
type=str,
|
||||
default=None,
|
||||
help='The language of source corpus.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--max_corpus',
|
||||
type=int,
|
||||
default=500000,
|
||||
help='The max num of corpus to load.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--overwrite',
|
||||
action='store_true',
|
||||
help='Whether to overwrite the existing data.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--debug_mode',
|
||||
action='store_true',
|
||||
help='Whether to open debug mode.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--gen_hard_neg',
|
||||
action='store_true',
|
||||
help='Whether to generate hard negatives.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--seed',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Random seed for generating triplets using the same positive. Default: 42'
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def gen_triplets(
|
||||
model: str,
|
||||
model_type: str,
|
||||
port: int,
|
||||
positives: List[dict],
|
||||
task_type: str,
|
||||
language: str,
|
||||
code_language: str,
|
||||
tgt_code_language: str,
|
||||
examples_pool: Optional[List[dict]] = None,
|
||||
num_examples: int = 3,
|
||||
tqdm_desc: str = "Generating triplets",
|
||||
thread_count: int = 1,
|
||||
gen_cache_dir: Optional[str] = None,
|
||||
debug_mode: bool = False,
|
||||
gen_hard_neg: bool = False,
|
||||
):
|
||||
triplet_generator = TripletGenerator(model, model_type, port, cache_dir=gen_cache_dir)
|
||||
triplets = triplet_generator.run(
|
||||
positives=positives,
|
||||
task_type=task_type,
|
||||
language=language,
|
||||
code_language=code_language,
|
||||
tgt_code_language=tgt_code_language,
|
||||
examples_pool=examples_pool,
|
||||
num_examples=num_examples,
|
||||
tqdm_desc=tqdm_desc,
|
||||
thread_count=thread_count,
|
||||
debug_mode=debug_mode,
|
||||
gen_hard_neg=gen_hard_neg,
|
||||
num_negatives=NUM_HARD_NEGATIVES,
|
||||
)
|
||||
return triplets
|
||||
|
||||
|
||||
def get_save_path(
|
||||
save_dir: str,
|
||||
task_type: str,
|
||||
language: str,
|
||||
code_language: str,
|
||||
tgt_code_language: Optional[str] = None
|
||||
):
|
||||
save_dir = os.path.join(save_dir, language, task_type)
|
||||
if tgt_code_language is not None:
|
||||
file_name = f"{language}-{code_language}-to-{tgt_code_language}-triplets.jsonl"
|
||||
else:
|
||||
file_name = f"{language}-{code_language}-triplets.jsonl"
|
||||
save_path = os.path.join(save_dir, file_name)
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
return save_path
|
||||
|
||||
|
||||
def save_triplets(
|
||||
triplets: list,
|
||||
save_dir: str,
|
||||
task_type: str,
|
||||
language: str,
|
||||
code_language: str,
|
||||
tgt_code_language: Optional[str] = None
|
||||
):
|
||||
if len(triplets) == 0:
|
||||
print(f"No triplets to save: {task_type} | {language} | {code_language} | {tgt_code_language}")
|
||||
return
|
||||
|
||||
save_path = get_save_path(save_dir, task_type, language, code_language, tgt_code_language)
|
||||
query_md5s = set()
|
||||
pos_md5s = set()
|
||||
old_triplets = []
|
||||
if os.path.exists(save_path):
|
||||
with open(save_path, "r", encoding="utf-8") as f:
|
||||
for line in f.readlines():
|
||||
triplet = json.loads(line)
|
||||
old_triplets.append(triplet)
|
||||
query_md5s.add(compute_md5(triplet['query']))
|
||||
pos_md5s.add(compute_md5(triplet['pos'][0]))
|
||||
|
||||
with open(save_path, 'w', encoding='utf-8') as f:
|
||||
for triplet in old_triplets:
|
||||
f.write(json.dumps(triplet, ensure_ascii=False) + '\n')
|
||||
|
||||
for triplet in triplets:
|
||||
_query_md5 = compute_md5(triplet['query'])
|
||||
_pos_md5 = compute_md5(triplet['pos'][0])
|
||||
if _query_md5 in query_md5s or _pos_md5 in pos_md5s:
|
||||
continue
|
||||
f.write(json.dumps(triplet, ensure_ascii=False) + '\n')
|
||||
print(f"Triplets saved to {save_path}")
|
||||
|
||||
|
||||
def main(args):
|
||||
# set seed
|
||||
seed = args.seed
|
||||
if seed is not None:
|
||||
print(f"------------------- Seed set to {seed} -------------------")
|
||||
random.seed(seed)
|
||||
|
||||
model = args.model
|
||||
model_type = args.model_type
|
||||
port = args.port
|
||||
|
||||
num_samples = args.num_samples
|
||||
|
||||
task_type = args.task_type
|
||||
language = args.language
|
||||
code_language = args.code_language
|
||||
tgt_code_language = args.tgt_code_language
|
||||
|
||||
corpus_root = args.corpus_root
|
||||
corpus_dir = os.path.join(corpus_root, code_language)
|
||||
doc_length = args.doc_length.split()
|
||||
external_path = args.external_path.split()
|
||||
|
||||
save_dir = args.save_dir
|
||||
cache_dir = args.cache_dir
|
||||
num_processes = min(args.num_processes, int(mp.cpu_count() * 0.8))
|
||||
overwrite = args.overwrite
|
||||
debug_mode = args.debug_mode
|
||||
gen_hard_neg = args.gen_hard_neg
|
||||
|
||||
save_path = get_save_path(save_dir, task_type, language, code_language, tgt_code_language)
|
||||
# if os.path.exists(save_path) and not overwrite:
|
||||
# data = []
|
||||
# with open(save_path) as f:
|
||||
# for line in f:
|
||||
# data.append(json.loads(line))
|
||||
# if len(data) >= num_samples * 0.8:
|
||||
# print(f"Triplets already exist at {save_path}. Skipping generation.")
|
||||
# return
|
||||
# else:
|
||||
# print(f"Triplets already exist at {save_path}. But samples is really small, continue generation.")
|
||||
# num_samples = int((num_samples - len(data)) * 1.25) # consider the filtered samples
|
||||
|
||||
corpus_generator = CorpusGenerator(cache_dir)
|
||||
|
||||
examples_dir = args.examples_dir
|
||||
num_examples = args.num_examples
|
||||
if examples_dir is not None:
|
||||
# if task_type in ["single_turn_code_qa", "multi_turn_code_qa"]:
|
||||
# examples_path = os.path.join(examples_dir, language, task_type, "sample_examples.json")
|
||||
if task_type in ["code_translation_retrieval"]:
|
||||
examples_path = os.path.join(examples_dir, language, task_type,
|
||||
f"{code_language}-to-{tgt_code_language}_sample_examples.json")
|
||||
else:
|
||||
examples_path = os.path.join(examples_dir, language, task_type, f"{code_language}_sample_examples.json")
|
||||
try:
|
||||
with open(examples_path, 'r', encoding='utf-8') as f:
|
||||
examples_pool = json.load(f)
|
||||
examples_pool = random.sample(examples_pool,
|
||||
min(30, len(examples_pool))) # sample 30 examples for few-shot generation
|
||||
except:
|
||||
print(f'Error for loading examples from {examples_path}')
|
||||
examples_pool = None
|
||||
else:
|
||||
examples_pool = None
|
||||
|
||||
positives, large_positives = corpus_generator.run(
|
||||
num_samples=num_samples,
|
||||
max_corpus=args.max_corpus,
|
||||
corpus_dir=corpus_dir,
|
||||
doc_length=doc_length,
|
||||
external_path=external_path,
|
||||
source_language=code_language
|
||||
)
|
||||
|
||||
if task_type in ["code_modification_retrieval", "code_comparison_retrieval"]:
|
||||
top1_docs = get_top1([e['text'] for e in positives], args.sim_model_name, [e['text'] for e in large_positives])
|
||||
for i in range(len(top1_docs)):
|
||||
positives[i]['similar'] = top1_docs[i]
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
print("=================== Generate training data ===================")
|
||||
print(f'Task Type: {task_type} | Language: {language} | Code Language: {code_language} | Target Code Language: {tgt_code_language}')
|
||||
start_time = time.time()
|
||||
triplets = gen_triplets(
|
||||
model=model,
|
||||
model_type=model_type,
|
||||
port=port,
|
||||
positives=positives,
|
||||
task_type=task_type,
|
||||
language=language,
|
||||
code_language=code_language,
|
||||
tgt_code_language=tgt_code_language,
|
||||
examples_pool=examples_pool,
|
||||
num_examples=num_examples,
|
||||
thread_count=num_processes,
|
||||
gen_cache_dir=os.path.join(save_dir, language, task_type, "gen_cache_dir"),
|
||||
debug_mode=debug_mode,
|
||||
gen_hard_neg=gen_hard_neg,
|
||||
)
|
||||
save_triplets(
|
||||
triplets=triplets,
|
||||
save_dir=save_dir,
|
||||
task_type=task_type,
|
||||
language=language,
|
||||
code_language=code_language,
|
||||
tgt_code_language=tgt_code_language
|
||||
)
|
||||
end_time = time.time()
|
||||
print("=============================================================")
|
||||
print(f"Time taken: {end_time - start_time:.2f} seconds")
|
||||
print("=============================================================")
|
||||
print("DONE!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
main(args)
|
71
research/BGE_Coder/data_generation/search.py
Normal file
71
research/BGE_Coder/data_generation/search.py
Normal file
@ -0,0 +1,71 @@
|
||||
from typing import Optional, List
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from FlagEmbedding import FlagModel
|
||||
|
||||
def create_index(embeddings: np.ndarray, use_gpu: bool = False):
|
||||
index = faiss.IndexFlatIP(len(embeddings[0]))
|
||||
embeddings = np.asarray(embeddings, dtype=np.float32)
|
||||
if use_gpu:
|
||||
co = faiss.GpuMultipleClonerOptions()
|
||||
co.shard = True
|
||||
co.useFloat16 = True
|
||||
index = faiss.index_cpu_to_all_gpus(index, co=co)
|
||||
index.add(embeddings)
|
||||
return index
|
||||
|
||||
|
||||
def search(
|
||||
faiss_index: faiss.Index,
|
||||
k: int = 100,
|
||||
query_embeddings: Optional[np.ndarray] = None,
|
||||
load_path: Optional[str] = None
|
||||
):
|
||||
if query_embeddings is None:
|
||||
query_embeddings = np.load(load_path)
|
||||
|
||||
query_size = len(query_embeddings)
|
||||
|
||||
all_scores = []
|
||||
all_indices = []
|
||||
|
||||
for i in tqdm(range(0, query_size, 32), desc="Searching"):
|
||||
j = min(i + 32, query_size)
|
||||
query_embedding = query_embeddings[i: j]
|
||||
score, indice = faiss_index.search(query_embedding.astype(np.float32), k=k)
|
||||
all_scores.append(score)
|
||||
all_indices.append(indice)
|
||||
|
||||
all_scores = np.concatenate(all_scores, axis=0)
|
||||
all_indices = np.concatenate(all_indices, axis=0)
|
||||
return all_scores, all_indices
|
||||
|
||||
def get_top1(
|
||||
small_docs,
|
||||
encoder_name,
|
||||
docs: List[str],
|
||||
top: int = 1
|
||||
):
|
||||
encoder = FlagModel(encoder_name, trust_remote_code=True)
|
||||
doc_emb = encoder.encode_corpus(docs, max_length=512, batch_size=256)
|
||||
small_doc_emb = encoder.encode_corpus(small_docs, max_length=512, batch_size=256)
|
||||
faiss_index = create_index(doc_emb, True)
|
||||
all_scores, all_indices = search(faiss_index, 1000, small_doc_emb)
|
||||
return_docs = []
|
||||
for i in range(len(all_indices)):
|
||||
return_docs.append([])
|
||||
for idx, score in zip(all_indices[i][20:], all_scores[i][20:]):
|
||||
d1 = set(docs[idx].split())
|
||||
d2 = set(small_docs[i].split())
|
||||
if len(d1 & d2) / len(d1 | d2) > 0.95:
|
||||
continue
|
||||
return_docs[-1].append(docs[idx])
|
||||
if len(return_docs[-1]) >= top:
|
||||
break
|
||||
if len(return_docs[-1]) == 0:
|
||||
print(all_indices[i], all_scores[i])
|
||||
# print(return_docs)
|
||||
del faiss_index
|
||||
return return_docs
|
654
research/BGE_Coder/data_generation/triplet_generator.py
Normal file
654
research/BGE_Coder/data_generation/triplet_generator.py
Normal file
@ -0,0 +1,654 @@
|
||||
import os
|
||||
import json
|
||||
import random
|
||||
from tqdm import tqdm
|
||||
from hashlib import md5
|
||||
from warnings import warn
|
||||
from typing import List, Optional
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from llm import LLM
|
||||
from utils import clean_content
|
||||
from constant import TaskType, Task, SPECIAL_TASK_STEPS, \
|
||||
get_task, get_generation_prompt, get_quality_control_prompt, \
|
||||
get_gen_hard_neg_prompt
|
||||
|
||||
|
||||
def compute_md5(text: str):
|
||||
return md5(text.encode()).hexdigest()
|
||||
|
||||
|
||||
class TripletGenerator(LLM):
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "Qwen2-5-Coder-32B-Instruct",
|
||||
model_type: str = "open-source",
|
||||
port: int = 8000,
|
||||
cache_dir: Optional[str] = None
|
||||
):
|
||||
super().__init__(model, model_type, port)
|
||||
self.cache_dir = cache_dir
|
||||
if self.cache_dir is not None:
|
||||
os.makedirs(self.cache_dir, exist_ok=True)
|
||||
|
||||
def _gen_for_code_modification_retrieval(
|
||||
self,
|
||||
task: Task,
|
||||
text: str,
|
||||
text_b: Optional[str] = None,
|
||||
examples: Optional[List[dict]] = None,
|
||||
debug_mode: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
gen_prompt = get_generation_prompt(
|
||||
task=task,
|
||||
text=text,
|
||||
text_b=text_b,
|
||||
examples=examples,
|
||||
idx=0
|
||||
)
|
||||
response = self.chat(gen_prompt, **kwargs)[0]
|
||||
diff = clean_content(response)
|
||||
gen_prompt = get_generation_prompt(
|
||||
task=task,
|
||||
text=diff,
|
||||
examples=examples,
|
||||
idx=1
|
||||
)
|
||||
response = self.chat(gen_prompt, **kwargs)[0]
|
||||
modification_instr = clean_content(response)
|
||||
|
||||
query = f"{modification_instr}\n```\n{text}\n```"
|
||||
pos = text_b
|
||||
|
||||
if debug_mode:
|
||||
result = {
|
||||
"generation_prompt": gen_prompt,
|
||||
"prompt": task.task_instruction,
|
||||
"query": query,
|
||||
"pos": [pos],
|
||||
"neg": []
|
||||
}
|
||||
else:
|
||||
result = {
|
||||
"prompt": task.task_instruction,
|
||||
"query": query,
|
||||
"pos": [pos],
|
||||
"neg": []
|
||||
}
|
||||
return result
|
||||
|
||||
def _gen_for_code_comparison_retrieval(
|
||||
self,
|
||||
task: Task,
|
||||
text: str,
|
||||
text_b: Optional[str] = None,
|
||||
examples: Optional[List[dict]] = None,
|
||||
debug_mode: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
gen_prompt = get_generation_prompt(
|
||||
task=task,
|
||||
text=text,
|
||||
text_b=text_b,
|
||||
examples=examples,
|
||||
idx=0
|
||||
)
|
||||
response = self.chat(gen_prompt, **kwargs)[0]
|
||||
diff_question = clean_content(response)
|
||||
query = f"{diff_question}\n\nInput Code:\n```\n{text}\n```\n\nOutput Code:\n```\n{text_b}\n```"
|
||||
gen_prompt = get_generation_prompt(
|
||||
task=task,
|
||||
text=query,
|
||||
examples=examples,
|
||||
idx=1
|
||||
)
|
||||
response = self.chat(gen_prompt, **kwargs)[0]
|
||||
pos = clean_content(response)
|
||||
|
||||
if debug_mode:
|
||||
result = {
|
||||
"generation_prompt": gen_prompt,
|
||||
"prompt": task.task_instruction,
|
||||
"query": query,
|
||||
"pos": [pos],
|
||||
"neg": []
|
||||
}
|
||||
else:
|
||||
result = {
|
||||
"prompt": task.task_instruction,
|
||||
"query": query,
|
||||
"pos": [pos],
|
||||
"neg": []
|
||||
}
|
||||
return result
|
||||
|
||||
def _gen_for_code_context_retrieval(
|
||||
self,
|
||||
task: Task,
|
||||
text: str,
|
||||
anchor_points: Optional[tuple] = (0.4, 0.7),
|
||||
**kwargs
|
||||
):
|
||||
former_part, latter_part = self.split_text(
|
||||
text,
|
||||
anchor_points=anchor_points
|
||||
)
|
||||
result = {
|
||||
"prompt": task.task_instruction,
|
||||
"query": former_part,
|
||||
"pos": [latter_part],
|
||||
"neg": []
|
||||
}
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _arrange_query_and_pos(task: Task, input_text: str, response: str):
|
||||
"""
|
||||
Arrange the query and positive example based on the task type.
|
||||
|
||||
Args:
|
||||
- task: Task
|
||||
- input_text: str
|
||||
- response: str
|
||||
|
||||
Returns:
|
||||
- query: str
|
||||
- pos: str
|
||||
"""
|
||||
# TODO: support more task types, including some special task types.
|
||||
if task.main_task_type in ["text2code", "hybrid"]:
|
||||
query = clean_content(response)
|
||||
pos = input_text
|
||||
else:
|
||||
query = input_text
|
||||
pos = clean_content(response)
|
||||
return query, pos
|
||||
|
||||
def _gen_for_normal_task(
|
||||
self,
|
||||
task: Task,
|
||||
text: str,
|
||||
examples: Optional[List[dict]] = None,
|
||||
debug_mode: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
gen_prompt = get_generation_prompt(
|
||||
task=task,
|
||||
text=text,
|
||||
examples=examples
|
||||
)
|
||||
response = self.chat(gen_prompt, **kwargs)[0]
|
||||
|
||||
# Arrange the query and positive example based on the task type.
|
||||
query, pos = self._arrange_query_and_pos(
|
||||
task=task,
|
||||
input_text=text,
|
||||
response=response
|
||||
)
|
||||
|
||||
if debug_mode:
|
||||
result = {
|
||||
"generation_prompt": gen_prompt,
|
||||
"prompt": task.task_instruction,
|
||||
"query": query,
|
||||
"pos": [pos],
|
||||
"neg": [],
|
||||
"response": response
|
||||
}
|
||||
else:
|
||||
result = {
|
||||
"prompt": task.task_instruction,
|
||||
"query": query,
|
||||
"pos": [pos],
|
||||
"neg": []
|
||||
}
|
||||
return result
|
||||
|
||||
def _gen_for_bug_desc_retrieval(
|
||||
self,
|
||||
task: Task,
|
||||
text: str,
|
||||
examples: Optional[List[dict]] = None,
|
||||
debug_mode: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
gen_prompt = get_generation_prompt(
|
||||
task=task,
|
||||
text=text,
|
||||
examples=examples,
|
||||
idx=0
|
||||
)
|
||||
response = self.chat(gen_prompt, **kwargs)[0]
|
||||
if response is None:
|
||||
raise ValueError("Response is None.")
|
||||
buggy_code = response
|
||||
gen_prompt = get_generation_prompt(
|
||||
task=task,
|
||||
text=buggy_code,
|
||||
examples=examples,
|
||||
idx=1
|
||||
)
|
||||
response = self.chat(gen_prompt, **kwargs)[0]
|
||||
query = clean_content(response)
|
||||
pos = text
|
||||
|
||||
if debug_mode:
|
||||
result = {
|
||||
"generation_prompt": gen_prompt,
|
||||
"prompt": task.task_instruction,
|
||||
"query": query,
|
||||
"pos": [pos],
|
||||
"neg": []
|
||||
}
|
||||
else:
|
||||
result = {
|
||||
"prompt": task.task_instruction,
|
||||
"query": query,
|
||||
"pos": [pos],
|
||||
"neg": []
|
||||
}
|
||||
return result
|
||||
|
||||
def _gen_for_two_step_not_use_last(
|
||||
self,
|
||||
task: Task,
|
||||
text: str,
|
||||
examples: Optional[List[dict]] = None,
|
||||
debug_mode: bool = False,
|
||||
reverse_query_pos: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
gen_prompt = get_generation_prompt(
|
||||
task=task,
|
||||
text=text,
|
||||
idx=0
|
||||
)
|
||||
response = self.chat(gen_prompt, **kwargs)[0]
|
||||
query = clean_content(response)
|
||||
gen_prompt = get_generation_prompt(
|
||||
task=task,
|
||||
text=query,
|
||||
examples=examples,
|
||||
idx=1
|
||||
)
|
||||
response = self.chat(gen_prompt, **kwargs)[0]
|
||||
pos = clean_content(response)
|
||||
if reverse_query_pos:
|
||||
query, pos = pos, query
|
||||
|
||||
if debug_mode:
|
||||
result = {
|
||||
"generation_prompt": gen_prompt,
|
||||
"prompt": task.task_instruction,
|
||||
"query": query,
|
||||
"pos": [pos],
|
||||
"neg": []
|
||||
}
|
||||
else:
|
||||
result = {
|
||||
"prompt": task.task_instruction,
|
||||
"query": query,
|
||||
"pos": [pos],
|
||||
"neg": []
|
||||
}
|
||||
return result
|
||||
|
||||
def _gen_for_two_step_use_last(
|
||||
self,
|
||||
task: Task,
|
||||
text: str,
|
||||
examples: Optional[List[dict]] = None,
|
||||
debug_mode: bool = False,
|
||||
reverse_query_pos: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
gen_prompt = get_generation_prompt(
|
||||
task=task,
|
||||
text=text,
|
||||
idx=0
|
||||
)
|
||||
response = self.chat(gen_prompt, **kwargs)[0]
|
||||
query = clean_content(response) + f"\n```\n{text}\n```"
|
||||
gen_prompt = get_generation_prompt(
|
||||
task=task,
|
||||
text=query,
|
||||
examples=examples,
|
||||
idx=1
|
||||
)
|
||||
response = self.chat(gen_prompt, **kwargs)[0]
|
||||
pos = clean_content(response)
|
||||
if reverse_query_pos:
|
||||
query, pos = pos, query
|
||||
|
||||
if debug_mode:
|
||||
result = {
|
||||
"generation_prompt": gen_prompt,
|
||||
"prompt": task.task_instruction,
|
||||
"query": query,
|
||||
"pos": [pos],
|
||||
"neg": []
|
||||
}
|
||||
else:
|
||||
result = {
|
||||
"prompt": task.task_instruction,
|
||||
"query": query,
|
||||
"pos": [pos],
|
||||
"neg": []
|
||||
}
|
||||
return result
|
||||
|
||||
def generate_triplets(
|
||||
self,
|
||||
data: dict,
|
||||
task: Task,
|
||||
examples_pool: Optional[List[dict]] = None,
|
||||
num_examples: int = 3,
|
||||
debug_mode: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
kwargs["remove_thinking"] = not debug_mode
|
||||
|
||||
result_list = []
|
||||
|
||||
examples = None
|
||||
if examples_pool is not None:
|
||||
examples = random.sample(examples_pool, min(num_examples, len(examples_pool)))
|
||||
|
||||
try:
|
||||
if task.task_type in SPECIAL_TASK_STEPS:
|
||||
text = data["text"]
|
||||
|
||||
if task.task_type == TaskType.code_modification_retrieval:
|
||||
text_b = data["similar"][0]
|
||||
|
||||
result = self._gen_for_code_modification_retrieval(
|
||||
task=task,
|
||||
text=text,
|
||||
text_b=text_b,
|
||||
examples=examples,
|
||||
debug_mode=debug_mode
|
||||
)
|
||||
elif task.task_type == TaskType.code_comparison_retrieval:
|
||||
text_b = data["similar"][0]
|
||||
|
||||
result = self._gen_for_code_comparison_retrieval(
|
||||
task=task,
|
||||
text=text,
|
||||
text_b=text_b,
|
||||
examples=examples,
|
||||
debug_mode=debug_mode
|
||||
)
|
||||
elif task.task_type == TaskType.bug_desc_retrieval:
|
||||
result = self._gen_for_bug_desc_retrieval(
|
||||
task=task,
|
||||
text=text,
|
||||
examples=examples,
|
||||
debug_mode=debug_mode
|
||||
)
|
||||
elif task.task_type in [
|
||||
# cf - updated
|
||||
TaskType.code_issue_discussion_retrieval,
|
||||
TaskType.code_version_update_retrieval,
|
||||
TaskType.code_bug_fix_example_retrieval,
|
||||
]:
|
||||
result = self._gen_for_two_step_not_use_last(
|
||||
task=task,
|
||||
text=text,
|
||||
examples=examples,
|
||||
debug_mode=debug_mode,
|
||||
reverse_query_pos=False
|
||||
)
|
||||
elif task.task_type in [
|
||||
# cf - updated
|
||||
TaskType.code_refactoring_pattern_retrieval,
|
||||
TaskType.code_style_guideline_example_retrieval,
|
||||
TaskType.code_migration_retrieval,
|
||||
# jl - updated
|
||||
TaskType.code_optimization_hybrid_retrieval,
|
||||
TaskType.code_best_practices_retrieval,
|
||||
TaskType.security_vulnerability_fix_retrieval,
|
||||
]:
|
||||
result = self._gen_for_two_step_use_last(
|
||||
task=task,
|
||||
text=text,
|
||||
examples=examples,
|
||||
debug_mode=debug_mode,
|
||||
reverse_query_pos=False
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Task type {task.task_type} not implemented.")
|
||||
elif task.task_type == TaskType.code_context_retrieval:
|
||||
text = data["text"]
|
||||
|
||||
result = self._gen_for_code_context_retrieval(
|
||||
task=task,
|
||||
text=text,
|
||||
**kwargs
|
||||
)
|
||||
# NOTE: no need to do quality control for code context retrieval task
|
||||
result_list.append(result)
|
||||
return result_list
|
||||
else:
|
||||
text = data["text"]
|
||||
|
||||
result = self._gen_for_normal_task(
|
||||
task=task,
|
||||
text=text,
|
||||
examples=examples,
|
||||
debug_mode=debug_mode,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# print(gen_prompt)
|
||||
# print('================================================')
|
||||
qc_prompt = get_quality_control_prompt(
|
||||
task=task,
|
||||
query=result["query"],
|
||||
pos=result["pos"][0]
|
||||
)
|
||||
# print(qc_prompt)
|
||||
# print('*********************************************************************')
|
||||
response = self.chat(qc_prompt, **kwargs)[0]
|
||||
judge = clean_content(response)
|
||||
# print(response, judge)
|
||||
if "1" in judge:
|
||||
if debug_mode:
|
||||
result["judge"] = judge
|
||||
result["judge_response"] = response
|
||||
result_list.append(result)
|
||||
else:
|
||||
if debug_mode:
|
||||
result["judge"] = judge
|
||||
result["judge_response"] = response
|
||||
result_list.append(result)
|
||||
except Exception as e:
|
||||
warn(f"Error: {e}")
|
||||
|
||||
return result_list
|
||||
|
||||
def gen_hard_negatives(self, result: dict, task: Task, num_negatives: int = 7, **kwargs):
|
||||
gen_hard_neg_prompt = get_gen_hard_neg_prompt(
|
||||
task=task,
|
||||
query=result["query"],
|
||||
pos=result["pos"][0]
|
||||
)
|
||||
response_list = self.chat(gen_hard_neg_prompt, n=num_negatives, **kwargs)
|
||||
for response in response_list:
|
||||
if response is None:
|
||||
continue
|
||||
hard_neg = clean_content(response)
|
||||
result["neg"].append(hard_neg)
|
||||
result["neg"] = list(set(result["neg"]))
|
||||
return result
|
||||
|
||||
def run_single(
|
||||
self,
|
||||
data: dict,
|
||||
task: Task,
|
||||
examples_pool: Optional[List[dict]] = None,
|
||||
num_examples: int = 3,
|
||||
debug_mode: bool = False,
|
||||
gen_hard_neg: bool = False,
|
||||
num_negatives: int = 7,
|
||||
**kwargs
|
||||
):
|
||||
result_list = []
|
||||
|
||||
docid = compute_md5(data["text"])
|
||||
if self.cache_dir is not None:
|
||||
gen_data_cache_path = os.path.join(self.cache_dir, f"{docid}.json")
|
||||
if os.path.exists(gen_data_cache_path):
|
||||
with open(gen_data_cache_path, "r", encoding="utf-8") as f:
|
||||
result_list = json.load(f)
|
||||
|
||||
if len(result_list) > 0:
|
||||
if gen_hard_neg:
|
||||
for i in range(len(result_list)):
|
||||
if len(result_list[i]["neg"]) == 0:
|
||||
result_list[i] = self.gen_hard_negatives(
|
||||
result=result_list[i],
|
||||
task=task,
|
||||
num_negatives=num_negatives,
|
||||
**kwargs
|
||||
)
|
||||
# overwrite the cache file
|
||||
with open(gen_data_cache_path, "w", encoding="utf-8") as f:
|
||||
json.dump(result_list, f, indent=4, ensure_ascii=False)
|
||||
return result_list
|
||||
|
||||
triplets = self.generate_triplets(
|
||||
data,
|
||||
task=task,
|
||||
examples_pool=examples_pool,
|
||||
num_examples=num_examples,
|
||||
debug_mode=debug_mode,
|
||||
**kwargs
|
||||
)
|
||||
if len(triplets) == 0:
|
||||
return []
|
||||
|
||||
result = triplets[0]
|
||||
if debug_mode:
|
||||
result["docid"] = docid
|
||||
|
||||
if gen_hard_neg:
|
||||
result = self.gen_hard_negatives(
|
||||
result,
|
||||
task=task,
|
||||
num_negatives=num_negatives,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
result_list.append(result)
|
||||
|
||||
if self.cache_dir is not None:
|
||||
gen_data_cache_path = os.path.join(self.cache_dir, f"{docid}.json")
|
||||
with open(gen_data_cache_path, "w", encoding="utf-8") as f:
|
||||
json.dump(result_list, f, indent=4, ensure_ascii=False)
|
||||
|
||||
return result_list
|
||||
|
||||
def run(
|
||||
self,
|
||||
positives: List[dict],
|
||||
task_type: str,
|
||||
language: str = "en",
|
||||
code_language: str = "python",
|
||||
tgt_code_language: Optional[str] = None,
|
||||
examples_pool: Optional[List[dict]] = None,
|
||||
num_examples: int = 3,
|
||||
tqdm_desc: str = "Generating triplets",
|
||||
debug_mode: bool = False,
|
||||
gen_hard_neg: bool = False,
|
||||
num_negatives: int = 7,
|
||||
thread_count: int = 1,
|
||||
**kwargs
|
||||
):
|
||||
task = get_task(
|
||||
task_type=task_type,
|
||||
language=language,
|
||||
code_language=code_language,
|
||||
tgt_code_language=tgt_code_language
|
||||
)
|
||||
|
||||
result_list = []
|
||||
|
||||
def process_positive(positive):
|
||||
return self.run_single(
|
||||
data=positive,
|
||||
task=task,
|
||||
examples_pool=examples_pool,
|
||||
num_examples=num_examples,
|
||||
debug_mode=debug_mode,
|
||||
gen_hard_neg=gen_hard_neg,
|
||||
num_negatives=num_negatives,
|
||||
**kwargs
|
||||
)
|
||||
# Use thread pool for parallel processing with tqdm progress bar.
|
||||
with ThreadPoolExecutor(max_workers=thread_count) as executor:
|
||||
results = list(tqdm(executor.map(
|
||||
process_positive,
|
||||
positives
|
||||
), total=len(positives), desc=tqdm_desc))
|
||||
|
||||
# Collect results into result_list.
|
||||
for res in results:
|
||||
if isinstance(res, list):
|
||||
result_list.extend(res)
|
||||
else:
|
||||
result_list.append(res)
|
||||
# result_list.extend(results)
|
||||
|
||||
return result_list
|
||||
|
||||
def run_for_gen_neg(
|
||||
self,
|
||||
pairs: List[dict],
|
||||
task_type: str,
|
||||
language: str = "en",
|
||||
code_language: str = "python",
|
||||
tgt_code_language: Optional[str] = None,
|
||||
examples_pool: Optional[List[dict]] = None,
|
||||
num_examples: int = 3,
|
||||
tqdm_desc: str = "Generating triplets",
|
||||
debug_mode: bool = False,
|
||||
gen_hard_neg: bool = False,
|
||||
num_negatives: int = 7,
|
||||
thread_count: int = 1,
|
||||
**kwargs
|
||||
):
|
||||
task = get_task(
|
||||
task_type=task_type,
|
||||
language=language,
|
||||
code_language=code_language,
|
||||
tgt_code_language=tgt_code_language
|
||||
)
|
||||
|
||||
result_list = []
|
||||
|
||||
def gen_single_negative(pair):
|
||||
result = self.gen_hard_negatives(
|
||||
pair,
|
||||
task=task,
|
||||
num_negatives=num_negatives,
|
||||
**kwargs
|
||||
)
|
||||
return [result]
|
||||
|
||||
# Use thread pool for parallel processing with tqdm progress bar.
|
||||
with ThreadPoolExecutor(max_workers=thread_count) as executor:
|
||||
results = list(tqdm(executor.map(
|
||||
gen_single_negative,
|
||||
pairs
|
||||
), total=len(pairs), desc=tqdm_desc))
|
||||
|
||||
# Collect results into result_list.
|
||||
for res in results:
|
||||
if isinstance(res, list):
|
||||
result_list.extend(res)
|
||||
else:
|
||||
result_list.append(res)
|
||||
# result_list.extend(results)
|
||||
|
||||
return result_list
|
128
research/BGE_Coder/data_generation/utils.py
Normal file
128
research/BGE_Coder/data_generation/utils.py
Normal file
@ -0,0 +1,128 @@
|
||||
import re
|
||||
|
||||
|
||||
def clean_content(content: str):
|
||||
if content is None:
|
||||
raise ValueError("content is None.")
|
||||
|
||||
content = content.split('</think>')[-1].strip('\n').strip()
|
||||
|
||||
if content.startswith('\"') and content.endswith('\"'):
|
||||
content = content[1:-1]
|
||||
|
||||
if content.startswith("```\n") and content.endswith("\n```"):
|
||||
content = content[4:-4]
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def clean_code(code: str, lang: str, length_threshold: int = 30) -> str:
|
||||
cleaned_code = code.strip('\ufeff').strip()
|
||||
if not cleaned_code:
|
||||
return ''
|
||||
|
||||
def clean_empty_lines(text: str) -> str:
|
||||
return re.sub(r'\n\s*\n', '\n', text).strip()
|
||||
|
||||
# 各语言函数/类定义检测正则表达式
|
||||
function_patterns = {
|
||||
"java": r"(?m)^(?!\s*(import|package)\b).*\b(public\s+class|class\s+\w+|void\s+main|new\s+\w+\(|@Override)\b",
|
||||
"python": r"(?m)^(?!\s*(import|from\s+\S+\s+import)\b).*\b(def\s+\w+|class\s+\w+|=\s*\S+|if\s+[:\w]|print\s+)",
|
||||
"javascript": r"(?m)^(?!\s*(import|require\(|export\s)).*\b(function\s+\w+|const\s+\w+|=>|\(\)\s*=>|console\.log)",
|
||||
"php": r"(?m)^(?!\s*(include|require|use)\b).*\b(function\s+\w+|echo\s+\S+|class\s+\w+)",
|
||||
"ruby": r"(?m)^(?!\s*(require|load)\b).*\b(class\s+\w+|def\s+\w+|puts\s+\S+)",
|
||||
"go": r"(?m)^(?!\s*import\b).*\bfunc\s+main\s*\(|type\s+\w+\s+struct",
|
||||
"c#": r"(?m)^(?!\s*using\b).*\b(class\s+\w+|void\s+Main\s*\()",
|
||||
"cplusplus": r"(?m)^(?!#include\b).*\b(int\s+main\s*\(|class\s+\w+|void\s+\w+\s*\(.*\)\s*{)",
|
||||
"c": r"(?m)^(?!#include\b).*\b(int\s+main\s*\(|void\s+\w+\s*\(.*\)\s*{)",
|
||||
"rust": r"(?m)^(?!\s*use\b).*\b(fn\s+main\s*\(|struct\s+\w+|impl\s+\w+)",
|
||||
"typescript": r"(?m)^(?!\s*(import|require\(|export\s)).*\b(interface\s+\w+|class\s+\w+|function\s+\w+)",
|
||||
"perl": r"(?m)^(?!\s*(use|require)\b).*\b(sub\s+\w+|my\s+\$\w+|print\s+\S+)",
|
||||
"shell": r"(?m)^(?!\s*(source|\.)\s).*\b(function\s+\w+|if\s+\[|\$\(|echo\s+\S+)",
|
||||
"sql": r"(?i)\b(CREATE\s+TABLE|SELECT\s+\*|INSERT\s+INTO|UPDATE\s+\w+|DELETE\s+FROM)\b",
|
||||
"batchfile": r"(?m)^(?!\s*@?call\b).*\b(echo\s+\S+|set\s+\w+|if\s+.*\s+==\s+)",
|
||||
"fortran": r"(?mi)^(?!\s*use\b).*\b(program\s+\w+|subroutine\s+\w+|do\s+\d+\s*,\s*\d+)",
|
||||
"haskell": r"(?m)^(?!\s*import\b).*\b(main\s*=\s*do|data\s+\w+|putStrLn\s+\S+)",
|
||||
"lua": r"(?m)^(?!\s*require\b).*\b(function\s+\w+|local\s+\w+|print\s*\()",
|
||||
"powershell": r"(?m)^(?!\s*Import-Module\b).*\b(function\s+\w+|Write-Host\s+\S+|\$\w+\s*=)",
|
||||
"visual_basic": r"(?m)^(?!\s*Imports\b).*\b(Module\s+\w+|Sub\s+Main|Class\s+\w+)"
|
||||
}
|
||||
|
||||
# 各语言注释处理规则
|
||||
comment_patterns = {
|
||||
'java': (r'//.*?$|/\*.*?\*/|\\/\\/.*?$|\\/\*.*?\*\\/', re.DOTALL | re.MULTILINE),
|
||||
'python': (r'#.*?$', re.MULTILINE),
|
||||
'javascript': (r'//.*?$|/\*.*?\*/|\\/\\/.*?$|\\/\*.*?\*\\/', re.DOTALL | re.MULTILINE),
|
||||
'php': (r'//.*?$|#.*?$|/\*.*?\*/|\\/\\/.*?$|#.*?$|\\/\*.*?\*\\/', re.DOTALL | re.MULTILINE),
|
||||
'ruby': (r'#.*', re.MULTILINE),
|
||||
'go': (r'//.*?$|/\*.*?\*/|\\/\\/.*?$|\\/\*.*?\*\\/', re.DOTALL | re.MULTILINE),
|
||||
'csharp': (r'//.*?$|/\*.*?\*/|\\/\\/.*?$|\\/\*.*?\*\\/', re.DOTALL | re.MULTILINE),
|
||||
'cplusplus': (r'//.*?$|/\*.*?\*/|\\/\\/.*?$|\\/\*.*?\*\\/', re.DOTALL | re.MULTILINE),
|
||||
'c': (r'//.*?$|/\*.*?\*/|\\/\\/.*?$|\\/\*.*?\*\\/', re.DOTALL | re.MULTILINE),
|
||||
'rust': (r'//.*?$|/\*.*?\*/|\\/\\/.*?$|\\/\*.*?\*\\/', re.DOTALL | re.MULTILINE),
|
||||
'typescript': (r'//.*?$|/\*.*?\*/|\\/\\/.*?$|\\/\*.*?\*\\/', re.DOTALL | re.MULTILINE),
|
||||
'perl': (r'#.*', re.MULTILINE),
|
||||
'shell': (r'#.*', re.MULTILINE),
|
||||
'sql': (r'--.*?$|/\*.*?\*/', re.DOTALL),
|
||||
'batchfile': (r'^\s*(REM|@REM|::).*', re.MULTILINE | re.IGNORECASE),
|
||||
'fortran': (r'!.*', re.MULTILINE),
|
||||
'haskell': (r'--.*', re.MULTILINE),
|
||||
'lua': (r'--.*?$|--\[\[.*?\]\]', re.DOTALL),
|
||||
'powershell': (r'<#.*?#>|#.*', re.DOTALL),
|
||||
'visual_basic': (r"'.*", re.MULTILINE),
|
||||
}
|
||||
|
||||
# 执行注释清理
|
||||
if lang in comment_patterns:
|
||||
pattern, flags = comment_patterns[lang]
|
||||
cleaned_code = re.sub(pattern, '', cleaned_code, flags=flags)
|
||||
cleaned_code = clean_empty_lines(cleaned_code)
|
||||
|
||||
# 特殊语言处理规则
|
||||
if lang == 'fortran':
|
||||
cleaned_code = re.sub(r'^[Cc*].*', '', cleaned_code, flags=re.MULTILINE)
|
||||
elif lang == 'sql':
|
||||
cleaned_code = re.sub(r'/\*.*?\*/', '', cleaned_code, flags=re.DOTALL)
|
||||
elif lang == 'python':
|
||||
cleaned_code = re.sub(r'^\s*#.*', '', cleaned_code, flags=re.MULTILINE)
|
||||
|
||||
# 函数定义检测及内容验证
|
||||
def has_valid_code(text: str, lang: str) -> bool:
|
||||
pattern = function_patterns.get(lang)
|
||||
if not pattern:
|
||||
return len(text.strip()) > 0
|
||||
|
||||
# 增强检测逻辑
|
||||
if lang == 'batchfile':
|
||||
return bool(re.search(r'^\s*@?echo\b|:\w+', text, re.MULTILINE))
|
||||
if lang == 'shell':
|
||||
return bool(re.search(r'^\s*(if|for|while|case|echo|export|shopt|source)\b', text, re.MULTILINE))
|
||||
if lang == 'python':
|
||||
if re.search(r'^\s*(def|class)\s+\w+', text, re.MULTILINE):
|
||||
return bool(re.search(r'^\s+[^\s#]', text, re.MULTILINE))
|
||||
return False
|
||||
if lang == 'ruby':
|
||||
return bool(re.search(r'(def\s+\w+|class\s+\w+).*?\n\s+[^\s#]', text, re.MULTILINE))
|
||||
return bool(re.search(pattern, text, re.DOTALL | re.MULTILINE))
|
||||
|
||||
# 最终有效性检查
|
||||
if not has_valid_code(cleaned_code, lang):
|
||||
return ''
|
||||
|
||||
cleaned_code = cleaned_code.strip('\ufeff').strip()
|
||||
|
||||
if len(cleaned_code) < length_threshold:
|
||||
return ''
|
||||
|
||||
return cleaned_code
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_text = "\/\/ ----------------------------------------------------------------------\n\/\/ ----------------------------------------------------------------------\n\/\/\n\/\/ File: StrMaxProjection.h\n\/\/ Author: mgrosso \n\/\/ Created: Mon Jul 17 14:39:22 PDT 2006 on caliban\n\/\/ Project: \n\/\/ Purpose: \n\/\/ \n\/\/ $Id$\n\/\/ ----------------------------------------------------------------------\n\/\/ ----------------------------------------------------------------------\n\n#ifndef STRMAXPROJECTION_H\n#define STRMAXPROJECTION_H 1\n\n#include \"StrMinProjection.h\"\n\nclass StrMaxProjection : public StrMinProjection\n{\n public:\n StrMaxProjection(ExpressionPtr &operand);\n virtual ~StrMaxProjection();\n virtual AbstractProjectionPtr copy();\n\n protected:\n int compare(const char *lhs, const char *rhs);\n\n private:\n \/\/not implemented\n StrMaxProjection();\n StrMaxProjection( const StrMaxProjection &rhs );\n StrMaxProjection &operator=( const StrMaxProjection &rhs );\n};\n\n#endif \/* STRMAXPROJECTION_H *\/"
|
||||
|
||||
result = clean_code(test_text, "c", 200)
|
||||
print(result)
|
||||
|
||||
test_text = "\/**\n * Copyright (c) Microsoft Corporation. All rights reserved.\n * Licensed under the MIT License. See License.txt in the project root for\n * license information.\n *\n * Code generated by Microsoft (R) AutoRest Code Generator.\n *\/\n\npackage com.microsoft.azure.management.datafactory.v2018_06_01;\n\nimport com.fasterxml.jackson.annotation.JsonProperty;\nimport com.fasterxml.jackson.annotation.JsonTypeInfo;\nimport com.fasterxml.jackson.annotation.JsonTypeName;\n\n\/**\n * The location of Google Cloud Storage dataset.\n *\/\n@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.PROPERTY, property = \"type\", defaultImpl = GoogleCloudStorageLocation.class)\n@JsonTypeName(\"GoogleCloudStorageLocation\")\npublic class GoogleCloudStorageLocation extends DatasetLocation {\n \/**\n * Specify the bucketName of Google Cloud Storage. Type: string (or\n * Expression with resultType string).\n *\/\n @JsonProperty(value = \"bucketName\")\n private Object bucketName;\n\n \/**\n * Specify the version of Google Cloud Storage. Type: string (or Expression\n * with resultType string).\n *\/\n @JsonProperty(value = \"version\")\n private Object version;\n\n \/**\n * Get specify the bucketName of Google Cloud Storage. Type: string (or Expression with resultType string).\n *\n * @return the bucketName value\n *\/\n public Object bucketName() {\n return this.bucketName;\n }\n\n \/**\n * Set specify the bucketName of Google Cloud Storage. Type: string (or Expression with resultType string).\n *\n * @param bucketName the bucketName value to set\n * @return the GoogleCloudStorageLocation object itself.\n *\/\n public GoogleCloudStorageLocation withBucketName(Object bucketName) {\n this.bucketName = bucketName;\n return this;\n }\n\n \/**\n * Get specify the version of Google Cloud Storage. Type: string (or Expression with resultType string).\n *\n * @return the version value\n *\/\n public Object version() {\n return this.version;\n }\n\n \/**\n * Set specify the version of Google Cloud Storage. Type: string (or Expression with resultType string).\n *\n * @param version the version value to set\n * @return the GoogleCloudStorageLocation object itself.\n *\/\n public GoogleCloudStorageLocation withVersion(Object version) {\n this.version = version;\n return this;\n }\n\n}"
|
||||
result = clean_code(test_text, "java", 200)
|
||||
print(result)
|
Loading…
x
Reference in New Issue
Block a user