feat: Allow all training options for training a SentenceTransformers EmbeddingRetriever (#4026)

* Add additional options to pass to the SentenceTransformers trainer

* Make options accessible to the EmbeddingRetriever.train

* Update file-converters.yml

* Update transformers-img-to-text.yml

* Update 3550-csv-converter.md

* move type: ignore to correct line

* Moving type ignore again

* Fixing pylint and mypy

* Update haystack/nodes/retriever/_embedding_encoder.py

Co-authored-by: bogdankostic <bogdankostic@web.de>

* Update haystack/nodes/retriever/_embedding_encoder.py

Co-authored-by: bogdankostic <bogdankostic@web.de>

* Update haystack/nodes/retriever/_embedding_encoder.py

Co-authored-by: bogdankostic <bogdankostic@web.de>

* Updated docstring to be less misleading.

---------

Co-authored-by: bogdankostic <bogdankostic@web.de>
This commit is contained in:
Sebastian 2023-02-07 08:05:21 +01:00 committed by GitHub
parent bcf3bfdf79
commit a9f13d4641
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 74 additions and 20 deletions

View File

@ -2,7 +2,7 @@ import json
import logging
import os
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, Literal
import numpy as np
import requests
@ -94,6 +94,10 @@ class _DefaultEmbeddingEncoder(_BaseEmbeddingEncoder):
n_epochs: int = 1,
num_warmup_steps: Optional[int] = None,
batch_size: int = 16,
train_loss: Literal["mnrl", "margin_mse"] = "mnrl",
num_workers: int = 0,
use_amp: bool = False,
**kwargs,
):
raise NotImplementedError(
"You can't train this retriever. You can only use the `train` method with sentence-transformers EmbeddingRetrievers."
@ -160,9 +164,37 @@ class _SentenceTransformersEmbeddingEncoder(_BaseEmbeddingEncoder):
learning_rate: float = 2e-5,
n_epochs: int = 1,
num_warmup_steps: Optional[int] = None,
batch_size: int = 16,
train_loss: str = "mnrl",
batch_size: Optional[int] = 16,
train_loss: Literal["mnrl", "margin_mse"] = "mnrl",
num_workers: int = 0,
use_amp: bool = False,
**kwargs,
):
"""
Trains the underlying Sentence Transformer model.
Each training data example is a dictionary with the following keys:
* question: The question string.
* pos_doc: Positive document string (the document containing the answer).
* neg_doc: Negative document string (the document that doesn't contain the answer).
* score: The score margin the answer must fall within.
:param training_data: The training data in a dictionary format.
:param learning_rate: The learning rate of the optimizer.
:param n_epochs: The number of iterations on the whole training data set you want to train for.
:param num_warmup_steps: Behavior depends on the scheduler. For WarmupLinear (default), the learning rate is
increased from 0 up to the maximal learning rate. After these many training steps, the learning rate is
decreased linearly back to zero.
:param batch_size: The batch size to use for the training. The default value is 16.
:param train_loss: Specify the training loss to use to fit the Sentence-Transformers model. Possible options are
"mnrl" (Multiple Negatives Ranking Loss) and "margin_mse".
:param num_workers: The number of subprocesses to use for the Pytorch DataLoader.
:param use_amp: Use Automatic Mixed Precision (AMP).
:param kwargs: Additional training keyword arguments to pass to the `SentenceTransformer.fit` function. Please
reference the Sentence-Transformers [documentation](https://www.sbert.net/docs/training/overview.html#sentence_transformers.SentenceTransformer.fit)
for a full list of keyword arguments.
"""
if train_loss not in _TRAINING_LOSSES:
raise ValueError(f"Unrecognized train_loss {train_loss}. Should be one of: {_TRAINING_LOSSES.keys()}")
@ -187,7 +219,13 @@ class _SentenceTransformersEmbeddingEncoder(_BaseEmbeddingEncoder):
train_examples.append(InputExample(texts=texts))
logger.info("Training/adapting %s with %s examples", self.embedding_model, len(train_examples))
train_dataloader = DataLoader(train_examples, batch_size=batch_size, drop_last=True, shuffle=True) # type: ignore [var-annotated,arg-type]
train_dataloader = DataLoader(
train_examples, # type: ignore [var-annotated, arg-type]
batch_size=batch_size,
drop_last=True,
shuffle=True,
num_workers=num_workers,
)
train_loss = st_loss.loss(self.embedding_model)
# Tune the model
@ -196,6 +234,8 @@ class _SentenceTransformersEmbeddingEncoder(_BaseEmbeddingEncoder):
epochs=n_epochs,
optimizer_params={"lr": learning_rate},
warmup_steps=int(len(train_dataloader) * 0.1) if num_warmup_steps is None else num_warmup_steps,
use_amp=use_amp,
**kwargs,
)
def save(self, save_dir: Union[Path, str]):
@ -303,6 +343,10 @@ class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder):
n_epochs: int = 1,
num_warmup_steps: Optional[int] = None,
batch_size: int = 16,
train_loss: Literal["mnrl", "margin_mse"] = "mnrl",
num_workers: int = 0,
use_amp: bool = False,
**kwargs,
):
raise NotImplementedError(
"You can't train this retriever. You can only use the `train` method with sentence-transformers EmbeddingRetrievers."
@ -368,6 +412,10 @@ class _CohereEmbeddingEncoder(_BaseEmbeddingEncoder):
n_epochs: int = 1,
num_warmup_steps: Optional[int] = None,
batch_size: int = 16,
train_loss: Literal["mnrl", "margin_mse"] = "mnrl",
num_workers: int = 0,
use_amp: bool = False,
**kwargs,
):
raise NotImplementedError(f"Training is not implemented for {self.__class__}")

View File

@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import List, Dict, Union, Optional, Any
from typing import List, Dict, Union, Optional, Any, Literal
import logging
from pathlib import Path
@ -1850,10 +1850,13 @@ class EmbeddingRetriever(DenseRetriever):
n_epochs: int = 1,
num_warmup_steps: Optional[int] = None,
batch_size: int = 16,
train_loss: str = "mnrl",
train_loss: Literal["mnrl", "margin_mse"] = "mnrl",
num_workers: int = 0,
use_amp: bool = False,
**kwargs,
) -> None:
"""
Trains/adapts the underlying embedding model.
Trains/adapts the underlying embedding model. We only support the training of sentence-transformer embedding models.
Each training data example is a dictionary with the following keys:
@ -1862,21 +1865,21 @@ class EmbeddingRetriever(DenseRetriever):
* neg_doc: the negative document string
* score: the score margin
:param training_data: The training data
:type training_data: List[Dict[str, Any]]
:param learning_rate: The learning rate
:type learning_rate: float
:param n_epochs: The number of epochs
:type n_epochs: int
:param num_warmup_steps: The number of warmup steps
:type num_warmup_steps: int
:param batch_size: The batch size to use for the training, defaults to 16
:type batch_size: int (optional)
:param training_data: The training data in a dictionary format.
:param learning_rate: The learning rate.
:param n_epochs: The number of epochs that you want the train for.
:param num_warmup_steps: Behavior depends on the scheduler. For WarmupLinear (default), the learning rate is
increased from 0 up to the maximal learning rate. After these many training steps, the learning rate is
decreased linearly back to zero.
:param batch_size: The batch size to use for the training. The default values is 16.
:param train_loss: The loss to use for training.
If you're using sentence-transformers as embedding_model (which are the only ones that currently support training),
If you're using a sentence-transformer embedding_model (which is the only model that training is supported for),
possible values are 'mnrl' (Multiple Negatives Ranking Loss) or 'margin_mse' (MarginMSE).
:type train_loss: str (optional)
:param num_workers: The number of subprocesses to use for the Pytorch DataLoader.
:param use_amp: Use Automatic Mixed Precision (AMP).
:param kwargs: Additional training key word arguments to pass to the `SentenceTransformer.fit` function. Please
reference the Sentence-Transformers [documentation](https://www.sbert.net/docs/training/overview.html#sentence_transformers.SentenceTransformer.fit)
for a full list of keyword arguments.
"""
self.embedding_encoder.train(
training_data,
@ -1885,6 +1888,9 @@ class EmbeddingRetriever(DenseRetriever):
num_warmup_steps=num_warmup_steps,
batch_size=batch_size,
train_loss=train_loss,
num_workers=num_workers,
use_amp=use_amp,
**kwargs,
)
def save(self, save_dir: Union[Path, str]) -> None: