mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-06-27 02:39:58 +00:00
191 lines
6.2 KiB
Python
191 lines
6.2 KiB
Python
import os
|
|
import logging
|
|
import datasets
|
|
|
|
from copy import deepcopy
|
|
from accelerate import Accelerator
|
|
from torch.utils.data import DataLoader
|
|
from transformers import HfArgumentParser
|
|
from dataclasses import dataclass, field, asdict
|
|
|
|
from src.lm import SRLMArgs, SelfRetrievalLM
|
|
from src.retrieval import Retriever, RetrievalArgs, TASK_CONFIG
|
|
from src.utils.util import makedirs, remove_eos, DefaultDataCollator, DatasetProcessFn, FileLogger
|
|
|
|
logger = logging.getLogger(__name__)
|
|
import transformers
|
|
# disable too long input warning
|
|
transformers.logging.set_verbosity_error()
|
|
|
|
|
|
# merge two args to get unified arguments
|
|
@dataclass
|
|
class LRLMArgs(RetrievalArgs, SRLMArgs):
|
|
eval_data: str = field(
|
|
default="llm-embedder:lrlm/books3/test.json",
|
|
metadata={'help': 'Evaluation json file.'},
|
|
)
|
|
lm_batch_size: int = field(
|
|
default=1,
|
|
metadata={'help': 'Evaluation json file.'},
|
|
)
|
|
|
|
context_max_length: int = field(
|
|
default=32768,
|
|
metadata={'help': 'Evaluation json file.'},
|
|
)
|
|
anchor_length: int = field(
|
|
default=160000,
|
|
metadata={'help': 'Evaluation file containing long texts.'}
|
|
)
|
|
chunk_size: int = field(
|
|
default=128,
|
|
metadata={'help': 'How many tokens in a chunk?'}
|
|
)
|
|
key_num: int = field(
|
|
default=8,
|
|
metadata={'help': 'How many chunks to retrieve at a time?'}
|
|
)
|
|
chunk_batch_size: int = field(
|
|
default=1,
|
|
metadata={'help': 'How many retrieval & generation to execute in parallel?'}
|
|
)
|
|
|
|
log_path: str = field(
|
|
default="data/results/lrlm",
|
|
metadata={'help': 'Path to the file for logging.'}
|
|
)
|
|
debug_retrieval: bool = field(
|
|
default=False,
|
|
metadata={'help': 'Check retrieval queries and values?'}
|
|
)
|
|
|
|
def __post_init__(self):
|
|
super().__post_init__()
|
|
if self.retrieval_method == "bm25":
|
|
# NOTE: we can only use naive bm25 for self retrieval
|
|
self.retrieval_method = "naive-bm25"
|
|
|
|
|
|
def process_lrlm(tokenizer, context_max_length=4096, target_length=1024, anchor_length=160000):
|
|
test = tokenizer("test", return_special_tokens_mask=True)["special_tokens_mask"]
|
|
has_eos = False
|
|
if test[-1] == 1:
|
|
has_eos = True
|
|
|
|
left_truncation_tokenizer = deepcopy(tokenizer)
|
|
left_truncation_tokenizer.truncation_side = "left"
|
|
|
|
@DatasetProcessFn()
|
|
def _process(text, **kwds):
|
|
output = {}
|
|
text = text[:anchor_length]
|
|
|
|
inputs = left_truncation_tokenizer(text, max_length=context_max_length, truncation=True, return_token_type_ids=False, add_special_tokens=False)
|
|
|
|
if len(inputs.input_ids) < target_length:
|
|
return None
|
|
|
|
labels = inputs["input_ids"].copy()
|
|
inputs_length = len(labels)
|
|
labels[:-target_length] = [-100 for _ in range(inputs_length - target_length)]
|
|
inputs["labels"] = labels
|
|
|
|
for k, v in inputs.items():
|
|
output[k] = v
|
|
return output
|
|
return _process
|
|
|
|
|
|
def main():
|
|
parser = HfArgumentParser([LRLMArgs])
|
|
args, = parser.parse_args_into_dataclasses()
|
|
|
|
accelerator = Accelerator(cpu=args.cpu)
|
|
|
|
retriever = Retriever(
|
|
retrieval_method=args.retrieval_method,
|
|
# for dense retriever
|
|
query_encoder=args.query_encoder,
|
|
key_encoder=args.key_encoder,
|
|
pooling_method=args.pooling_method,
|
|
dense_metric=args.dense_metric,
|
|
query_max_length=args.query_max_length,
|
|
key_max_length=args.key_max_length,
|
|
tie_encoders=args.tie_encoders,
|
|
truncation_side=args.truncation_side,
|
|
cache_dir=args.model_cache_dir,
|
|
dtype=args.dtype,
|
|
accelerator=accelerator,
|
|
# for bm25 retriever
|
|
anserini_dir=args.anserini_dir,
|
|
k1=args.k1,
|
|
b=args.b
|
|
)
|
|
|
|
if args.add_instruction:
|
|
instruction = TASK_CONFIG[args.version]["instruction"]["lrlm"]
|
|
else:
|
|
instruction = None
|
|
|
|
srlm = SelfRetrievalLM(
|
|
model_name_or_path=args.model_name_or_path,
|
|
retriever=retriever,
|
|
dtype=args.lm_dtype,
|
|
device_map=args.lm_device_map,
|
|
padding_side=args.padding_side,
|
|
cache_dir=args.model_cache_dir,
|
|
context_window_size=args.context_window_size,
|
|
chunk_size=args.chunk_size,
|
|
key_num=args.key_num,
|
|
chunk_batch_size=args.chunk_batch_size,
|
|
add_key_continuation=args.add_key_continuation,
|
|
retrieval_method=args.retrieval_method,
|
|
order_method=args.order_method,
|
|
integrate_method=args.integrate_method,
|
|
instruction=instruction,
|
|
debug_retrieval=args.debug_retrieval,
|
|
add_sep=args.add_sep,
|
|
accelerator=accelerator,
|
|
)
|
|
|
|
tokenizer = srlm.tokenizer
|
|
|
|
logging.info(f"Loading data from {args.eval_data}...")
|
|
|
|
if args.retrieval_method == "no" and args.context_max_length != args.context_window_size:
|
|
logger.warning(f"Found retrieval_method is 'no', setting context_max_length to the same as context_window_size ({args.context_window_size})!")
|
|
args.context_max_length = args.context_window_size
|
|
|
|
with accelerator.main_process_first():
|
|
dataset = datasets.load_dataset("json", data_files=args.eval_data, split="train", cache_dir=args.dataset_cache_dir)
|
|
dataset = dataset.map(process_lrlm(
|
|
tokenizer,
|
|
context_max_length=args.context_max_length,
|
|
target_length=args.target_length,
|
|
anchor_length=args.anchor_length,
|
|
), remove_columns=dataset.column_names, batched=True, batch_size=50, num_proc=64)
|
|
|
|
data_collator = DefaultDataCollator(tokenizer=tokenizer, add_position_ids=args.add_position_ids)
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=args.lm_batch_size,
|
|
collate_fn=data_collator,
|
|
pin_memory=True,
|
|
)
|
|
dataloader = accelerator.prepare(dataloader)
|
|
|
|
perplexity = srlm.compute_perplexity(dataloader)
|
|
metrics = {"perplexity": perplexity}
|
|
|
|
if accelerator.process_index == 0:
|
|
dataset = os.path.normpath(args.eval_data).split(os.sep)[-2]
|
|
log_path = os.path.join(args.log_path, f"{dataset}.log")
|
|
|
|
file_logger = FileLogger(makedirs(log_path))
|
|
file_logger.log(metrics, Args=asdict(args))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|