mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-25 16:15:35 +00:00
feat: Allow all training options for training a SentenceTransformers EmbeddingRetriever (#4026)
* Add additional options to pass to the SentenceTransformers trainer * Make options accessible to the EmbeddingRetriever.train * Update file-converters.yml * Update transformers-img-to-text.yml * Update 3550-csv-converter.md * move type: ignore to correct line * Moving type ignore again * Fixing pylint and mypy * Update haystack/nodes/retriever/_embedding_encoder.py Co-authored-by: bogdankostic <bogdankostic@web.de> * Update haystack/nodes/retriever/_embedding_encoder.py Co-authored-by: bogdankostic <bogdankostic@web.de> * Update haystack/nodes/retriever/_embedding_encoder.py Co-authored-by: bogdankostic <bogdankostic@web.de> * Updated docstring to be less misleading. --------- Co-authored-by: bogdankostic <bogdankostic@web.de>
This commit is contained in:
parent
bcf3bfdf79
commit
a9f13d4641
@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, Literal
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
@ -94,6 +94,10 @@ class _DefaultEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
n_epochs: int = 1,
|
||||
num_warmup_steps: Optional[int] = None,
|
||||
batch_size: int = 16,
|
||||
train_loss: Literal["mnrl", "margin_mse"] = "mnrl",
|
||||
num_workers: int = 0,
|
||||
use_amp: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"You can't train this retriever. You can only use the `train` method with sentence-transformers EmbeddingRetrievers."
|
||||
@ -160,9 +164,37 @@ class _SentenceTransformersEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
learning_rate: float = 2e-5,
|
||||
n_epochs: int = 1,
|
||||
num_warmup_steps: Optional[int] = None,
|
||||
batch_size: int = 16,
|
||||
train_loss: str = "mnrl",
|
||||
batch_size: Optional[int] = 16,
|
||||
train_loss: Literal["mnrl", "margin_mse"] = "mnrl",
|
||||
num_workers: int = 0,
|
||||
use_amp: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Trains the underlying Sentence Transformer model.
|
||||
|
||||
Each training data example is a dictionary with the following keys:
|
||||
|
||||
* question: The question string.
|
||||
* pos_doc: Positive document string (the document containing the answer).
|
||||
* neg_doc: Negative document string (the document that doesn't contain the answer).
|
||||
* score: The score margin the answer must fall within.
|
||||
|
||||
:param training_data: The training data in a dictionary format.
|
||||
:param learning_rate: The learning rate of the optimizer.
|
||||
:param n_epochs: The number of iterations on the whole training data set you want to train for.
|
||||
:param num_warmup_steps: Behavior depends on the scheduler. For WarmupLinear (default), the learning rate is
|
||||
increased from 0 up to the maximal learning rate. After these many training steps, the learning rate is
|
||||
decreased linearly back to zero.
|
||||
:param batch_size: The batch size to use for the training. The default value is 16.
|
||||
:param train_loss: Specify the training loss to use to fit the Sentence-Transformers model. Possible options are
|
||||
"mnrl" (Multiple Negatives Ranking Loss) and "margin_mse".
|
||||
:param num_workers: The number of subprocesses to use for the Pytorch DataLoader.
|
||||
:param use_amp: Use Automatic Mixed Precision (AMP).
|
||||
:param kwargs: Additional training keyword arguments to pass to the `SentenceTransformer.fit` function. Please
|
||||
reference the Sentence-Transformers [documentation](https://www.sbert.net/docs/training/overview.html#sentence_transformers.SentenceTransformer.fit)
|
||||
for a full list of keyword arguments.
|
||||
"""
|
||||
|
||||
if train_loss not in _TRAINING_LOSSES:
|
||||
raise ValueError(f"Unrecognized train_loss {train_loss}. Should be one of: {_TRAINING_LOSSES.keys()}")
|
||||
@ -187,7 +219,13 @@ class _SentenceTransformersEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
train_examples.append(InputExample(texts=texts))
|
||||
|
||||
logger.info("Training/adapting %s with %s examples", self.embedding_model, len(train_examples))
|
||||
train_dataloader = DataLoader(train_examples, batch_size=batch_size, drop_last=True, shuffle=True) # type: ignore [var-annotated,arg-type]
|
||||
train_dataloader = DataLoader(
|
||||
train_examples, # type: ignore [var-annotated, arg-type]
|
||||
batch_size=batch_size,
|
||||
drop_last=True,
|
||||
shuffle=True,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
train_loss = st_loss.loss(self.embedding_model)
|
||||
|
||||
# Tune the model
|
||||
@ -196,6 +234,8 @@ class _SentenceTransformersEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
epochs=n_epochs,
|
||||
optimizer_params={"lr": learning_rate},
|
||||
warmup_steps=int(len(train_dataloader) * 0.1) if num_warmup_steps is None else num_warmup_steps,
|
||||
use_amp=use_amp,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def save(self, save_dir: Union[Path, str]):
|
||||
@ -303,6 +343,10 @@ class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
n_epochs: int = 1,
|
||||
num_warmup_steps: Optional[int] = None,
|
||||
batch_size: int = 16,
|
||||
train_loss: Literal["mnrl", "margin_mse"] = "mnrl",
|
||||
num_workers: int = 0,
|
||||
use_amp: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"You can't train this retriever. You can only use the `train` method with sentence-transformers EmbeddingRetrievers."
|
||||
@ -368,6 +412,10 @@ class _CohereEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
n_epochs: int = 1,
|
||||
num_warmup_steps: Optional[int] = None,
|
||||
batch_size: int = 16,
|
||||
train_loss: Literal["mnrl", "margin_mse"] = "mnrl",
|
||||
num_workers: int = 0,
|
||||
use_amp: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
raise NotImplementedError(f"Training is not implemented for {self.__class__}")
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
from abc import abstractmethod
|
||||
from typing import List, Dict, Union, Optional, Any
|
||||
from typing import List, Dict, Union, Optional, Any, Literal
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
@ -1850,10 +1850,13 @@ class EmbeddingRetriever(DenseRetriever):
|
||||
n_epochs: int = 1,
|
||||
num_warmup_steps: Optional[int] = None,
|
||||
batch_size: int = 16,
|
||||
train_loss: str = "mnrl",
|
||||
train_loss: Literal["mnrl", "margin_mse"] = "mnrl",
|
||||
num_workers: int = 0,
|
||||
use_amp: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Trains/adapts the underlying embedding model.
|
||||
Trains/adapts the underlying embedding model. We only support the training of sentence-transformer embedding models.
|
||||
|
||||
Each training data example is a dictionary with the following keys:
|
||||
|
||||
@ -1862,21 +1865,21 @@ class EmbeddingRetriever(DenseRetriever):
|
||||
* neg_doc: the negative document string
|
||||
* score: the score margin
|
||||
|
||||
|
||||
:param training_data: The training data
|
||||
:type training_data: List[Dict[str, Any]]
|
||||
:param learning_rate: The learning rate
|
||||
:type learning_rate: float
|
||||
:param n_epochs: The number of epochs
|
||||
:type n_epochs: int
|
||||
:param num_warmup_steps: The number of warmup steps
|
||||
:type num_warmup_steps: int
|
||||
:param batch_size: The batch size to use for the training, defaults to 16
|
||||
:type batch_size: int (optional)
|
||||
:param training_data: The training data in a dictionary format.
|
||||
:param learning_rate: The learning rate.
|
||||
:param n_epochs: The number of epochs that you want the train for.
|
||||
:param num_warmup_steps: Behavior depends on the scheduler. For WarmupLinear (default), the learning rate is
|
||||
increased from 0 up to the maximal learning rate. After these many training steps, the learning rate is
|
||||
decreased linearly back to zero.
|
||||
:param batch_size: The batch size to use for the training. The default values is 16.
|
||||
:param train_loss: The loss to use for training.
|
||||
If you're using sentence-transformers as embedding_model (which are the only ones that currently support training),
|
||||
If you're using a sentence-transformer embedding_model (which is the only model that training is supported for),
|
||||
possible values are 'mnrl' (Multiple Negatives Ranking Loss) or 'margin_mse' (MarginMSE).
|
||||
:type train_loss: str (optional)
|
||||
:param num_workers: The number of subprocesses to use for the Pytorch DataLoader.
|
||||
:param use_amp: Use Automatic Mixed Precision (AMP).
|
||||
:param kwargs: Additional training key word arguments to pass to the `SentenceTransformer.fit` function. Please
|
||||
reference the Sentence-Transformers [documentation](https://www.sbert.net/docs/training/overview.html#sentence_transformers.SentenceTransformer.fit)
|
||||
for a full list of keyword arguments.
|
||||
"""
|
||||
self.embedding_encoder.train(
|
||||
training_data,
|
||||
@ -1885,6 +1888,9 @@ class EmbeddingRetriever(DenseRetriever):
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
batch_size=batch_size,
|
||||
train_loss=train_loss,
|
||||
num_workers=num_workers,
|
||||
use_amp=use_amp,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def save(self, save_dir: Union[Path, str]) -> None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user