mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-12-30 08:42:24 +00:00
embedder ft abc
This commit is contained in:
parent
d4c2a1431c
commit
9ff1fa98fa
@ -38,6 +38,9 @@ class AbsEmbedderModelArguments:
|
||||
|
||||
@dataclass
|
||||
class AbsEmbedderDataArguments:
|
||||
"""
|
||||
Abstract class for data arguments.
|
||||
"""
|
||||
train_data: str = field(
|
||||
default=None, metadata={
|
||||
"help": "One or more paths to training data. `query: str`, `pos: List[str]`, `neg: List[str]` are required in the training data.",
|
||||
|
||||
@ -21,6 +21,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbsEmbedderTrainDataset(Dataset):
|
||||
"""Abstract class for training dataset.
|
||||
|
||||
Args:
|
||||
args (AbsEmbedderDataArguments): Data arguments.
|
||||
tokenizer (PreTrainedTokenizer): Tokenizer to use.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
args: AbsEmbedderDataArguments,
|
||||
@ -46,6 +52,17 @@ class AbsEmbedderTrainDataset(Dataset):
|
||||
self.dataset = datasets.concatenate_datasets(train_datasets)
|
||||
|
||||
def _load_dataset(self, file_path: str):
|
||||
"""Load dataset from path.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to load the datasets from.
|
||||
|
||||
Raises:
|
||||
ValueError: `pos_scores` and `neg_scores` not found in the features of training data
|
||||
|
||||
Returns:
|
||||
datasets.Dataset: Loaded HF dataset.
|
||||
"""
|
||||
if dist.get_rank() == 0:
|
||||
logger.info(f'loading data from {file_path} ...')
|
||||
|
||||
@ -63,6 +80,14 @@ class AbsEmbedderTrainDataset(Dataset):
|
||||
return temp_dataset
|
||||
|
||||
def _shuffle_text(self, text):
|
||||
"""shuffle the input text.
|
||||
|
||||
Args:
|
||||
text (str): Input text.
|
||||
|
||||
Returns:
|
||||
str: Shuffled text.
|
||||
"""
|
||||
if self.shuffle_ratio > 0 and len(text) > 100 and random.random() < self.shuffle_ratio:
|
||||
split_text = []
|
||||
chunk_size = len(text)//3 + 1
|
||||
@ -126,6 +151,9 @@ class AbsEmbedderTrainDataset(Dataset):
|
||||
|
||||
@dataclass
|
||||
class AbsEmbedderCollator(DataCollatorWithPadding):
|
||||
"""
|
||||
The abstract embedder collator.
|
||||
"""
|
||||
query_max_len: int = 32
|
||||
passage_max_len: int = 128
|
||||
sub_batch_size: int = -1
|
||||
@ -214,6 +242,16 @@ class AbsEmbedderCollator(DataCollatorWithPadding):
|
||||
|
||||
|
||||
class AbsEmbedderSameDatasetTrainDataset(AbsEmbedderTrainDataset):
|
||||
"""Abstract class for training dataset that samples batches from same dataset.
|
||||
|
||||
Args:
|
||||
args (AbsEmbedderDataArguments): Data arguments.
|
||||
default_batch_size (int): The default batch size for training.
|
||||
seed (int): Random seed.
|
||||
tokenizer (PreTrainedTokenizer): Tokenizer to use.
|
||||
process_index (int, optional): Current process index. Defaults to 0.
|
||||
num_processes (int, optional): Total number of processes. Defaults to 1.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
args: AbsEmbedderDataArguments,
|
||||
@ -296,6 +334,14 @@ class AbsEmbedderSameDatasetTrainDataset(AbsEmbedderTrainDataset):
|
||||
self.refresh_epoch()
|
||||
|
||||
def _load_dataset(self, file_path: str):
|
||||
"""Load datset from given path.
|
||||
|
||||
Args:
|
||||
file_path (str): The path to load or download from HF hub.
|
||||
|
||||
Returns:
|
||||
datasets.Dataset: The loaded dataset.
|
||||
"""
|
||||
if dist.get_rank() == 0:
|
||||
logger.info(f'loading data from {file_path} ...')
|
||||
|
||||
@ -311,6 +357,15 @@ class AbsEmbedderSameDatasetTrainDataset(AbsEmbedderTrainDataset):
|
||||
|
||||
@staticmethod
|
||||
def _get_file_batch_size(temp_dataset: datasets.Dataset, default_batch_size: int):
|
||||
"""Get the appropriate batch size for the dataset.
|
||||
|
||||
Args:
|
||||
temp_dataset (datasets.Dataset): Loaded :data:`datasets.Dataset` object.
|
||||
default_batch_size (int): The default batch size to use if not specified in the dataset.
|
||||
|
||||
Returns:
|
||||
int: The final batch size to use.
|
||||
"""
|
||||
if 'batch_size' in temp_dataset.column_names:
|
||||
return temp_dataset['batch_size'][0]
|
||||
if 'type' in temp_dataset.column_names:
|
||||
@ -320,6 +375,9 @@ class AbsEmbedderSameDatasetTrainDataset(AbsEmbedderTrainDataset):
|
||||
return default_batch_size
|
||||
|
||||
def refresh_epoch(self):
|
||||
"""
|
||||
Refresh data for epoch.
|
||||
"""
|
||||
logger.info(f'-- Rank {self.process_index}: refresh data --')
|
||||
self.deterministic_generator.shuffle(self.datasets_inxs)
|
||||
|
||||
@ -353,6 +411,15 @@ class AbsEmbedderSameDatasetTrainDataset(AbsEmbedderTrainDataset):
|
||||
return queries, passages, teacher_scores, no_in_batch_neg_flag
|
||||
|
||||
def _get_train_group_size(self, batch_raw_data):
|
||||
"""Get the training group size and data type.
|
||||
|
||||
Args:
|
||||
batch_raw_data (datasets.Dataset): One batch of raw data.
|
||||
|
||||
Returns:
|
||||
int: The training group size.
|
||||
str: The type of data for the task.
|
||||
"""
|
||||
if 'type' in batch_raw_data:
|
||||
data_type = batch_raw_data['type'][0]
|
||||
if data_type in ['only_1neg']:
|
||||
@ -362,6 +429,16 @@ class AbsEmbedderSameDatasetTrainDataset(AbsEmbedderTrainDataset):
|
||||
return self.args.train_group_size, None
|
||||
|
||||
def _create_batch_data(self, batch_raw_data):
|
||||
"""Create a comple batch of data with queries, documents and teacher scores.
|
||||
|
||||
Args:
|
||||
batch_raw_data (datasets.Dataset): One batch of raw data.
|
||||
|
||||
Returns:
|
||||
List[str]: Queries with instruction format.
|
||||
List[str]: Documents with instruction format.
|
||||
List[float]: Teacher scores for model distillation.
|
||||
"""
|
||||
queries, passages, teacher_scores = [], [], []
|
||||
|
||||
train_group_size, data_type = self._get_train_group_size(batch_raw_data)
|
||||
@ -516,6 +593,9 @@ class AbsEmbedderSameDatasetCollator(DataCollatorWithPadding):
|
||||
|
||||
|
||||
class EmbedderTrainerCallbackForDataRefresh(TrainerCallback):
|
||||
"""
|
||||
Callback class to inspect the state of the training loop and take decision.
|
||||
"""
|
||||
def __init__(self, train_dataset: AbsEmbedderSameDatasetTrainDataset):
|
||||
self.train_dataset = train_dataset
|
||||
|
||||
|
||||
@ -15,6 +15,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class EmbedderOutput(ModelOutput):
|
||||
"""
|
||||
Output information returned by the model.
|
||||
"""
|
||||
q_reps: Optional[Tensor] = None
|
||||
p_reps: Optional[Tensor] = None
|
||||
loss: Optional[Tensor] = None
|
||||
@ -22,6 +25,17 @@ class EmbedderOutput(ModelOutput):
|
||||
|
||||
|
||||
class AbsEmbedderModel(ABC, nn.Module):
|
||||
"""Abstract class of embedding model for training.
|
||||
|
||||
Args:
|
||||
base_model: The base model to train on.
|
||||
tokenizer (AutoTokenizer, optional): The tokenizer to use. Defaults to ``None``.
|
||||
negatives_cross_device (bool, optional): If True, will compute cross devices negative loss. Defaults to ``False``.
|
||||
temperature (float, optional): Temperature to control the scale of scores. Defaults to ``1.0``.
|
||||
sub_batch_size (int, optional): Sub-batch size during encoding. If negative, will not split to sub-batch.
|
||||
Defaults to ``-1``.
|
||||
kd_loss_type (str, optional): Knowledge distillation type. Defaults to ``"kl_div"``.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
base_model,
|
||||
@ -48,21 +62,53 @@ class AbsEmbedderModel(ABC, nn.Module):
|
||||
|
||||
@abstractmethod
|
||||
def encode(self, features):
|
||||
"""Abstract method encode and get the embedding.
|
||||
|
||||
Args:
|
||||
features (Union[list, dict]): Features feed to the model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_loss(self, scores, target):
|
||||
"""Abstract method compute the loss.
|
||||
|
||||
Args:
|
||||
scores (torch.Tensor): Computed score.
|
||||
target (torch.Tensor): The target value.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_score(self, q_reps, p_reps):
|
||||
"""Abstract method to compute the score.
|
||||
|
||||
Args:
|
||||
q_reps (torch.Tensor): Queries representations.
|
||||
p_reps (torch.Tensor): Passages rerpresentations.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, output_dir: str):
|
||||
"""Abstract method to save the model.
|
||||
|
||||
Args:
|
||||
output_dir (str): Directory for saving the model.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_local_score(self, q_reps, p_reps, all_scores):
|
||||
"""Get the local score of queries and passages.
|
||||
|
||||
Args:
|
||||
q_reps (torch.Tensor): Queries representations.
|
||||
p_reps (torch.Tensor): Passages rerpresentations.
|
||||
all_scores (torch.Tensor): All the query-passage scores computed.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Local scores to compute loss.
|
||||
"""
|
||||
group_size = p_reps.size(0) // q_reps.size(0)
|
||||
indices = torch.arange(0, q_reps.size(0), device=q_reps.device) * group_size
|
||||
specific_scores = []
|
||||
@ -73,6 +119,17 @@ class AbsEmbedderModel(ABC, nn.Module):
|
||||
return torch.stack(specific_scores, dim=1).view(q_reps.size(0), -1)
|
||||
|
||||
def compute_local_score(self, q_reps, p_reps, compute_score_func=None, **kwargs):
|
||||
"""Compute the local score of queries and passages.
|
||||
|
||||
Args:
|
||||
q_reps (torch.Tensor): Queries representations.
|
||||
p_reps (torch.Tensor): Passages rerpresentations.
|
||||
compute_score_func (function, optional): Function to compute score. Defaults to ``None``, which will use the
|
||||
:meth:`self.compute_score`.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Local scores to compute loss.
|
||||
"""
|
||||
if compute_score_func is None:
|
||||
all_scores = self.compute_score(q_reps, p_reps)
|
||||
else:
|
||||
@ -181,6 +238,17 @@ class AbsEmbedderModel(ABC, nn.Module):
|
||||
teacher_scores: Union[None, List[float]] = None,
|
||||
no_in_batch_neg_flag: bool = False,
|
||||
):
|
||||
"""The computation performed at every call.
|
||||
|
||||
Args:
|
||||
queries (Union[Dict[str, Tensor], List[Dict[str, Tensor]]], optional): Input queries. Defaults to ``None``.
|
||||
passages (Union[Dict[str, Tensor], List[Dict[str, Tensor]]], optional): Input passages. Defaults to ``None``.
|
||||
teacher_scores (Union[None, List[float]], optional): Teacher scores for distillation. Defaults to ``None``.
|
||||
no_in_batch_neg_flag (bool, optional): If True, use no in-batch negatives and no cross-device negatives. Defaults to ``False``.
|
||||
|
||||
Returns:
|
||||
EmbedderOutput: Output of the forward call of model.
|
||||
"""
|
||||
q_reps = self.encode(queries) # (batch_size, dim)
|
||||
p_reps = self.encode(passages) # (batch_size * group_size, dim)
|
||||
|
||||
@ -210,6 +278,20 @@ class AbsEmbedderModel(ABC, nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def distill_loss(kd_loss_type, teacher_targets, student_scores, group_size=None):
|
||||
"""Compute the distillation loss.
|
||||
|
||||
Args:
|
||||
kd_loss_type (str): Type of knowledge distillation loss, supports "kl_div" and "m3_kd_loss".
|
||||
teacher_targets (torch.Tensor): Targets from the teacher model.
|
||||
student_scores (torch.Tensor): Score of student model.
|
||||
group_size (int, optional): Number of groups for . Defaults to ``None``.
|
||||
|
||||
Raises:
|
||||
ValueError: Invalid kd_loss_type
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A scalar of computed distillation loss.
|
||||
"""
|
||||
if kd_loss_type == 'kl_div':
|
||||
# teacher_targets: (batch_size, group_size) / (world_size * batch_size, group_size)
|
||||
# student_scores: (batch_size, group_size) / (world_size * batch_size, group_size)
|
||||
@ -236,6 +318,15 @@ class AbsEmbedderModel(ABC, nn.Module):
|
||||
raise ValueError(f"Invalid kd_loss_type: {kd_loss_type}")
|
||||
|
||||
def _dist_gather_tensor(self, t: Optional[torch.Tensor]):
|
||||
"""Gather a tensor from all processes in a distributed setting.
|
||||
|
||||
Args:
|
||||
t (Optional[torch.Tensor]): The input tensor to be gathered. If `None`, no gathering is performed.
|
||||
|
||||
Returns:
|
||||
Union[torch.Tensor, None]: A concatenated tensor from all processes if ``t`` is not ``None``,
|
||||
otherwise returns ``None``.
|
||||
"""
|
||||
if t is None:
|
||||
return None
|
||||
t = t.contiguous()
|
||||
|
||||
@ -22,6 +22,13 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbsEmbedderRunner(ABC):
|
||||
"""Abstract class to run embedding model fine-tuning.
|
||||
|
||||
Args:
|
||||
model_args (AbsEmbedderModelArguments): Model arguments
|
||||
data_args (AbsEmbedderDataArguments): Data arguments.
|
||||
training_args (AbsEmbedderTrainingArguments): Training arguments.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model_args: AbsEmbedderModelArguments,
|
||||
@ -70,13 +77,28 @@ class AbsEmbedderRunner(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def load_tokenizer_and_model(self) -> Tuple[PreTrainedTokenizer, AbsEmbedderModel]:
|
||||
"""Abstract method to load the tokenizer and model.
|
||||
|
||||
Returns:
|
||||
Tuple[PreTrainedTokenizer, AbsEmbedderModel]: Loaded tokenizer and model instances.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_trainer(self) -> AbsEmbedderTrainer:
|
||||
"""Abstract method to load the trainer.
|
||||
|
||||
Returns:
|
||||
AbsEmbedderTrainer: The loaded trainer instance.
|
||||
"""
|
||||
pass
|
||||
|
||||
def load_train_dataset(self) -> AbsEmbedderTrainDataset:
|
||||
"""Loads the training dataset based on data arguments.
|
||||
|
||||
Returns:
|
||||
AbsEmbedderTrainDataset: The loaded dataset instance.
|
||||
"""
|
||||
if self.data_args.same_dataset_within_batch:
|
||||
train_dataset = AbsEmbedderSameDatasetTrainDataset(
|
||||
args=self.data_args,
|
||||
@ -96,6 +118,11 @@ class AbsEmbedderRunner(ABC):
|
||||
return train_dataset
|
||||
|
||||
def load_data_collator(self) -> AbsEmbedderCollator:
|
||||
"""Loads the appropriate data collator.
|
||||
|
||||
Returns:
|
||||
AbsEmbedderCollator: Loaded data collator.
|
||||
"""
|
||||
if self.data_args.same_dataset_within_batch:
|
||||
EmbedCollator = AbsEmbedderSameDatasetCollator
|
||||
else:
|
||||
@ -113,6 +140,9 @@ class AbsEmbedderRunner(ABC):
|
||||
return data_collator
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
Executes the training process.
|
||||
"""
|
||||
Path(self.training_args.output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Training
|
||||
|
||||
@ -7,6 +7,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbsEmbedderTrainer(ABC, Trainer):
|
||||
"""
|
||||
Abstract class for the trainer of embedder.
|
||||
"""
|
||||
@abstractmethod
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
||||
pass
|
||||
@ -16,6 +19,16 @@ class AbsEmbedderTrainer(ABC, Trainer):
|
||||
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
||||
|
||||
Subclass and override for custom behavior.
|
||||
|
||||
Args:
|
||||
model (AbsEmbedderModel): The model being trained.
|
||||
inputs (dict): A dictionary of input tensors to be passed to the model.
|
||||
return_outputs (bool, optional): If ``True``, returns both the loss and the model's outputs. Otherwise,
|
||||
returns only the loss.
|
||||
|
||||
Returns:
|
||||
Union[torch.Tensor, tuple(torch.Tensor, ModelOutput)]: The computed loss. If ``return_outputs`` is ``True``,
|
||||
also returns the model's outputs in a tuple ``(loss, outputs)``.
|
||||
"""
|
||||
|
||||
outputs = model(**inputs)
|
||||
|
||||
@ -175,10 +175,11 @@ class BaseLLMReranker(AbsReranker):
|
||||
Args:
|
||||
model_name_or_path (str): If it's a path to a local model, it loads the model from the path. Otherwise tries to download and
|
||||
load a model from HuggingFace Hub with the name.
|
||||
peft_path (Optional[str], optional): _description_. Defaults to :data:`None`.
|
||||
peft_path (Optional[str], optional): Path to the PEFT config. Defaults to :data:`None`.
|
||||
use_fp16 (bool, optional): If true, use half-precision floating-point to speed up computation with a slight performance
|
||||
degradation. Defaults to :data:`False`. Defaults to :data:`False`.
|
||||
use_bf16 (bool, optional): _description_. Defaults to :data:False.
|
||||
use_bf16 (bool, optional): Another type of half-precision floating-point, you can use bf16 if the hardware supports.
|
||||
Defaults to :data:False.
|
||||
query_instruction_for_rerank (str, optional): Query instruction for retrieval tasks, which will be used with
|
||||
with :attr:`query_instruction_format`. Defaults to :data:`"A: "`.
|
||||
query_instruction_format (str, optional): The template for :attr:`query_instruction_for_rerank`. Defaults to :data:`"{}{}"`.
|
||||
@ -248,7 +249,7 @@ class BaseLLMReranker(AbsReranker):
|
||||
torch_dtype=torch.bfloat16 if use_bf16 else torch.float32
|
||||
)
|
||||
if peft_path:
|
||||
self.model = PeftModel.from_pretrained(self.model,peft_path)
|
||||
self.model = PeftModel.from_pretrained(self.model, peft_path)
|
||||
self.model = self.model.merge_and_unload()
|
||||
|
||||
self.yes_loc = self.tokenizer('Yes', add_special_tokens=False)['input_ids'][0]
|
||||
|
||||
@ -41,10 +41,11 @@ class LayerWiseLLMReranker(AbsReranker):
|
||||
Args:
|
||||
model_name_or_path (str): If it's a path to a local model, it loads the model from the path. Otherwise tries to download and
|
||||
load a model from HuggingFace Hub with the name.
|
||||
peft_path (Optional[str], optional): _description_. Defaults to :data:`None`.
|
||||
peft_path (Optional[str], optional): Path to the PEFT config. Defaults to :data:`None`.
|
||||
use_fp16 (bool, optional): If true, use half-precision floating-point to speed up computation with a slight performance
|
||||
degradation. Defaults to :data:`False`. Defaults to :data:`False`.
|
||||
use_bf16 (bool, optional): _description_. Defaults to :data:False.
|
||||
use_bf16 (bool, optional): Another type of half-precision floating-point, you can use bf16 if the hardware supports.
|
||||
Defaults to :data:False.
|
||||
query_instruction_for_rerank (str, optional): Query instruction for retrieval tasks, which will be used with
|
||||
with :attr:`query_instruction_format`. Defaults to :data:`"A: "`.
|
||||
query_instruction_format (str, optional): The template for :attr:`query_instruction_for_rerank`. Defaults to :data:`"{}{}"`.
|
||||
|
||||
@ -95,10 +95,11 @@ class LightweightLLMReranker(AbsReranker):
|
||||
Args:
|
||||
model_name_or_path (str): If it's a path to a local model, it loads the model from the path. Otherwise tries to download and
|
||||
load a model from HuggingFace Hub with the name.
|
||||
peft_path (Optional[str], optional): _description_. Defaults to :data:`None`.
|
||||
peft_path (Optional[str], optional): Path to the PEFT config. Defaults to :data:`None`.
|
||||
use_fp16 (bool, optional): If true, use half-precision floating-point to speed up computation with a slight performance
|
||||
degradation. Defaults to :data:`False`. Defaults to :data:`False`.
|
||||
use_bf16 (bool, optional): _description_. Defaults to :data:False.
|
||||
use_bf16 (bool, optional): Another type of half-precision floating-point, you can use bf16 if the hardware supports.
|
||||
Defaults to :data:False.
|
||||
query_instruction_for_rerank (str, optional): Query instruction for retrieval tasks, which will be used with
|
||||
with :attr:`query_instruction_format`. Defaults to :data:`"A: "`.
|
||||
query_instruction_format (str, optional): The template for :attr:`query_instruction_for_rerank`. Defaults to :data:`"{}{}"`.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user