feat: Allow all training options for training a SentenceTransformers EmbeddingRetriever (#4026)

* Add additional options to pass to the SentenceTransformers trainer

* Make options accessible to the EmbeddingRetriever.train

* Update file-converters.yml

* Update transformers-img-to-text.yml

* Update 3550-csv-converter.md

* move type: ignore to correct line

* Moving type ignore again

* Fixing pylint and mypy

* Update haystack/nodes/retriever/_embedding_encoder.py

Co-authored-by: bogdankostic <bogdankostic@web.de>

* Update haystack/nodes/retriever/_embedding_encoder.py

Co-authored-by: bogdankostic <bogdankostic@web.de>

* Update haystack/nodes/retriever/_embedding_encoder.py

Co-authored-by: bogdankostic <bogdankostic@web.de>

* Updated docstring to be less misleading.

---------

Co-authored-by: bogdankostic <bogdankostic@web.de>
This commit is contained in:
Sebastian 2023-02-07 08:05:21 +01:00 committed by GitHub
parent bcf3bfdf79
commit a9f13d4641
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 74 additions and 20 deletions

View File

@ -2,7 +2,7 @@ import json
import logging import logging
import os import os
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, Literal
import numpy as np import numpy as np
import requests import requests
@ -94,6 +94,10 @@ class _DefaultEmbeddingEncoder(_BaseEmbeddingEncoder):
n_epochs: int = 1, n_epochs: int = 1,
num_warmup_steps: Optional[int] = None, num_warmup_steps: Optional[int] = None,
batch_size: int = 16, batch_size: int = 16,
train_loss: Literal["mnrl", "margin_mse"] = "mnrl",
num_workers: int = 0,
use_amp: bool = False,
**kwargs,
): ):
raise NotImplementedError( raise NotImplementedError(
"You can't train this retriever. You can only use the `train` method with sentence-transformers EmbeddingRetrievers." "You can't train this retriever. You can only use the `train` method with sentence-transformers EmbeddingRetrievers."
@ -160,9 +164,37 @@ class _SentenceTransformersEmbeddingEncoder(_BaseEmbeddingEncoder):
learning_rate: float = 2e-5, learning_rate: float = 2e-5,
n_epochs: int = 1, n_epochs: int = 1,
num_warmup_steps: Optional[int] = None, num_warmup_steps: Optional[int] = None,
batch_size: int = 16, batch_size: Optional[int] = 16,
train_loss: str = "mnrl", train_loss: Literal["mnrl", "margin_mse"] = "mnrl",
num_workers: int = 0,
use_amp: bool = False,
**kwargs,
): ):
"""
Trains the underlying Sentence Transformer model.
Each training data example is a dictionary with the following keys:
* question: The question string.
* pos_doc: Positive document string (the document containing the answer).
* neg_doc: Negative document string (the document that doesn't contain the answer).
* score: The score margin the answer must fall within.
:param training_data: The training data in a dictionary format.
:param learning_rate: The learning rate of the optimizer.
:param n_epochs: The number of iterations on the whole training data set you want to train for.
:param num_warmup_steps: Behavior depends on the scheduler. For WarmupLinear (default), the learning rate is
increased from 0 up to the maximal learning rate. After these many training steps, the learning rate is
decreased linearly back to zero.
:param batch_size: The batch size to use for the training. The default value is 16.
:param train_loss: Specify the training loss to use to fit the Sentence-Transformers model. Possible options are
"mnrl" (Multiple Negatives Ranking Loss) and "margin_mse".
:param num_workers: The number of subprocesses to use for the Pytorch DataLoader.
:param use_amp: Use Automatic Mixed Precision (AMP).
:param kwargs: Additional training keyword arguments to pass to the `SentenceTransformer.fit` function. Please
reference the Sentence-Transformers [documentation](https://www.sbert.net/docs/training/overview.html#sentence_transformers.SentenceTransformer.fit)
for a full list of keyword arguments.
"""
if train_loss not in _TRAINING_LOSSES: if train_loss not in _TRAINING_LOSSES:
raise ValueError(f"Unrecognized train_loss {train_loss}. Should be one of: {_TRAINING_LOSSES.keys()}") raise ValueError(f"Unrecognized train_loss {train_loss}. Should be one of: {_TRAINING_LOSSES.keys()}")
@ -187,7 +219,13 @@ class _SentenceTransformersEmbeddingEncoder(_BaseEmbeddingEncoder):
train_examples.append(InputExample(texts=texts)) train_examples.append(InputExample(texts=texts))
logger.info("Training/adapting %s with %s examples", self.embedding_model, len(train_examples)) logger.info("Training/adapting %s with %s examples", self.embedding_model, len(train_examples))
train_dataloader = DataLoader(train_examples, batch_size=batch_size, drop_last=True, shuffle=True) # type: ignore [var-annotated,arg-type] train_dataloader = DataLoader(
train_examples, # type: ignore [var-annotated, arg-type]
batch_size=batch_size,
drop_last=True,
shuffle=True,
num_workers=num_workers,
)
train_loss = st_loss.loss(self.embedding_model) train_loss = st_loss.loss(self.embedding_model)
# Tune the model # Tune the model
@ -196,6 +234,8 @@ class _SentenceTransformersEmbeddingEncoder(_BaseEmbeddingEncoder):
epochs=n_epochs, epochs=n_epochs,
optimizer_params={"lr": learning_rate}, optimizer_params={"lr": learning_rate},
warmup_steps=int(len(train_dataloader) * 0.1) if num_warmup_steps is None else num_warmup_steps, warmup_steps=int(len(train_dataloader) * 0.1) if num_warmup_steps is None else num_warmup_steps,
use_amp=use_amp,
**kwargs,
) )
def save(self, save_dir: Union[Path, str]): def save(self, save_dir: Union[Path, str]):
@ -303,6 +343,10 @@ class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder):
n_epochs: int = 1, n_epochs: int = 1,
num_warmup_steps: Optional[int] = None, num_warmup_steps: Optional[int] = None,
batch_size: int = 16, batch_size: int = 16,
train_loss: Literal["mnrl", "margin_mse"] = "mnrl",
num_workers: int = 0,
use_amp: bool = False,
**kwargs,
): ):
raise NotImplementedError( raise NotImplementedError(
"You can't train this retriever. You can only use the `train` method with sentence-transformers EmbeddingRetrievers." "You can't train this retriever. You can only use the `train` method with sentence-transformers EmbeddingRetrievers."
@ -368,6 +412,10 @@ class _CohereEmbeddingEncoder(_BaseEmbeddingEncoder):
n_epochs: int = 1, n_epochs: int = 1,
num_warmup_steps: Optional[int] = None, num_warmup_steps: Optional[int] = None,
batch_size: int = 16, batch_size: int = 16,
train_loss: Literal["mnrl", "margin_mse"] = "mnrl",
num_workers: int = 0,
use_amp: bool = False,
**kwargs,
): ):
raise NotImplementedError(f"Training is not implemented for {self.__class__}") raise NotImplementedError(f"Training is not implemented for {self.__class__}")

View File

@ -1,5 +1,5 @@
from abc import abstractmethod from abc import abstractmethod
from typing import List, Dict, Union, Optional, Any from typing import List, Dict, Union, Optional, Any, Literal
import logging import logging
from pathlib import Path from pathlib import Path
@ -1850,10 +1850,13 @@ class EmbeddingRetriever(DenseRetriever):
n_epochs: int = 1, n_epochs: int = 1,
num_warmup_steps: Optional[int] = None, num_warmup_steps: Optional[int] = None,
batch_size: int = 16, batch_size: int = 16,
train_loss: str = "mnrl", train_loss: Literal["mnrl", "margin_mse"] = "mnrl",
num_workers: int = 0,
use_amp: bool = False,
**kwargs,
) -> None: ) -> None:
""" """
Trains/adapts the underlying embedding model. Trains/adapts the underlying embedding model. We only support the training of sentence-transformer embedding models.
Each training data example is a dictionary with the following keys: Each training data example is a dictionary with the following keys:
@ -1862,21 +1865,21 @@ class EmbeddingRetriever(DenseRetriever):
* neg_doc: the negative document string * neg_doc: the negative document string
* score: the score margin * score: the score margin
:param training_data: The training data in a dictionary format.
:param training_data: The training data :param learning_rate: The learning rate.
:type training_data: List[Dict[str, Any]] :param n_epochs: The number of epochs that you want the train for.
:param learning_rate: The learning rate :param num_warmup_steps: Behavior depends on the scheduler. For WarmupLinear (default), the learning rate is
:type learning_rate: float increased from 0 up to the maximal learning rate. After these many training steps, the learning rate is
:param n_epochs: The number of epochs decreased linearly back to zero.
:type n_epochs: int :param batch_size: The batch size to use for the training. The default values is 16.
:param num_warmup_steps: The number of warmup steps
:type num_warmup_steps: int
:param batch_size: The batch size to use for the training, defaults to 16
:type batch_size: int (optional)
:param train_loss: The loss to use for training. :param train_loss: The loss to use for training.
If you're using sentence-transformers as embedding_model (which are the only ones that currently support training), If you're using a sentence-transformer embedding_model (which is the only model that training is supported for),
possible values are 'mnrl' (Multiple Negatives Ranking Loss) or 'margin_mse' (MarginMSE). possible values are 'mnrl' (Multiple Negatives Ranking Loss) or 'margin_mse' (MarginMSE).
:type train_loss: str (optional) :param num_workers: The number of subprocesses to use for the Pytorch DataLoader.
:param use_amp: Use Automatic Mixed Precision (AMP).
:param kwargs: Additional training key word arguments to pass to the `SentenceTransformer.fit` function. Please
reference the Sentence-Transformers [documentation](https://www.sbert.net/docs/training/overview.html#sentence_transformers.SentenceTransformer.fit)
for a full list of keyword arguments.
""" """
self.embedding_encoder.train( self.embedding_encoder.train(
training_data, training_data,
@ -1885,6 +1888,9 @@ class EmbeddingRetriever(DenseRetriever):
num_warmup_steps=num_warmup_steps, num_warmup_steps=num_warmup_steps,
batch_size=batch_size, batch_size=batch_size,
train_loss=train_loss, train_loss=train_loss,
num_workers=num_workers,
use_amp=use_amp,
**kwargs,
) )
def save(self, save_dir: Union[Path, str]) -> None: def save(self, save_dir: Union[Path, str]) -> None: