diff --git a/haystack/nodes/retriever/_embedding_encoder.py b/haystack/nodes/retriever/_embedding_encoder.py index 11c314d41..564d9f834 100644 --- a/haystack/nodes/retriever/_embedding_encoder.py +++ b/haystack/nodes/retriever/_embedding_encoder.py @@ -2,7 +2,7 @@ import json import logging import os from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, Literal import numpy as np import requests @@ -94,6 +94,10 @@ class _DefaultEmbeddingEncoder(_BaseEmbeddingEncoder): n_epochs: int = 1, num_warmup_steps: Optional[int] = None, batch_size: int = 16, + train_loss: Literal["mnrl", "margin_mse"] = "mnrl", + num_workers: int = 0, + use_amp: bool = False, + **kwargs, ): raise NotImplementedError( "You can't train this retriever. You can only use the `train` method with sentence-transformers EmbeddingRetrievers." @@ -160,9 +164,37 @@ class _SentenceTransformersEmbeddingEncoder(_BaseEmbeddingEncoder): learning_rate: float = 2e-5, n_epochs: int = 1, num_warmup_steps: Optional[int] = None, - batch_size: int = 16, - train_loss: str = "mnrl", + batch_size: Optional[int] = 16, + train_loss: Literal["mnrl", "margin_mse"] = "mnrl", + num_workers: int = 0, + use_amp: bool = False, + **kwargs, ): + """ + Trains the underlying Sentence Transformer model. + + Each training data example is a dictionary with the following keys: + + * question: The question string. + * pos_doc: Positive document string (the document containing the answer). + * neg_doc: Negative document string (the document that doesn't contain the answer). + * score: The score margin the answer must fall within. + + :param training_data: The training data in a dictionary format. + :param learning_rate: The learning rate of the optimizer. + :param n_epochs: The number of iterations on the whole training data set you want to train for. + :param num_warmup_steps: Behavior depends on the scheduler. For WarmupLinear (default), the learning rate is + increased from 0 up to the maximal learning rate. After these many training steps, the learning rate is + decreased linearly back to zero. + :param batch_size: The batch size to use for the training. The default value is 16. + :param train_loss: Specify the training loss to use to fit the Sentence-Transformers model. Possible options are + "mnrl" (Multiple Negatives Ranking Loss) and "margin_mse". + :param num_workers: The number of subprocesses to use for the Pytorch DataLoader. + :param use_amp: Use Automatic Mixed Precision (AMP). + :param kwargs: Additional training keyword arguments to pass to the `SentenceTransformer.fit` function. Please + reference the Sentence-Transformers [documentation](https://www.sbert.net/docs/training/overview.html#sentence_transformers.SentenceTransformer.fit) + for a full list of keyword arguments. + """ if train_loss not in _TRAINING_LOSSES: raise ValueError(f"Unrecognized train_loss {train_loss}. Should be one of: {_TRAINING_LOSSES.keys()}") @@ -187,7 +219,13 @@ class _SentenceTransformersEmbeddingEncoder(_BaseEmbeddingEncoder): train_examples.append(InputExample(texts=texts)) logger.info("Training/adapting %s with %s examples", self.embedding_model, len(train_examples)) - train_dataloader = DataLoader(train_examples, batch_size=batch_size, drop_last=True, shuffle=True) # type: ignore [var-annotated,arg-type] + train_dataloader = DataLoader( + train_examples, # type: ignore [var-annotated, arg-type] + batch_size=batch_size, + drop_last=True, + shuffle=True, + num_workers=num_workers, + ) train_loss = st_loss.loss(self.embedding_model) # Tune the model @@ -196,6 +234,8 @@ class _SentenceTransformersEmbeddingEncoder(_BaseEmbeddingEncoder): epochs=n_epochs, optimizer_params={"lr": learning_rate}, warmup_steps=int(len(train_dataloader) * 0.1) if num_warmup_steps is None else num_warmup_steps, + use_amp=use_amp, + **kwargs, ) def save(self, save_dir: Union[Path, str]): @@ -303,6 +343,10 @@ class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder): n_epochs: int = 1, num_warmup_steps: Optional[int] = None, batch_size: int = 16, + train_loss: Literal["mnrl", "margin_mse"] = "mnrl", + num_workers: int = 0, + use_amp: bool = False, + **kwargs, ): raise NotImplementedError( "You can't train this retriever. You can only use the `train` method with sentence-transformers EmbeddingRetrievers." @@ -368,6 +412,10 @@ class _CohereEmbeddingEncoder(_BaseEmbeddingEncoder): n_epochs: int = 1, num_warmup_steps: Optional[int] = None, batch_size: int = 16, + train_loss: Literal["mnrl", "margin_mse"] = "mnrl", + num_workers: int = 0, + use_amp: bool = False, + **kwargs, ): raise NotImplementedError(f"Training is not implemented for {self.__class__}") diff --git a/haystack/nodes/retriever/dense.py b/haystack/nodes/retriever/dense.py index 376115128..fc992ce2c 100644 --- a/haystack/nodes/retriever/dense.py +++ b/haystack/nodes/retriever/dense.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import List, Dict, Union, Optional, Any +from typing import List, Dict, Union, Optional, Any, Literal import logging from pathlib import Path @@ -1850,10 +1850,13 @@ class EmbeddingRetriever(DenseRetriever): n_epochs: int = 1, num_warmup_steps: Optional[int] = None, batch_size: int = 16, - train_loss: str = "mnrl", + train_loss: Literal["mnrl", "margin_mse"] = "mnrl", + num_workers: int = 0, + use_amp: bool = False, + **kwargs, ) -> None: """ - Trains/adapts the underlying embedding model. + Trains/adapts the underlying embedding model. We only support the training of sentence-transformer embedding models. Each training data example is a dictionary with the following keys: @@ -1862,21 +1865,21 @@ class EmbeddingRetriever(DenseRetriever): * neg_doc: the negative document string * score: the score margin - - :param training_data: The training data - :type training_data: List[Dict[str, Any]] - :param learning_rate: The learning rate - :type learning_rate: float - :param n_epochs: The number of epochs - :type n_epochs: int - :param num_warmup_steps: The number of warmup steps - :type num_warmup_steps: int - :param batch_size: The batch size to use for the training, defaults to 16 - :type batch_size: int (optional) + :param training_data: The training data in a dictionary format. + :param learning_rate: The learning rate. + :param n_epochs: The number of epochs that you want the train for. + :param num_warmup_steps: Behavior depends on the scheduler. For WarmupLinear (default), the learning rate is + increased from 0 up to the maximal learning rate. After these many training steps, the learning rate is + decreased linearly back to zero. + :param batch_size: The batch size to use for the training. The default values is 16. :param train_loss: The loss to use for training. - If you're using sentence-transformers as embedding_model (which are the only ones that currently support training), + If you're using a sentence-transformer embedding_model (which is the only model that training is supported for), possible values are 'mnrl' (Multiple Negatives Ranking Loss) or 'margin_mse' (MarginMSE). - :type train_loss: str (optional) + :param num_workers: The number of subprocesses to use for the Pytorch DataLoader. + :param use_amp: Use Automatic Mixed Precision (AMP). + :param kwargs: Additional training key word arguments to pass to the `SentenceTransformer.fit` function. Please + reference the Sentence-Transformers [documentation](https://www.sbert.net/docs/training/overview.html#sentence_transformers.SentenceTransformer.fit) + for a full list of keyword arguments. """ self.embedding_encoder.train( training_data, @@ -1885,6 +1888,9 @@ class EmbeddingRetriever(DenseRetriever): num_warmup_steps=num_warmup_steps, batch_size=batch_size, train_loss=train_loss, + num_workers=num_workers, + use_amp=use_amp, + **kwargs, ) def save(self, save_dir: Union[Path, str]) -> None: