2023-10-10 12:23:43 +00:00
|
|
|
import os
|
|
|
|
import copy
|
|
|
|
import json
|
|
|
|
import logging
|
|
|
|
import datasets
|
|
|
|
from typing import List
|
|
|
|
from accelerate import Accelerator
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
from transformers import HfArgumentParser
|
|
|
|
from dataclasses import dataclass, field, asdict
|
|
|
|
from collections import defaultdict
|
|
|
|
|
|
|
|
from src.lm import (
|
|
|
|
LM,
|
|
|
|
LMArgs
|
|
|
|
)
|
|
|
|
from src.retrieval import (
|
|
|
|
RetrievalArgs,
|
|
|
|
RetrievalMetric,
|
|
|
|
)
|
|
|
|
from src.utils.util import makedirs, remove_eos, DefaultDataCollator, DatasetProcessFn, FileLogger
|
|
|
|
from .eval_retrieval import main as retrieval_main
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
import transformers
|
|
|
|
transformers.logging.set_verbosity_error()
|
|
|
|
|
|
|
|
SUBJECT_2_CATEGORY={"abstract_algebra": "STEM", "anatomy": "others", "astronomy": "STEM", "business_ethics": "others", "clinical_knowledge": "others", "college_biology": "STEM", "college_chemistry": "STEM", "college_computer_science": "STEM", "college_mathematics": "STEM", "college_medicine": "others", "college_physics": "STEM", "computer_security": "STEM", "conceptual_physics": "STEM", "econometrics": "Social Sciences", "electrical_engineering": "STEM", "elementary_mathematics": "STEM", "formal_logic": "Humanities", "global_facts": "others", "high_school_biology": "STEM", "high_school_chemistry": "STEM", "high_school_computer_science": "STEM", "high_school_european_history": "Humanities", "high_school_geography": "Social Sciences", "high_school_government_and_politics": "Social Sciences", "high_school_macroeconomics": "Social Sciences", "high_school_mathematics": "STEM", "high_school_microeconomics": "Social Sciences", "high_school_physics": "STEM", "high_school_psychology": "Social Sciences", "high_school_statistics": "STEM", "high_school_us_history": "Humanities", "high_school_world_history": "Humanities", "human_aging": "others", "human_sexuality": "Social Sciences", "international_law": "Humanities", "jurisprudence": "Humanities", "logical_fallacies": "Humanities", "machine_learning": "STEM", "management": "others", "marketing": "others", "medical_genetics": "others", "miscellaneous": "others", "moral_disputes": "Humanities", "moral_scenarios": "Humanities", "nutrition": "others", "philosophy": "Humanities", "prehistory": "Humanities", "professional_accounting": "others", "professional_law": "Humanities", "professional_medicine": "others", "professional_psychology": "Social Sciences", "public_relations": "Social Sciences", "security_studies": "Social Sciences", "sociology": "Social Sciences", "us_foreign_policy": "Social Sciences", "virology": "others", "world_religions": "Humanities"}
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class MMLUArgs(LMArgs, RetrievalArgs):
|
|
|
|
output_dir: str = field(
|
|
|
|
default="data/results/mmlu",
|
|
|
|
)
|
|
|
|
eval_data: str = field(
|
|
|
|
default="llm-embedder:qa/mmlu/test.json",
|
|
|
|
metadata={'help': 'Path to the test file.'}
|
|
|
|
)
|
|
|
|
lm_batch_size: int = field(
|
|
|
|
default=2,
|
|
|
|
metadata={'help': 'Evaluation batch size.'},
|
|
|
|
)
|
|
|
|
|
|
|
|
few_shot: int = field(
|
|
|
|
default=0,
|
|
|
|
metadata={'help': 'How many few shot train samples?'},
|
|
|
|
)
|
|
|
|
train_data: str = field(
|
|
|
|
default="llm-embedder:qa/mmlu/dev.json",
|
|
|
|
metadata={'help': 'Path to the file containing training examples.'}
|
|
|
|
)
|
|
|
|
|
|
|
|
corpus: str = field(
|
|
|
|
default="llm-embedder:qa/msmarco/corpus.json",
|
|
|
|
metadata={'help': 'Corpus path for retrieval.'}
|
|
|
|
)
|
|
|
|
key_template: str = field(
|
|
|
|
default="{title} {text}",
|
|
|
|
metadata={'help': 'How to concatenate columns in the corpus to form one key?'}
|
|
|
|
)
|
|
|
|
key_max_length: int = field(
|
|
|
|
default=128,
|
|
|
|
metadata={'help': 'How many tokens at maximum in a key.'}
|
|
|
|
)
|
|
|
|
hits: int = field(
|
|
|
|
default=10,
|
|
|
|
metadata={'help': 'How many hits per query?'},
|
|
|
|
)
|
|
|
|
key_num: int = field(
|
|
|
|
default=3,
|
|
|
|
metadata={'help': 'How many docs to provide in prompt?'},
|
|
|
|
)
|
|
|
|
metrics: List[str] = field(
|
|
|
|
default_factory=lambda: ["collate_key"],
|
|
|
|
)
|
2023-12-13 06:35:51 +00:00
|
|
|
save_to_output: bool = field(
|
|
|
|
default=True,
|
|
|
|
metadata={'help': 'Save the result/key/negative to output_dir? If not true, they will be saved next to the eval_data.'}
|
|
|
|
)
|
2023-10-10 12:23:43 +00:00
|
|
|
|
|
|
|
log_path: str = field(
|
|
|
|
default="data/results/mmlu/mmlu.log",
|
|
|
|
metadata={'help': 'Path to the file for logging.'}
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def process_mmlu(tokenizer, context_max_length=2048, key_num=3, few_shot=0, train_data=None, cache_dir=None, is_encoder_decoder=False, add_llama_inst=False):
|
|
|
|
tokenizer.truncation_side = 'right'
|
|
|
|
left_truncation_tokenizer = copy.deepcopy(tokenizer)
|
|
|
|
left_truncation_tokenizer.truncation_side = 'left'
|
|
|
|
|
|
|
|
test = tokenizer("test", return_special_tokens_mask=True)["special_tokens_mask"]
|
|
|
|
|
|
|
|
has_bos = has_eos = False
|
|
|
|
if test[0] == 1:
|
|
|
|
has_bos = True
|
|
|
|
if test[-1] == 1:
|
|
|
|
has_eos = True
|
|
|
|
|
|
|
|
if few_shot > 0:
|
|
|
|
assert train_data is not None
|
|
|
|
train_data = datasets.load_dataset("json", data_files=train_data, cache_dir=cache_dir, split="train")
|
|
|
|
train_df = train_data.to_pandas()
|
|
|
|
# transform the dataframe into dict of dataframes
|
|
|
|
train_df = {k: v[:few_shot] for k, v in train_df.groupby("subject")}
|
|
|
|
|
|
|
|
options = ['A', 'B', 'C', 'D']
|
|
|
|
|
|
|
|
def _prepare_sample(query, choices, answer):
|
|
|
|
"""
|
|
|
|
<Question>
|
|
|
|
A. <Choices 1>
|
|
|
|
B. <Choices 2>
|
|
|
|
C. <Choices 3>
|
|
|
|
D. <Choices 4>
|
|
|
|
Answer: <Answer>
|
|
|
|
"""
|
|
|
|
# answer maybe int or numpy int64
|
|
|
|
if not isinstance(answer, str):
|
|
|
|
answer = options[answer]
|
|
|
|
|
|
|
|
sample = f"{query}\n{chr(10).join([f'{option}. {choice}' for option, choice in zip(options, choices)])}\nAnswer: {answer}"
|
|
|
|
return sample
|
|
|
|
|
|
|
|
def _prepare_knowledge(key, max_length=None):
|
|
|
|
if key is not None:
|
|
|
|
key = key[:key_num]
|
|
|
|
key = "\n".join(key)
|
|
|
|
key = f"Knowledge:\n{key}"
|
|
|
|
if max_length is not None:
|
|
|
|
# truncate key if necessary
|
|
|
|
key = tokenizer.decode(tokenizer.encode(key, add_special_tokens=False, truncation=True, max_length=max_length))
|
|
|
|
else:
|
|
|
|
key = ""
|
|
|
|
return key
|
|
|
|
|
|
|
|
@DatasetProcessFn(augment=True)
|
|
|
|
def _process(query, choices, query_id, subject, answer, key=None, **kwds):
|
|
|
|
"""Yield key and query with a prompt template"""
|
|
|
|
output = defaultdict(list)
|
|
|
|
query = query.strip()
|
|
|
|
|
|
|
|
head = f"The following are multiple choice questions (with answers) about {' '.join(subject.split('_'))}.\n\n"
|
|
|
|
|
|
|
|
if few_shot > 0:
|
|
|
|
train_samples = ""
|
|
|
|
for i in range(few_shot):
|
|
|
|
if i >= len(train_df[subject]):
|
|
|
|
break
|
|
|
|
train_sample = train_df[subject].iloc[i][['query', 'choices', 'answer']]
|
|
|
|
train_sample = _prepare_sample(**train_sample) + "\n\n"
|
|
|
|
train_samples += train_sample
|
|
|
|
else:
|
|
|
|
train_samples = ""
|
|
|
|
|
|
|
|
knowledge_max_length = context_max_length - len(tokenizer.encode(head + train_samples + _prepare_sample(query, choices, 'A'))) - int(has_bos) - int(has_eos)
|
|
|
|
if knowledge_max_length < 0:
|
|
|
|
knowledge = ""
|
|
|
|
else:
|
|
|
|
knowledge = _prepare_knowledge(key, knowledge_max_length)
|
|
|
|
|
|
|
|
for option in options:
|
|
|
|
left = knowledge
|
|
|
|
right = head + train_samples + _prepare_sample(query, choices, option)
|
|
|
|
# \n\n to split knowledge and prompts
|
|
|
|
if len(left):
|
|
|
|
right = "\n\n" + right
|
|
|
|
|
|
|
|
# TODO: add llama instruction
|
|
|
|
# if add_llama_inst:
|
|
|
|
# left = "[INST]" + left
|
|
|
|
# right = right + "[/INST]"
|
|
|
|
|
|
|
|
inputs = left_truncation_tokenizer(left + right, truncation=True, max_length=context_max_length, return_token_type_ids=False)
|
|
|
|
|
|
|
|
if has_eos and not is_encoder_decoder:
|
|
|
|
inputs = remove_eos(inputs, tokenizer.eos_token_id)
|
|
|
|
|
|
|
|
# find answer length
|
|
|
|
option_seq = tokenizer.encode("Answer: " + option, add_special_tokens=False)
|
|
|
|
option_length = len(option_seq) - len(tokenizer.encode("Answer:", add_special_tokens=False))
|
|
|
|
|
|
|
|
if is_encoder_decoder:
|
|
|
|
labels = inputs["input_ids"].copy()[-option_length:]
|
|
|
|
for k, v in inputs.items():
|
|
|
|
inputs[k] = v[:-option_length]
|
|
|
|
inputs["labels"] = labels
|
|
|
|
|
|
|
|
else:
|
|
|
|
# take care of padded tokens
|
|
|
|
labels = inputs["input_ids"].copy()
|
|
|
|
labels = [x if inputs["attention_mask"][i] == 1 else -100 for i, x in enumerate(labels)]
|
|
|
|
labels[:-option_length] = [-100] * (len(labels) - option_length)
|
|
|
|
inputs["labels"] = labels
|
|
|
|
|
|
|
|
inputs["query_id"] = query_id
|
|
|
|
for k, v in inputs.items():
|
|
|
|
output[k].append(v)
|
|
|
|
return output
|
|
|
|
return _process
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_mmlu(eval_data, save_path, **kwds):
|
|
|
|
def compute_metric(eval_preds):
|
|
|
|
makedirs(save_path)
|
|
|
|
|
|
|
|
tasks = defaultdict(list)
|
|
|
|
results = defaultdict(list)
|
|
|
|
samples = {}
|
|
|
|
|
|
|
|
with open(eval_data) as f:
|
|
|
|
for line in f:
|
|
|
|
sample = json.loads(line.strip())
|
|
|
|
samples[sample["query_id"]] = sample
|
|
|
|
|
|
|
|
# nll must comes in the order of A, B, C, and D
|
|
|
|
for query_id, nll in zip(*eval_preds):
|
|
|
|
# store log likelihood
|
|
|
|
results[query_id].append(-nll)
|
|
|
|
|
2023-12-13 06:35:51 +00:00
|
|
|
with open(makedirs(save_path), "w") as f:
|
2023-10-10 12:23:43 +00:00
|
|
|
for k, v in results.items():
|
|
|
|
output = max(enumerate(v), key=lambda x: x[1])[0]
|
|
|
|
sample = samples[k]
|
|
|
|
sample["output"] = output
|
|
|
|
tasks[sample["subject"]].append((output, sample["answer"]))
|
|
|
|
f.write(json.dumps(sample, ensure_ascii=False) + "\n")
|
|
|
|
|
|
|
|
metrics = defaultdict(list)
|
|
|
|
for task_name, task_eval_preds in tasks.items():
|
|
|
|
accuracy = 0
|
|
|
|
for pred, label in task_eval_preds:
|
|
|
|
accuracy += int(pred == label)
|
|
|
|
accuracy /= len(task_eval_preds)
|
|
|
|
|
|
|
|
category = SUBJECT_2_CATEGORY[task_name]
|
|
|
|
metrics[f"{category}"].append(accuracy)
|
|
|
|
metrics["all"].append(accuracy)
|
|
|
|
|
|
|
|
for k, v in metrics.items():
|
|
|
|
metrics[k] = sum(v) / len(v)
|
|
|
|
|
|
|
|
metrics = {
|
|
|
|
"STEM": metrics["STEM"],
|
|
|
|
"Social Sciences": metrics["Social Sciences"],
|
|
|
|
"Humanities": metrics["Humanities"],
|
|
|
|
"Others": metrics["others"],
|
|
|
|
"All": metrics["all"],
|
|
|
|
}
|
|
|
|
|
|
|
|
return dict(metrics)
|
|
|
|
return compute_metric
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
parser = HfArgumentParser([MMLUArgs])
|
|
|
|
args, = parser.parse_args_into_dataclasses()
|
|
|
|
|
|
|
|
accelerator = Accelerator(cpu=args.cpu)
|
|
|
|
|
|
|
|
# modify the output_dir for retrieval
|
|
|
|
if args.retrieval_method == "dense":
|
|
|
|
output_dir = os.path.join(args.output_dir, args.query_encoder.strip(os.sep).replace(os.sep, "--"))
|
|
|
|
else:
|
|
|
|
output_dir = os.path.join(args.output_dir, args.retrieval_method)
|
|
|
|
args.output_dir = output_dir
|
|
|
|
|
|
|
|
if args.retrieval_method != "no":
|
|
|
|
retrieval_main(args=args, accelerator=accelerator, log=False)
|
2023-12-13 06:35:51 +00:00
|
|
|
eval_data = RetrievalMetric._get_save_path(args.eval_data, args.output_dir, field="key", save_name=args.save_name)
|
2023-10-10 12:23:43 +00:00
|
|
|
else:
|
|
|
|
eval_data = args.eval_data
|
|
|
|
|
|
|
|
lm = LM(
|
|
|
|
model_name_or_path=args.model_name_or_path,
|
|
|
|
dtype=args.lm_dtype,
|
|
|
|
device_map=args.lm_device_map,
|
|
|
|
padding_side=args.padding_side,
|
|
|
|
cache_dir=args.model_cache_dir,
|
|
|
|
accelerator=accelerator
|
|
|
|
)
|
|
|
|
|
|
|
|
tokenizer = lm.tokenizer
|
|
|
|
|
|
|
|
with accelerator.main_process_first():
|
|
|
|
logging.info(f"Loading data from {eval_data}...")
|
|
|
|
dataset = datasets.load_dataset("json", data_files=eval_data, split="train", cache_dir=args.dataset_cache_dir)
|
|
|
|
dataset = dataset.map(process_mmlu(
|
|
|
|
tokenizer,
|
|
|
|
context_max_length=args.context_max_length,
|
|
|
|
key_num=args.key_num,
|
|
|
|
few_shot=args.few_shot,
|
|
|
|
train_data=args.train_data,
|
|
|
|
cache_dir=args.dataset_cache_dir,
|
|
|
|
is_encoder_decoder=lm.model.config.is_encoder_decoder,
|
|
|
|
add_llama_inst=args.add_llama_inst
|
|
|
|
), remove_columns=dataset.column_names, batched=True, num_proc=32)
|
|
|
|
|
|
|
|
data_collator = DefaultDataCollator(tokenizer=tokenizer, add_position_ids=args.add_position_ids)
|
|
|
|
dataloader = DataLoader(
|
|
|
|
dataset,
|
|
|
|
batch_size=args.lm_batch_size,
|
|
|
|
collate_fn=data_collator,
|
|
|
|
pin_memory=True,
|
|
|
|
)
|
|
|
|
dataloader = accelerator.prepare(dataloader)
|
|
|
|
|
|
|
|
results = lm.compute_nlls(dataloader)
|
|
|
|
|
|
|
|
if accelerator.process_index == 0:
|
|
|
|
file_logger = FileLogger(makedirs(args.log_path))
|
|
|
|
result_path = os.path.join(args.output_dir, args.model_name_or_path.strip(os.sep).replace(os.sep, "--") + ".json")
|
|
|
|
metrics = evaluate_mmlu(eval_data, result_path)(results)
|
|
|
|
file_logger.log(metrics, Args=asdict(args))
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|