mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-28 09:35:42 +00:00
feat: Allow all training options for training a SentenceTransformers EmbeddingRetriever (#4026)
* Add additional options to pass to the SentenceTransformers trainer * Make options accessible to the EmbeddingRetriever.train * Update file-converters.yml * Update transformers-img-to-text.yml * Update 3550-csv-converter.md * move type: ignore to correct line * Moving type ignore again * Fixing pylint and mypy * Update haystack/nodes/retriever/_embedding_encoder.py Co-authored-by: bogdankostic <bogdankostic@web.de> * Update haystack/nodes/retriever/_embedding_encoder.py Co-authored-by: bogdankostic <bogdankostic@web.de> * Update haystack/nodes/retriever/_embedding_encoder.py Co-authored-by: bogdankostic <bogdankostic@web.de> * Updated docstring to be less misleading. --------- Co-authored-by: bogdankostic <bogdankostic@web.de>
This commit is contained in:
parent
bcf3bfdf79
commit
a9f13d4641
@ -2,7 +2,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, Literal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
@ -94,6 +94,10 @@ class _DefaultEmbeddingEncoder(_BaseEmbeddingEncoder):
|
|||||||
n_epochs: int = 1,
|
n_epochs: int = 1,
|
||||||
num_warmup_steps: Optional[int] = None,
|
num_warmup_steps: Optional[int] = None,
|
||||||
batch_size: int = 16,
|
batch_size: int = 16,
|
||||||
|
train_loss: Literal["mnrl", "margin_mse"] = "mnrl",
|
||||||
|
num_workers: int = 0,
|
||||||
|
use_amp: bool = False,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"You can't train this retriever. You can only use the `train` method with sentence-transformers EmbeddingRetrievers."
|
"You can't train this retriever. You can only use the `train` method with sentence-transformers EmbeddingRetrievers."
|
||||||
@ -160,9 +164,37 @@ class _SentenceTransformersEmbeddingEncoder(_BaseEmbeddingEncoder):
|
|||||||
learning_rate: float = 2e-5,
|
learning_rate: float = 2e-5,
|
||||||
n_epochs: int = 1,
|
n_epochs: int = 1,
|
||||||
num_warmup_steps: Optional[int] = None,
|
num_warmup_steps: Optional[int] = None,
|
||||||
batch_size: int = 16,
|
batch_size: Optional[int] = 16,
|
||||||
train_loss: str = "mnrl",
|
train_loss: Literal["mnrl", "margin_mse"] = "mnrl",
|
||||||
|
num_workers: int = 0,
|
||||||
|
use_amp: bool = False,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Trains the underlying Sentence Transformer model.
|
||||||
|
|
||||||
|
Each training data example is a dictionary with the following keys:
|
||||||
|
|
||||||
|
* question: The question string.
|
||||||
|
* pos_doc: Positive document string (the document containing the answer).
|
||||||
|
* neg_doc: Negative document string (the document that doesn't contain the answer).
|
||||||
|
* score: The score margin the answer must fall within.
|
||||||
|
|
||||||
|
:param training_data: The training data in a dictionary format.
|
||||||
|
:param learning_rate: The learning rate of the optimizer.
|
||||||
|
:param n_epochs: The number of iterations on the whole training data set you want to train for.
|
||||||
|
:param num_warmup_steps: Behavior depends on the scheduler. For WarmupLinear (default), the learning rate is
|
||||||
|
increased from 0 up to the maximal learning rate. After these many training steps, the learning rate is
|
||||||
|
decreased linearly back to zero.
|
||||||
|
:param batch_size: The batch size to use for the training. The default value is 16.
|
||||||
|
:param train_loss: Specify the training loss to use to fit the Sentence-Transformers model. Possible options are
|
||||||
|
"mnrl" (Multiple Negatives Ranking Loss) and "margin_mse".
|
||||||
|
:param num_workers: The number of subprocesses to use for the Pytorch DataLoader.
|
||||||
|
:param use_amp: Use Automatic Mixed Precision (AMP).
|
||||||
|
:param kwargs: Additional training keyword arguments to pass to the `SentenceTransformer.fit` function. Please
|
||||||
|
reference the Sentence-Transformers [documentation](https://www.sbert.net/docs/training/overview.html#sentence_transformers.SentenceTransformer.fit)
|
||||||
|
for a full list of keyword arguments.
|
||||||
|
"""
|
||||||
|
|
||||||
if train_loss not in _TRAINING_LOSSES:
|
if train_loss not in _TRAINING_LOSSES:
|
||||||
raise ValueError(f"Unrecognized train_loss {train_loss}. Should be one of: {_TRAINING_LOSSES.keys()}")
|
raise ValueError(f"Unrecognized train_loss {train_loss}. Should be one of: {_TRAINING_LOSSES.keys()}")
|
||||||
@ -187,7 +219,13 @@ class _SentenceTransformersEmbeddingEncoder(_BaseEmbeddingEncoder):
|
|||||||
train_examples.append(InputExample(texts=texts))
|
train_examples.append(InputExample(texts=texts))
|
||||||
|
|
||||||
logger.info("Training/adapting %s with %s examples", self.embedding_model, len(train_examples))
|
logger.info("Training/adapting %s with %s examples", self.embedding_model, len(train_examples))
|
||||||
train_dataloader = DataLoader(train_examples, batch_size=batch_size, drop_last=True, shuffle=True) # type: ignore [var-annotated,arg-type]
|
train_dataloader = DataLoader(
|
||||||
|
train_examples, # type: ignore [var-annotated, arg-type]
|
||||||
|
batch_size=batch_size,
|
||||||
|
drop_last=True,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=num_workers,
|
||||||
|
)
|
||||||
train_loss = st_loss.loss(self.embedding_model)
|
train_loss = st_loss.loss(self.embedding_model)
|
||||||
|
|
||||||
# Tune the model
|
# Tune the model
|
||||||
@ -196,6 +234,8 @@ class _SentenceTransformersEmbeddingEncoder(_BaseEmbeddingEncoder):
|
|||||||
epochs=n_epochs,
|
epochs=n_epochs,
|
||||||
optimizer_params={"lr": learning_rate},
|
optimizer_params={"lr": learning_rate},
|
||||||
warmup_steps=int(len(train_dataloader) * 0.1) if num_warmup_steps is None else num_warmup_steps,
|
warmup_steps=int(len(train_dataloader) * 0.1) if num_warmup_steps is None else num_warmup_steps,
|
||||||
|
use_amp=use_amp,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def save(self, save_dir: Union[Path, str]):
|
def save(self, save_dir: Union[Path, str]):
|
||||||
@ -303,6 +343,10 @@ class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder):
|
|||||||
n_epochs: int = 1,
|
n_epochs: int = 1,
|
||||||
num_warmup_steps: Optional[int] = None,
|
num_warmup_steps: Optional[int] = None,
|
||||||
batch_size: int = 16,
|
batch_size: int = 16,
|
||||||
|
train_loss: Literal["mnrl", "margin_mse"] = "mnrl",
|
||||||
|
num_workers: int = 0,
|
||||||
|
use_amp: bool = False,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"You can't train this retriever. You can only use the `train` method with sentence-transformers EmbeddingRetrievers."
|
"You can't train this retriever. You can only use the `train` method with sentence-transformers EmbeddingRetrievers."
|
||||||
@ -368,6 +412,10 @@ class _CohereEmbeddingEncoder(_BaseEmbeddingEncoder):
|
|||||||
n_epochs: int = 1,
|
n_epochs: int = 1,
|
||||||
num_warmup_steps: Optional[int] = None,
|
num_warmup_steps: Optional[int] = None,
|
||||||
batch_size: int = 16,
|
batch_size: int = 16,
|
||||||
|
train_loss: Literal["mnrl", "margin_mse"] = "mnrl",
|
||||||
|
num_workers: int = 0,
|
||||||
|
use_amp: bool = False,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
raise NotImplementedError(f"Training is not implemented for {self.__class__}")
|
raise NotImplementedError(f"Training is not implemented for {self.__class__}")
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import List, Dict, Union, Optional, Any
|
from typing import List, Dict, Union, Optional, Any, Literal
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -1850,10 +1850,13 @@ class EmbeddingRetriever(DenseRetriever):
|
|||||||
n_epochs: int = 1,
|
n_epochs: int = 1,
|
||||||
num_warmup_steps: Optional[int] = None,
|
num_warmup_steps: Optional[int] = None,
|
||||||
batch_size: int = 16,
|
batch_size: int = 16,
|
||||||
train_loss: str = "mnrl",
|
train_loss: Literal["mnrl", "margin_mse"] = "mnrl",
|
||||||
|
num_workers: int = 0,
|
||||||
|
use_amp: bool = False,
|
||||||
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Trains/adapts the underlying embedding model.
|
Trains/adapts the underlying embedding model. We only support the training of sentence-transformer embedding models.
|
||||||
|
|
||||||
Each training data example is a dictionary with the following keys:
|
Each training data example is a dictionary with the following keys:
|
||||||
|
|
||||||
@ -1862,21 +1865,21 @@ class EmbeddingRetriever(DenseRetriever):
|
|||||||
* neg_doc: the negative document string
|
* neg_doc: the negative document string
|
||||||
* score: the score margin
|
* score: the score margin
|
||||||
|
|
||||||
|
:param training_data: The training data in a dictionary format.
|
||||||
:param training_data: The training data
|
:param learning_rate: The learning rate.
|
||||||
:type training_data: List[Dict[str, Any]]
|
:param n_epochs: The number of epochs that you want the train for.
|
||||||
:param learning_rate: The learning rate
|
:param num_warmup_steps: Behavior depends on the scheduler. For WarmupLinear (default), the learning rate is
|
||||||
:type learning_rate: float
|
increased from 0 up to the maximal learning rate. After these many training steps, the learning rate is
|
||||||
:param n_epochs: The number of epochs
|
decreased linearly back to zero.
|
||||||
:type n_epochs: int
|
:param batch_size: The batch size to use for the training. The default values is 16.
|
||||||
:param num_warmup_steps: The number of warmup steps
|
|
||||||
:type num_warmup_steps: int
|
|
||||||
:param batch_size: The batch size to use for the training, defaults to 16
|
|
||||||
:type batch_size: int (optional)
|
|
||||||
:param train_loss: The loss to use for training.
|
:param train_loss: The loss to use for training.
|
||||||
If you're using sentence-transformers as embedding_model (which are the only ones that currently support training),
|
If you're using a sentence-transformer embedding_model (which is the only model that training is supported for),
|
||||||
possible values are 'mnrl' (Multiple Negatives Ranking Loss) or 'margin_mse' (MarginMSE).
|
possible values are 'mnrl' (Multiple Negatives Ranking Loss) or 'margin_mse' (MarginMSE).
|
||||||
:type train_loss: str (optional)
|
:param num_workers: The number of subprocesses to use for the Pytorch DataLoader.
|
||||||
|
:param use_amp: Use Automatic Mixed Precision (AMP).
|
||||||
|
:param kwargs: Additional training key word arguments to pass to the `SentenceTransformer.fit` function. Please
|
||||||
|
reference the Sentence-Transformers [documentation](https://www.sbert.net/docs/training/overview.html#sentence_transformers.SentenceTransformer.fit)
|
||||||
|
for a full list of keyword arguments.
|
||||||
"""
|
"""
|
||||||
self.embedding_encoder.train(
|
self.embedding_encoder.train(
|
||||||
training_data,
|
training_data,
|
||||||
@ -1885,6 +1888,9 @@ class EmbeddingRetriever(DenseRetriever):
|
|||||||
num_warmup_steps=num_warmup_steps,
|
num_warmup_steps=num_warmup_steps,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
train_loss=train_loss,
|
train_loss=train_loss,
|
||||||
|
num_workers=num_workers,
|
||||||
|
use_amp=use_amp,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def save(self, save_dir: Union[Path, str]) -> None:
|
def save(self, save_dir: Union[Path, str]) -> None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user