diff --git a/FlagEmbedding/abc/finetune/embedder/AbsArguments.py b/FlagEmbedding/abc/finetune/embedder/AbsArguments.py index 486dc00..d636542 100644 --- a/FlagEmbedding/abc/finetune/embedder/AbsArguments.py +++ b/FlagEmbedding/abc/finetune/embedder/AbsArguments.py @@ -38,6 +38,9 @@ class AbsEmbedderModelArguments: @dataclass class AbsEmbedderDataArguments: + """ + Abstract class for data arguments. + """ train_data: str = field( default=None, metadata={ "help": "One or more paths to training data. `query: str`, `pos: List[str]`, `neg: List[str]` are required in the training data.", diff --git a/FlagEmbedding/abc/finetune/embedder/AbsDataset.py b/FlagEmbedding/abc/finetune/embedder/AbsDataset.py index 7b5861c..e3126ec 100644 --- a/FlagEmbedding/abc/finetune/embedder/AbsDataset.py +++ b/FlagEmbedding/abc/finetune/embedder/AbsDataset.py @@ -21,6 +21,12 @@ logger = logging.getLogger(__name__) class AbsEmbedderTrainDataset(Dataset): + """Abstract class for training dataset. + + Args: + args (AbsEmbedderDataArguments): Data arguments. + tokenizer (PreTrainedTokenizer): Tokenizer to use. + """ def __init__( self, args: AbsEmbedderDataArguments, @@ -46,6 +52,17 @@ class AbsEmbedderTrainDataset(Dataset): self.dataset = datasets.concatenate_datasets(train_datasets) def _load_dataset(self, file_path: str): + """Load dataset from path. + + Args: + file_path (str): Path to load the datasets from. + + Raises: + ValueError: `pos_scores` and `neg_scores` not found in the features of training data + + Returns: + datasets.Dataset: Loaded HF dataset. + """ if dist.get_rank() == 0: logger.info(f'loading data from {file_path} ...') @@ -63,6 +80,14 @@ class AbsEmbedderTrainDataset(Dataset): return temp_dataset def _shuffle_text(self, text): + """shuffle the input text. + + Args: + text (str): Input text. + + Returns: + str: Shuffled text. + """ if self.shuffle_ratio > 0 and len(text) > 100 and random.random() < self.shuffle_ratio: split_text = [] chunk_size = len(text)//3 + 1 @@ -126,6 +151,9 @@ class AbsEmbedderTrainDataset(Dataset): @dataclass class AbsEmbedderCollator(DataCollatorWithPadding): + """ + The abstract embedder collator. + """ query_max_len: int = 32 passage_max_len: int = 128 sub_batch_size: int = -1 @@ -214,6 +242,16 @@ class AbsEmbedderCollator(DataCollatorWithPadding): class AbsEmbedderSameDatasetTrainDataset(AbsEmbedderTrainDataset): + """Abstract class for training dataset that samples batches from same dataset. + + Args: + args (AbsEmbedderDataArguments): Data arguments. + default_batch_size (int): The default batch size for training. + seed (int): Random seed. + tokenizer (PreTrainedTokenizer): Tokenizer to use. + process_index (int, optional): Current process index. Defaults to 0. + num_processes (int, optional): Total number of processes. Defaults to 1. + """ def __init__( self, args: AbsEmbedderDataArguments, @@ -296,6 +334,14 @@ class AbsEmbedderSameDatasetTrainDataset(AbsEmbedderTrainDataset): self.refresh_epoch() def _load_dataset(self, file_path: str): + """Load datset from given path. + + Args: + file_path (str): The path to load or download from HF hub. + + Returns: + datasets.Dataset: The loaded dataset. + """ if dist.get_rank() == 0: logger.info(f'loading data from {file_path} ...') @@ -311,6 +357,15 @@ class AbsEmbedderSameDatasetTrainDataset(AbsEmbedderTrainDataset): @staticmethod def _get_file_batch_size(temp_dataset: datasets.Dataset, default_batch_size: int): + """Get the appropriate batch size for the dataset. + + Args: + temp_dataset (datasets.Dataset): Loaded :data:`datasets.Dataset` object. + default_batch_size (int): The default batch size to use if not specified in the dataset. + + Returns: + int: The final batch size to use. + """ if 'batch_size' in temp_dataset.column_names: return temp_dataset['batch_size'][0] if 'type' in temp_dataset.column_names: @@ -320,6 +375,9 @@ class AbsEmbedderSameDatasetTrainDataset(AbsEmbedderTrainDataset): return default_batch_size def refresh_epoch(self): + """ + Refresh data for epoch. + """ logger.info(f'-- Rank {self.process_index}: refresh data --') self.deterministic_generator.shuffle(self.datasets_inxs) @@ -353,6 +411,15 @@ class AbsEmbedderSameDatasetTrainDataset(AbsEmbedderTrainDataset): return queries, passages, teacher_scores, no_in_batch_neg_flag def _get_train_group_size(self, batch_raw_data): + """Get the training group size and data type. + + Args: + batch_raw_data (datasets.Dataset): One batch of raw data. + + Returns: + int: The training group size. + str: The type of data for the task. + """ if 'type' in batch_raw_data: data_type = batch_raw_data['type'][0] if data_type in ['only_1neg']: @@ -362,6 +429,16 @@ class AbsEmbedderSameDatasetTrainDataset(AbsEmbedderTrainDataset): return self.args.train_group_size, None def _create_batch_data(self, batch_raw_data): + """Create a comple batch of data with queries, documents and teacher scores. + + Args: + batch_raw_data (datasets.Dataset): One batch of raw data. + + Returns: + List[str]: Queries with instruction format. + List[str]: Documents with instruction format. + List[float]: Teacher scores for model distillation. + """ queries, passages, teacher_scores = [], [], [] train_group_size, data_type = self._get_train_group_size(batch_raw_data) @@ -516,6 +593,9 @@ class AbsEmbedderSameDatasetCollator(DataCollatorWithPadding): class EmbedderTrainerCallbackForDataRefresh(TrainerCallback): + """ + Callback class to inspect the state of the training loop and take decision. + """ def __init__(self, train_dataset: AbsEmbedderSameDatasetTrainDataset): self.train_dataset = train_dataset diff --git a/FlagEmbedding/abc/finetune/embedder/AbsModeling.py b/FlagEmbedding/abc/finetune/embedder/AbsModeling.py index d3d2aac..c08806f 100644 --- a/FlagEmbedding/abc/finetune/embedder/AbsModeling.py +++ b/FlagEmbedding/abc/finetune/embedder/AbsModeling.py @@ -15,6 +15,9 @@ logger = logging.getLogger(__name__) @dataclass class EmbedderOutput(ModelOutput): + """ + Output information returned by the model. + """ q_reps: Optional[Tensor] = None p_reps: Optional[Tensor] = None loss: Optional[Tensor] = None @@ -22,6 +25,17 @@ class EmbedderOutput(ModelOutput): class AbsEmbedderModel(ABC, nn.Module): + """Abstract class of embedding model for training. + + Args: + base_model: The base model to train on. + tokenizer (AutoTokenizer, optional): The tokenizer to use. Defaults to ``None``. + negatives_cross_device (bool, optional): If True, will compute cross devices negative loss. Defaults to ``False``. + temperature (float, optional): Temperature to control the scale of scores. Defaults to ``1.0``. + sub_batch_size (int, optional): Sub-batch size during encoding. If negative, will not split to sub-batch. + Defaults to ``-1``. + kd_loss_type (str, optional): Knowledge distillation type. Defaults to ``"kl_div"``. + """ def __init__( self, base_model, @@ -48,21 +62,53 @@ class AbsEmbedderModel(ABC, nn.Module): @abstractmethod def encode(self, features): + """Abstract method encode and get the embedding. + + Args: + features (Union[list, dict]): Features feed to the model. + """ pass @abstractmethod def compute_loss(self, scores, target): + """Abstract method compute the loss. + + Args: + scores (torch.Tensor): Computed score. + target (torch.Tensor): The target value. + """ pass @abstractmethod def compute_score(self, q_reps, p_reps): + """Abstract method to compute the score. + + Args: + q_reps (torch.Tensor): Queries representations. + p_reps (torch.Tensor): Passages rerpresentations. + """ pass @abstractmethod def save(self, output_dir: str): + """Abstract method to save the model. + + Args: + output_dir (str): Directory for saving the model. + """ pass def get_local_score(self, q_reps, p_reps, all_scores): + """Get the local score of queries and passages. + + Args: + q_reps (torch.Tensor): Queries representations. + p_reps (torch.Tensor): Passages rerpresentations. + all_scores (torch.Tensor): All the query-passage scores computed. + + Returns: + torch.Tensor: Local scores to compute loss. + """ group_size = p_reps.size(0) // q_reps.size(0) indices = torch.arange(0, q_reps.size(0), device=q_reps.device) * group_size specific_scores = [] @@ -73,6 +119,17 @@ class AbsEmbedderModel(ABC, nn.Module): return torch.stack(specific_scores, dim=1).view(q_reps.size(0), -1) def compute_local_score(self, q_reps, p_reps, compute_score_func=None, **kwargs): + """Compute the local score of queries and passages. + + Args: + q_reps (torch.Tensor): Queries representations. + p_reps (torch.Tensor): Passages rerpresentations. + compute_score_func (function, optional): Function to compute score. Defaults to ``None``, which will use the + :meth:`self.compute_score`. + + Returns: + torch.Tensor: Local scores to compute loss. + """ if compute_score_func is None: all_scores = self.compute_score(q_reps, p_reps) else: @@ -181,6 +238,17 @@ class AbsEmbedderModel(ABC, nn.Module): teacher_scores: Union[None, List[float]] = None, no_in_batch_neg_flag: bool = False, ): + """The computation performed at every call. + + Args: + queries (Union[Dict[str, Tensor], List[Dict[str, Tensor]]], optional): Input queries. Defaults to ``None``. + passages (Union[Dict[str, Tensor], List[Dict[str, Tensor]]], optional): Input passages. Defaults to ``None``. + teacher_scores (Union[None, List[float]], optional): Teacher scores for distillation. Defaults to ``None``. + no_in_batch_neg_flag (bool, optional): If True, use no in-batch negatives and no cross-device negatives. Defaults to ``False``. + + Returns: + EmbedderOutput: Output of the forward call of model. + """ q_reps = self.encode(queries) # (batch_size, dim) p_reps = self.encode(passages) # (batch_size * group_size, dim) @@ -210,6 +278,20 @@ class AbsEmbedderModel(ABC, nn.Module): @staticmethod def distill_loss(kd_loss_type, teacher_targets, student_scores, group_size=None): + """Compute the distillation loss. + + Args: + kd_loss_type (str): Type of knowledge distillation loss, supports "kl_div" and "m3_kd_loss". + teacher_targets (torch.Tensor): Targets from the teacher model. + student_scores (torch.Tensor): Score of student model. + group_size (int, optional): Number of groups for . Defaults to ``None``. + + Raises: + ValueError: Invalid kd_loss_type + + Returns: + torch.Tensor: A scalar of computed distillation loss. + """ if kd_loss_type == 'kl_div': # teacher_targets: (batch_size, group_size) / (world_size * batch_size, group_size) # student_scores: (batch_size, group_size) / (world_size * batch_size, group_size) @@ -236,6 +318,15 @@ class AbsEmbedderModel(ABC, nn.Module): raise ValueError(f"Invalid kd_loss_type: {kd_loss_type}") def _dist_gather_tensor(self, t: Optional[torch.Tensor]): + """Gather a tensor from all processes in a distributed setting. + + Args: + t (Optional[torch.Tensor]): The input tensor to be gathered. If `None`, no gathering is performed. + + Returns: + Union[torch.Tensor, None]: A concatenated tensor from all processes if ``t`` is not ``None``, + otherwise returns ``None``. + """ if t is None: return None t = t.contiguous() diff --git a/FlagEmbedding/abc/finetune/embedder/AbsRunner.py b/FlagEmbedding/abc/finetune/embedder/AbsRunner.py index 2137c2e..9466a3d 100644 --- a/FlagEmbedding/abc/finetune/embedder/AbsRunner.py +++ b/FlagEmbedding/abc/finetune/embedder/AbsRunner.py @@ -22,6 +22,13 @@ logger = logging.getLogger(__name__) class AbsEmbedderRunner(ABC): + """Abstract class to run embedding model fine-tuning. + + Args: + model_args (AbsEmbedderModelArguments): Model arguments + data_args (AbsEmbedderDataArguments): Data arguments. + training_args (AbsEmbedderTrainingArguments): Training arguments. + """ def __init__( self, model_args: AbsEmbedderModelArguments, @@ -70,13 +77,28 @@ class AbsEmbedderRunner(ABC): @abstractmethod def load_tokenizer_and_model(self) -> Tuple[PreTrainedTokenizer, AbsEmbedderModel]: + """Abstract method to load the tokenizer and model. + + Returns: + Tuple[PreTrainedTokenizer, AbsEmbedderModel]: Loaded tokenizer and model instances. + """ pass @abstractmethod def load_trainer(self) -> AbsEmbedderTrainer: + """Abstract method to load the trainer. + + Returns: + AbsEmbedderTrainer: The loaded trainer instance. + """ pass def load_train_dataset(self) -> AbsEmbedderTrainDataset: + """Loads the training dataset based on data arguments. + + Returns: + AbsEmbedderTrainDataset: The loaded dataset instance. + """ if self.data_args.same_dataset_within_batch: train_dataset = AbsEmbedderSameDatasetTrainDataset( args=self.data_args, @@ -96,6 +118,11 @@ class AbsEmbedderRunner(ABC): return train_dataset def load_data_collator(self) -> AbsEmbedderCollator: + """Loads the appropriate data collator. + + Returns: + AbsEmbedderCollator: Loaded data collator. + """ if self.data_args.same_dataset_within_batch: EmbedCollator = AbsEmbedderSameDatasetCollator else: @@ -113,6 +140,9 @@ class AbsEmbedderRunner(ABC): return data_collator def run(self): + """ + Executes the training process. + """ Path(self.training_args.output_dir).mkdir(parents=True, exist_ok=True) # Training diff --git a/FlagEmbedding/abc/finetune/embedder/AbsTrainer.py b/FlagEmbedding/abc/finetune/embedder/AbsTrainer.py index 0628e1b..f611f2a 100644 --- a/FlagEmbedding/abc/finetune/embedder/AbsTrainer.py +++ b/FlagEmbedding/abc/finetune/embedder/AbsTrainer.py @@ -7,6 +7,9 @@ logger = logging.getLogger(__name__) class AbsEmbedderTrainer(ABC, Trainer): + """ + Abstract class for the trainer of embedder. + """ @abstractmethod def _save(self, output_dir: Optional[str] = None, state_dict=None): pass @@ -16,6 +19,16 @@ class AbsEmbedderTrainer(ABC, Trainer): How the loss is computed by Trainer. By default, all models return the loss in the first element. Subclass and override for custom behavior. + + Args: + model (AbsEmbedderModel): The model being trained. + inputs (dict): A dictionary of input tensors to be passed to the model. + return_outputs (bool, optional): If ``True``, returns both the loss and the model's outputs. Otherwise, + returns only the loss. + + Returns: + Union[torch.Tensor, tuple(torch.Tensor, ModelOutput)]: The computed loss. If ``return_outputs`` is ``True``, + also returns the model's outputs in a tuple ``(loss, outputs)``. """ outputs = model(**inputs) diff --git a/FlagEmbedding/inference/reranker/decoder_only/base.py b/FlagEmbedding/inference/reranker/decoder_only/base.py index a48a175..d16b87a 100644 --- a/FlagEmbedding/inference/reranker/decoder_only/base.py +++ b/FlagEmbedding/inference/reranker/decoder_only/base.py @@ -175,10 +175,11 @@ class BaseLLMReranker(AbsReranker): Args: model_name_or_path (str): If it's a path to a local model, it loads the model from the path. Otherwise tries to download and load a model from HuggingFace Hub with the name. - peft_path (Optional[str], optional): _description_. Defaults to :data:`None`. + peft_path (Optional[str], optional): Path to the PEFT config. Defaults to :data:`None`. use_fp16 (bool, optional): If true, use half-precision floating-point to speed up computation with a slight performance degradation. Defaults to :data:`False`. Defaults to :data:`False`. - use_bf16 (bool, optional): _description_. Defaults to :data:False. + use_bf16 (bool, optional): Another type of half-precision floating-point, you can use bf16 if the hardware supports. + Defaults to :data:False. query_instruction_for_rerank (str, optional): Query instruction for retrieval tasks, which will be used with with :attr:`query_instruction_format`. Defaults to :data:`"A: "`. query_instruction_format (str, optional): The template for :attr:`query_instruction_for_rerank`. Defaults to :data:`"{}{}"`. @@ -248,7 +249,7 @@ class BaseLLMReranker(AbsReranker): torch_dtype=torch.bfloat16 if use_bf16 else torch.float32 ) if peft_path: - self.model = PeftModel.from_pretrained(self.model,peft_path) + self.model = PeftModel.from_pretrained(self.model, peft_path) self.model = self.model.merge_and_unload() self.yes_loc = self.tokenizer('Yes', add_special_tokens=False)['input_ids'][0] diff --git a/FlagEmbedding/inference/reranker/decoder_only/layerwise.py b/FlagEmbedding/inference/reranker/decoder_only/layerwise.py index 487d436..dce7393 100644 --- a/FlagEmbedding/inference/reranker/decoder_only/layerwise.py +++ b/FlagEmbedding/inference/reranker/decoder_only/layerwise.py @@ -41,10 +41,11 @@ class LayerWiseLLMReranker(AbsReranker): Args: model_name_or_path (str): If it's a path to a local model, it loads the model from the path. Otherwise tries to download and load a model from HuggingFace Hub with the name. - peft_path (Optional[str], optional): _description_. Defaults to :data:`None`. + peft_path (Optional[str], optional): Path to the PEFT config. Defaults to :data:`None`. use_fp16 (bool, optional): If true, use half-precision floating-point to speed up computation with a slight performance degradation. Defaults to :data:`False`. Defaults to :data:`False`. - use_bf16 (bool, optional): _description_. Defaults to :data:False. + use_bf16 (bool, optional): Another type of half-precision floating-point, you can use bf16 if the hardware supports. + Defaults to :data:False. query_instruction_for_rerank (str, optional): Query instruction for retrieval tasks, which will be used with with :attr:`query_instruction_format`. Defaults to :data:`"A: "`. query_instruction_format (str, optional): The template for :attr:`query_instruction_for_rerank`. Defaults to :data:`"{}{}"`. diff --git a/FlagEmbedding/inference/reranker/decoder_only/lightweight.py b/FlagEmbedding/inference/reranker/decoder_only/lightweight.py index 44e59c5..b67892b 100644 --- a/FlagEmbedding/inference/reranker/decoder_only/lightweight.py +++ b/FlagEmbedding/inference/reranker/decoder_only/lightweight.py @@ -95,10 +95,11 @@ class LightweightLLMReranker(AbsReranker): Args: model_name_or_path (str): If it's a path to a local model, it loads the model from the path. Otherwise tries to download and load a model from HuggingFace Hub with the name. - peft_path (Optional[str], optional): _description_. Defaults to :data:`None`. + peft_path (Optional[str], optional): Path to the PEFT config. Defaults to :data:`None`. use_fp16 (bool, optional): If true, use half-precision floating-point to speed up computation with a slight performance degradation. Defaults to :data:`False`. Defaults to :data:`False`. - use_bf16 (bool, optional): _description_. Defaults to :data:False. + use_bf16 (bool, optional): Another type of half-precision floating-point, you can use bf16 if the hardware supports. + Defaults to :data:False. query_instruction_for_rerank (str, optional): Query instruction for retrieval tasks, which will be used with with :attr:`query_instruction_format`. Defaults to :data:`"A: "`. query_instruction_format (str, optional): The template for :attr:`query_instruction_for_rerank`. Defaults to :data:`"{}{}"`.