mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-24 21:48:52 +00:00
feat: SentenceTransformersTextEmbedder supports config_kwargs (#8432)
* add config_kwargs * disable PLR0913 for a specific function * add a release note * refer to AutoConfig in config_kwargs docstring --------- Co-authored-by: David S. Batista <dsbatista@gmail.com> Co-authored-by: Julian Risch <julianrisch@gmx.de>
This commit is contained in:
parent
b81abc0c85
commit
b40f0c8b5d
@ -34,7 +34,7 @@ class SentenceTransformersTextEmbedder:
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
def __init__( # noqa: PLR0913
|
||||
self,
|
||||
model: str = "sentence-transformers/all-mpnet-base-v2",
|
||||
device: Optional[ComponentDevice] = None,
|
||||
@ -48,6 +48,7 @@ class SentenceTransformersTextEmbedder:
|
||||
truncate_dim: Optional[int] = None,
|
||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
config_kwargs: Optional[Dict[str, Any]] = None,
|
||||
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
|
||||
):
|
||||
"""
|
||||
@ -86,6 +87,8 @@ class SentenceTransformersTextEmbedder:
|
||||
:param tokenizer_kwargs:
|
||||
Additional keyword arguments for `AutoTokenizer.from_pretrained` when loading the tokenizer.
|
||||
Refer to specific model documentation for available kwargs.
|
||||
:param config_kwargs:
|
||||
Additional keyword arguments for `AutoConfig.from_pretrained` when loading the model configuration.
|
||||
:param precision:
|
||||
The precision to use for the embeddings.
|
||||
All non-float32 precisions are quantized embeddings.
|
||||
@ -105,6 +108,7 @@ class SentenceTransformersTextEmbedder:
|
||||
self.truncate_dim = truncate_dim
|
||||
self.model_kwargs = model_kwargs
|
||||
self.tokenizer_kwargs = tokenizer_kwargs
|
||||
self.config_kwargs = config_kwargs
|
||||
self.embedding_backend = None
|
||||
self.precision = precision
|
||||
|
||||
@ -135,6 +139,7 @@ class SentenceTransformersTextEmbedder:
|
||||
truncate_dim=self.truncate_dim,
|
||||
model_kwargs=self.model_kwargs,
|
||||
tokenizer_kwargs=self.tokenizer_kwargs,
|
||||
config_kwargs=self.config_kwargs,
|
||||
precision=self.precision,
|
||||
)
|
||||
if serialization_dict["init_parameters"].get("model_kwargs") is not None:
|
||||
@ -172,6 +177,7 @@ class SentenceTransformersTextEmbedder:
|
||||
truncate_dim=self.truncate_dim,
|
||||
model_kwargs=self.model_kwargs,
|
||||
tokenizer_kwargs=self.tokenizer_kwargs,
|
||||
config_kwargs=self.config_kwargs,
|
||||
)
|
||||
if self.tokenizer_kwargs and self.tokenizer_kwargs.get("model_max_length"):
|
||||
self.embedding_backend.model.max_seq_length = self.tokenizer_kwargs["model_max_length"]
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
SentenceTransformersTextEmbedder now supports config_kwargs for additional parameters when loading the model configuration
|
||||
@ -70,6 +70,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
"truncate_dim": None,
|
||||
"model_kwargs": None,
|
||||
"tokenizer_kwargs": None,
|
||||
"config_kwargs": None,
|
||||
"precision": "float32",
|
||||
},
|
||||
}
|
||||
@ -88,6 +89,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
truncate_dim=256,
|
||||
model_kwargs={"torch_dtype": torch.float32},
|
||||
tokenizer_kwargs={"model_max_length": 512},
|
||||
config_kwargs={"use_memory_efficient_attention": False},
|
||||
precision="int8",
|
||||
)
|
||||
data = component.to_dict()
|
||||
@ -106,6 +108,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
"truncate_dim": 256,
|
||||
"model_kwargs": {"torch_dtype": "torch.float32"},
|
||||
"tokenizer_kwargs": {"model_max_length": 512},
|
||||
"config_kwargs": {"use_memory_efficient_attention": False},
|
||||
"precision": "int8",
|
||||
},
|
||||
}
|
||||
@ -131,6 +134,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
"truncate_dim": None,
|
||||
"model_kwargs": {"torch_dtype": "torch.float32"},
|
||||
"tokenizer_kwargs": {"model_max_length": 512},
|
||||
"config_kwargs": {"use_memory_efficient_attention": False},
|
||||
"precision": "float32",
|
||||
},
|
||||
}
|
||||
@ -147,6 +151,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
assert component.truncate_dim is None
|
||||
assert component.model_kwargs == {"torch_dtype": torch.float32}
|
||||
assert component.tokenizer_kwargs == {"model_max_length": 512}
|
||||
assert component.config_kwargs == {"use_memory_efficient_attention": False}
|
||||
assert component.precision == "float32"
|
||||
|
||||
def test_from_dict_no_default_parameters(self):
|
||||
@ -218,6 +223,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
truncate_dim=None,
|
||||
model_kwargs=None,
|
||||
tokenizer_kwargs={"model_max_length": 512},
|
||||
config_kwargs=None,
|
||||
)
|
||||
|
||||
@patch(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user