embedder ft abc

This commit is contained in:
ZiyiXia 2024-11-05 08:42:51 +00:00
parent d4c2a1431c
commit 9ff1fa98fa
8 changed files with 227 additions and 7 deletions

View File

@ -38,6 +38,9 @@ class AbsEmbedderModelArguments:
@dataclass
class AbsEmbedderDataArguments:
"""
Abstract class for data arguments.
"""
train_data: str = field(
default=None, metadata={
"help": "One or more paths to training data. `query: str`, `pos: List[str]`, `neg: List[str]` are required in the training data.",

View File

@ -21,6 +21,12 @@ logger = logging.getLogger(__name__)
class AbsEmbedderTrainDataset(Dataset):
"""Abstract class for training dataset.
Args:
args (AbsEmbedderDataArguments): Data arguments.
tokenizer (PreTrainedTokenizer): Tokenizer to use.
"""
def __init__(
self,
args: AbsEmbedderDataArguments,
@ -46,6 +52,17 @@ class AbsEmbedderTrainDataset(Dataset):
self.dataset = datasets.concatenate_datasets(train_datasets)
def _load_dataset(self, file_path: str):
"""Load dataset from path.
Args:
file_path (str): Path to load the datasets from.
Raises:
ValueError: `pos_scores` and `neg_scores` not found in the features of training data
Returns:
datasets.Dataset: Loaded HF dataset.
"""
if dist.get_rank() == 0:
logger.info(f'loading data from {file_path} ...')
@ -63,6 +80,14 @@ class AbsEmbedderTrainDataset(Dataset):
return temp_dataset
def _shuffle_text(self, text):
"""shuffle the input text.
Args:
text (str): Input text.
Returns:
str: Shuffled text.
"""
if self.shuffle_ratio > 0 and len(text) > 100 and random.random() < self.shuffle_ratio:
split_text = []
chunk_size = len(text)//3 + 1
@ -126,6 +151,9 @@ class AbsEmbedderTrainDataset(Dataset):
@dataclass
class AbsEmbedderCollator(DataCollatorWithPadding):
"""
The abstract embedder collator.
"""
query_max_len: int = 32
passage_max_len: int = 128
sub_batch_size: int = -1
@ -214,6 +242,16 @@ class AbsEmbedderCollator(DataCollatorWithPadding):
class AbsEmbedderSameDatasetTrainDataset(AbsEmbedderTrainDataset):
"""Abstract class for training dataset that samples batches from same dataset.
Args:
args (AbsEmbedderDataArguments): Data arguments.
default_batch_size (int): The default batch size for training.
seed (int): Random seed.
tokenizer (PreTrainedTokenizer): Tokenizer to use.
process_index (int, optional): Current process index. Defaults to 0.
num_processes (int, optional): Total number of processes. Defaults to 1.
"""
def __init__(
self,
args: AbsEmbedderDataArguments,
@ -296,6 +334,14 @@ class AbsEmbedderSameDatasetTrainDataset(AbsEmbedderTrainDataset):
self.refresh_epoch()
def _load_dataset(self, file_path: str):
"""Load datset from given path.
Args:
file_path (str): The path to load or download from HF hub.
Returns:
datasets.Dataset: The loaded dataset.
"""
if dist.get_rank() == 0:
logger.info(f'loading data from {file_path} ...')
@ -311,6 +357,15 @@ class AbsEmbedderSameDatasetTrainDataset(AbsEmbedderTrainDataset):
@staticmethod
def _get_file_batch_size(temp_dataset: datasets.Dataset, default_batch_size: int):
"""Get the appropriate batch size for the dataset.
Args:
temp_dataset (datasets.Dataset): Loaded :data:`datasets.Dataset` object.
default_batch_size (int): The default batch size to use if not specified in the dataset.
Returns:
int: The final batch size to use.
"""
if 'batch_size' in temp_dataset.column_names:
return temp_dataset['batch_size'][0]
if 'type' in temp_dataset.column_names:
@ -320,6 +375,9 @@ class AbsEmbedderSameDatasetTrainDataset(AbsEmbedderTrainDataset):
return default_batch_size
def refresh_epoch(self):
"""
Refresh data for epoch.
"""
logger.info(f'-- Rank {self.process_index}: refresh data --')
self.deterministic_generator.shuffle(self.datasets_inxs)
@ -353,6 +411,15 @@ class AbsEmbedderSameDatasetTrainDataset(AbsEmbedderTrainDataset):
return queries, passages, teacher_scores, no_in_batch_neg_flag
def _get_train_group_size(self, batch_raw_data):
"""Get the training group size and data type.
Args:
batch_raw_data (datasets.Dataset): One batch of raw data.
Returns:
int: The training group size.
str: The type of data for the task.
"""
if 'type' in batch_raw_data:
data_type = batch_raw_data['type'][0]
if data_type in ['only_1neg']:
@ -362,6 +429,16 @@ class AbsEmbedderSameDatasetTrainDataset(AbsEmbedderTrainDataset):
return self.args.train_group_size, None
def _create_batch_data(self, batch_raw_data):
"""Create a comple batch of data with queries, documents and teacher scores.
Args:
batch_raw_data (datasets.Dataset): One batch of raw data.
Returns:
List[str]: Queries with instruction format.
List[str]: Documents with instruction format.
List[float]: Teacher scores for model distillation.
"""
queries, passages, teacher_scores = [], [], []
train_group_size, data_type = self._get_train_group_size(batch_raw_data)
@ -516,6 +593,9 @@ class AbsEmbedderSameDatasetCollator(DataCollatorWithPadding):
class EmbedderTrainerCallbackForDataRefresh(TrainerCallback):
"""
Callback class to inspect the state of the training loop and take decision.
"""
def __init__(self, train_dataset: AbsEmbedderSameDatasetTrainDataset):
self.train_dataset = train_dataset

View File

@ -15,6 +15,9 @@ logger = logging.getLogger(__name__)
@dataclass
class EmbedderOutput(ModelOutput):
"""
Output information returned by the model.
"""
q_reps: Optional[Tensor] = None
p_reps: Optional[Tensor] = None
loss: Optional[Tensor] = None
@ -22,6 +25,17 @@ class EmbedderOutput(ModelOutput):
class AbsEmbedderModel(ABC, nn.Module):
"""Abstract class of embedding model for training.
Args:
base_model: The base model to train on.
tokenizer (AutoTokenizer, optional): The tokenizer to use. Defaults to ``None``.
negatives_cross_device (bool, optional): If True, will compute cross devices negative loss. Defaults to ``False``.
temperature (float, optional): Temperature to control the scale of scores. Defaults to ``1.0``.
sub_batch_size (int, optional): Sub-batch size during encoding. If negative, will not split to sub-batch.
Defaults to ``-1``.
kd_loss_type (str, optional): Knowledge distillation type. Defaults to ``"kl_div"``.
"""
def __init__(
self,
base_model,
@ -48,21 +62,53 @@ class AbsEmbedderModel(ABC, nn.Module):
@abstractmethod
def encode(self, features):
"""Abstract method encode and get the embedding.
Args:
features (Union[list, dict]): Features feed to the model.
"""
pass
@abstractmethod
def compute_loss(self, scores, target):
"""Abstract method compute the loss.
Args:
scores (torch.Tensor): Computed score.
target (torch.Tensor): The target value.
"""
pass
@abstractmethod
def compute_score(self, q_reps, p_reps):
"""Abstract method to compute the score.
Args:
q_reps (torch.Tensor): Queries representations.
p_reps (torch.Tensor): Passages rerpresentations.
"""
pass
@abstractmethod
def save(self, output_dir: str):
"""Abstract method to save the model.
Args:
output_dir (str): Directory for saving the model.
"""
pass
def get_local_score(self, q_reps, p_reps, all_scores):
"""Get the local score of queries and passages.
Args:
q_reps (torch.Tensor): Queries representations.
p_reps (torch.Tensor): Passages rerpresentations.
all_scores (torch.Tensor): All the query-passage scores computed.
Returns:
torch.Tensor: Local scores to compute loss.
"""
group_size = p_reps.size(0) // q_reps.size(0)
indices = torch.arange(0, q_reps.size(0), device=q_reps.device) * group_size
specific_scores = []
@ -73,6 +119,17 @@ class AbsEmbedderModel(ABC, nn.Module):
return torch.stack(specific_scores, dim=1).view(q_reps.size(0), -1)
def compute_local_score(self, q_reps, p_reps, compute_score_func=None, **kwargs):
"""Compute the local score of queries and passages.
Args:
q_reps (torch.Tensor): Queries representations.
p_reps (torch.Tensor): Passages rerpresentations.
compute_score_func (function, optional): Function to compute score. Defaults to ``None``, which will use the
:meth:`self.compute_score`.
Returns:
torch.Tensor: Local scores to compute loss.
"""
if compute_score_func is None:
all_scores = self.compute_score(q_reps, p_reps)
else:
@ -181,6 +238,17 @@ class AbsEmbedderModel(ABC, nn.Module):
teacher_scores: Union[None, List[float]] = None,
no_in_batch_neg_flag: bool = False,
):
"""The computation performed at every call.
Args:
queries (Union[Dict[str, Tensor], List[Dict[str, Tensor]]], optional): Input queries. Defaults to ``None``.
passages (Union[Dict[str, Tensor], List[Dict[str, Tensor]]], optional): Input passages. Defaults to ``None``.
teacher_scores (Union[None, List[float]], optional): Teacher scores for distillation. Defaults to ``None``.
no_in_batch_neg_flag (bool, optional): If True, use no in-batch negatives and no cross-device negatives. Defaults to ``False``.
Returns:
EmbedderOutput: Output of the forward call of model.
"""
q_reps = self.encode(queries) # (batch_size, dim)
p_reps = self.encode(passages) # (batch_size * group_size, dim)
@ -210,6 +278,20 @@ class AbsEmbedderModel(ABC, nn.Module):
@staticmethod
def distill_loss(kd_loss_type, teacher_targets, student_scores, group_size=None):
"""Compute the distillation loss.
Args:
kd_loss_type (str): Type of knowledge distillation loss, supports "kl_div" and "m3_kd_loss".
teacher_targets (torch.Tensor): Targets from the teacher model.
student_scores (torch.Tensor): Score of student model.
group_size (int, optional): Number of groups for . Defaults to ``None``.
Raises:
ValueError: Invalid kd_loss_type
Returns:
torch.Tensor: A scalar of computed distillation loss.
"""
if kd_loss_type == 'kl_div':
# teacher_targets: (batch_size, group_size) / (world_size * batch_size, group_size)
# student_scores: (batch_size, group_size) / (world_size * batch_size, group_size)
@ -236,6 +318,15 @@ class AbsEmbedderModel(ABC, nn.Module):
raise ValueError(f"Invalid kd_loss_type: {kd_loss_type}")
def _dist_gather_tensor(self, t: Optional[torch.Tensor]):
"""Gather a tensor from all processes in a distributed setting.
Args:
t (Optional[torch.Tensor]): The input tensor to be gathered. If `None`, no gathering is performed.
Returns:
Union[torch.Tensor, None]: A concatenated tensor from all processes if ``t`` is not ``None``,
otherwise returns ``None``.
"""
if t is None:
return None
t = t.contiguous()

View File

@ -22,6 +22,13 @@ logger = logging.getLogger(__name__)
class AbsEmbedderRunner(ABC):
"""Abstract class to run embedding model fine-tuning.
Args:
model_args (AbsEmbedderModelArguments): Model arguments
data_args (AbsEmbedderDataArguments): Data arguments.
training_args (AbsEmbedderTrainingArguments): Training arguments.
"""
def __init__(
self,
model_args: AbsEmbedderModelArguments,
@ -70,13 +77,28 @@ class AbsEmbedderRunner(ABC):
@abstractmethod
def load_tokenizer_and_model(self) -> Tuple[PreTrainedTokenizer, AbsEmbedderModel]:
"""Abstract method to load the tokenizer and model.
Returns:
Tuple[PreTrainedTokenizer, AbsEmbedderModel]: Loaded tokenizer and model instances.
"""
pass
@abstractmethod
def load_trainer(self) -> AbsEmbedderTrainer:
"""Abstract method to load the trainer.
Returns:
AbsEmbedderTrainer: The loaded trainer instance.
"""
pass
def load_train_dataset(self) -> AbsEmbedderTrainDataset:
"""Loads the training dataset based on data arguments.
Returns:
AbsEmbedderTrainDataset: The loaded dataset instance.
"""
if self.data_args.same_dataset_within_batch:
train_dataset = AbsEmbedderSameDatasetTrainDataset(
args=self.data_args,
@ -96,6 +118,11 @@ class AbsEmbedderRunner(ABC):
return train_dataset
def load_data_collator(self) -> AbsEmbedderCollator:
"""Loads the appropriate data collator.
Returns:
AbsEmbedderCollator: Loaded data collator.
"""
if self.data_args.same_dataset_within_batch:
EmbedCollator = AbsEmbedderSameDatasetCollator
else:
@ -113,6 +140,9 @@ class AbsEmbedderRunner(ABC):
return data_collator
def run(self):
"""
Executes the training process.
"""
Path(self.training_args.output_dir).mkdir(parents=True, exist_ok=True)
# Training

View File

@ -7,6 +7,9 @@ logger = logging.getLogger(__name__)
class AbsEmbedderTrainer(ABC, Trainer):
"""
Abstract class for the trainer of embedder.
"""
@abstractmethod
def _save(self, output_dir: Optional[str] = None, state_dict=None):
pass
@ -16,6 +19,16 @@ class AbsEmbedderTrainer(ABC, Trainer):
How the loss is computed by Trainer. By default, all models return the loss in the first element.
Subclass and override for custom behavior.
Args:
model (AbsEmbedderModel): The model being trained.
inputs (dict): A dictionary of input tensors to be passed to the model.
return_outputs (bool, optional): If ``True``, returns both the loss and the model's outputs. Otherwise,
returns only the loss.
Returns:
Union[torch.Tensor, tuple(torch.Tensor, ModelOutput)]: The computed loss. If ``return_outputs`` is ``True``,
also returns the model's outputs in a tuple ``(loss, outputs)``.
"""
outputs = model(**inputs)

View File

@ -175,10 +175,11 @@ class BaseLLMReranker(AbsReranker):
Args:
model_name_or_path (str): If it's a path to a local model, it loads the model from the path. Otherwise tries to download and
load a model from HuggingFace Hub with the name.
peft_path (Optional[str], optional): _description_. Defaults to :data:`None`.
peft_path (Optional[str], optional): Path to the PEFT config. Defaults to :data:`None`.
use_fp16 (bool, optional): If true, use half-precision floating-point to speed up computation with a slight performance
degradation. Defaults to :data:`False`. Defaults to :data:`False`.
use_bf16 (bool, optional): _description_. Defaults to :data:False.
use_bf16 (bool, optional): Another type of half-precision floating-point, you can use bf16 if the hardware supports.
Defaults to :data:False.
query_instruction_for_rerank (str, optional): Query instruction for retrieval tasks, which will be used with
with :attr:`query_instruction_format`. Defaults to :data:`"A: "`.
query_instruction_format (str, optional): The template for :attr:`query_instruction_for_rerank`. Defaults to :data:`"{}{}"`.
@ -248,7 +249,7 @@ class BaseLLMReranker(AbsReranker):
torch_dtype=torch.bfloat16 if use_bf16 else torch.float32
)
if peft_path:
self.model = PeftModel.from_pretrained(self.model,peft_path)
self.model = PeftModel.from_pretrained(self.model, peft_path)
self.model = self.model.merge_and_unload()
self.yes_loc = self.tokenizer('Yes', add_special_tokens=False)['input_ids'][0]

View File

@ -41,10 +41,11 @@ class LayerWiseLLMReranker(AbsReranker):
Args:
model_name_or_path (str): If it's a path to a local model, it loads the model from the path. Otherwise tries to download and
load a model from HuggingFace Hub with the name.
peft_path (Optional[str], optional): _description_. Defaults to :data:`None`.
peft_path (Optional[str], optional): Path to the PEFT config. Defaults to :data:`None`.
use_fp16 (bool, optional): If true, use half-precision floating-point to speed up computation with a slight performance
degradation. Defaults to :data:`False`. Defaults to :data:`False`.
use_bf16 (bool, optional): _description_. Defaults to :data:False.
use_bf16 (bool, optional): Another type of half-precision floating-point, you can use bf16 if the hardware supports.
Defaults to :data:False.
query_instruction_for_rerank (str, optional): Query instruction for retrieval tasks, which will be used with
with :attr:`query_instruction_format`. Defaults to :data:`"A: "`.
query_instruction_format (str, optional): The template for :attr:`query_instruction_for_rerank`. Defaults to :data:`"{}{}"`.

View File

@ -95,10 +95,11 @@ class LightweightLLMReranker(AbsReranker):
Args:
model_name_or_path (str): If it's a path to a local model, it loads the model from the path. Otherwise tries to download and
load a model from HuggingFace Hub with the name.
peft_path (Optional[str], optional): _description_. Defaults to :data:`None`.
peft_path (Optional[str], optional): Path to the PEFT config. Defaults to :data:`None`.
use_fp16 (bool, optional): If true, use half-precision floating-point to speed up computation with a slight performance
degradation. Defaults to :data:`False`. Defaults to :data:`False`.
use_bf16 (bool, optional): _description_. Defaults to :data:False.
use_bf16 (bool, optional): Another type of half-precision floating-point, you can use bf16 if the hardware supports.
Defaults to :data:False.
query_instruction_for_rerank (str, optional): Query instruction for retrieval tasks, which will be used with
with :attr:`query_instruction_format`. Defaults to :data:`"A: "`.
query_instruction_format (str, optional): The template for :attr:`query_instruction_for_rerank`. Defaults to :data:`"{}{}"`.