feat: Add ONNX & OpenVINO backend support, and torch dtype kwargs in Sentence Transformers Components (#8813)

* initial rough draft

* expose backend instead of extracting from model_kwargs

* explictly set backend model path

* add reno

* expose backend for ST diversity backend

* add dtype tests and expose kwargs to ST ranker for backend parameters

* skip dtype tests as torch isnt compiled with cuda

* add new openvino dependency release, unskip tests

* resolve suggestion

* mock calls, turn integrations into unit tests

* remove unnecessary test dependencies
This commit is contained in:
Ulises M 2025-02-13 03:04:14 -08:00 committed by GitHub
parent 71416c81bc
commit bfdad40a80
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 299 additions and 6 deletions

View File

@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Literal, Optional
from haystack.lazy_imports import LazyImport
from haystack.utils.auth import Secret
@ -28,8 +28,9 @@ class _SentenceTransformersEmbeddingBackendFactory:
model_kwargs: Optional[Dict[str, Any]] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
config_kwargs: Optional[Dict[str, Any]] = None,
backend: Literal["torch", "onnx", "openvino"] = "torch",
):
embedding_backend_id = f"{model}{device}{auth_token}{truncate_dim}"
embedding_backend_id = f"{model}{device}{auth_token}{truncate_dim}{backend}"
if embedding_backend_id in _SentenceTransformersEmbeddingBackendFactory._instances:
return _SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id]
@ -42,6 +43,7 @@ class _SentenceTransformersEmbeddingBackendFactory:
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
config_kwargs=config_kwargs,
backend=backend,
)
_SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend
return embedding_backend
@ -62,8 +64,10 @@ class _SentenceTransformersEmbeddingBackend:
model_kwargs: Optional[Dict[str, Any]] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
config_kwargs: Optional[Dict[str, Any]] = None,
backend: Literal["torch", "onnx", "openvino"] = "torch",
):
sentence_transformers_import.check()
self.model = SentenceTransformer(
model_name_or_path=model,
device=device,
@ -73,6 +77,7 @@ class _SentenceTransformersEmbeddingBackend:
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
config_kwargs=config_kwargs,
backend=backend,
)
def embed(self, data: List[str], **kwargs) -> List[List[float]]:

View File

@ -57,6 +57,7 @@ class SentenceTransformersDocumentEmbedder:
config_kwargs: Optional[Dict[str, Any]] = None,
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
encode_kwargs: Optional[Dict[str, Any]] = None,
backend: Literal["torch", "onnx", "openvino"] = "torch",
):
"""
Creates a SentenceTransformersDocumentEmbedder component.
@ -109,6 +110,10 @@ class SentenceTransformersDocumentEmbedder:
Additional keyword arguments for `SentenceTransformer.encode` when embedding documents.
This parameter is provided for fine customization. Be careful not to clash with already set parameters and
avoid passing parameters that change the output type.
:param backend:
The backend to use for the Sentence Transformers model. Choose from "torch", "onnx", or "openvino".
Refer to the [Sentence Transformers documentation](https://sbert.net/docs/sentence_transformer/usage/efficiency.html)
for more information on acceleration and quantization options.
"""
self.model = model
@ -129,6 +134,7 @@ class SentenceTransformersDocumentEmbedder:
self.encode_kwargs = encode_kwargs
self.embedding_backend = None
self.precision = precision
self.backend = backend
def _get_telemetry_data(self) -> Dict[str, Any]:
"""
@ -162,6 +168,7 @@ class SentenceTransformersDocumentEmbedder:
config_kwargs=self.config_kwargs,
precision=self.precision,
encode_kwargs=self.encode_kwargs,
backend=self.backend,
)
if serialization_dict["init_parameters"].get("model_kwargs") is not None:
serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
@ -199,6 +206,7 @@ class SentenceTransformersDocumentEmbedder:
model_kwargs=self.model_kwargs,
tokenizer_kwargs=self.tokenizer_kwargs,
config_kwargs=self.config_kwargs,
backend=self.backend,
)
if self.tokenizer_kwargs and self.tokenizer_kwargs.get("model_max_length"):
self.embedding_backend.model.max_seq_length = self.tokenizer_kwargs["model_max_length"]

View File

@ -51,6 +51,7 @@ class SentenceTransformersTextEmbedder:
config_kwargs: Optional[Dict[str, Any]] = None,
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
encode_kwargs: Optional[Dict[str, Any]] = None,
backend: Literal["torch", "onnx", "openvino"] = "torch",
):
"""
Create a SentenceTransformersTextEmbedder component.
@ -99,6 +100,10 @@ class SentenceTransformersTextEmbedder:
Additional keyword arguments for `SentenceTransformer.encode` when embedding texts.
This parameter is provided for fine customization. Be careful not to clash with already set parameters and
avoid passing parameters that change the output type.
:param backend:
The backend to use for the Sentence Transformers model. Choose from "torch", "onnx", or "openvino".
Refer to the [Sentence Transformers documentation](https://sbert.net/docs/sentence_transformer/usage/efficiency.html)
for more information on acceleration and quantization options.
"""
self.model = model
@ -117,6 +122,7 @@ class SentenceTransformersTextEmbedder:
self.encode_kwargs = encode_kwargs
self.embedding_backend = None
self.precision = precision
self.backend = backend
def _get_telemetry_data(self) -> Dict[str, Any]:
"""
@ -148,6 +154,7 @@ class SentenceTransformersTextEmbedder:
config_kwargs=self.config_kwargs,
precision=self.precision,
encode_kwargs=self.encode_kwargs,
backend=self.backend,
)
if serialization_dict["init_parameters"].get("model_kwargs") is not None:
serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
@ -185,6 +192,7 @@ class SentenceTransformersTextEmbedder:
model_kwargs=self.model_kwargs,
tokenizer_kwargs=self.tokenizer_kwargs,
config_kwargs=self.config_kwargs,
backend=self.backend,
)
if self.tokenizer_kwargs and self.tokenizer_kwargs.get("model_max_length"):
self.embedding_backend.model.max_seq_length = self.tokenizer_kwargs["model_max_length"]

View File

@ -3,11 +3,12 @@
# SPDX-License-Identifier: Apache-2.0
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union
from haystack import Document, component, default_from_dict, default_to_dict, logging
from haystack.lazy_imports import LazyImport
from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace
from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs
logger = logging.getLogger(__name__)
@ -111,7 +112,7 @@ class SentenceTransformersDiversityRanker:
```
""" # noqa: E501
def __init__(
def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
self,
model: str = "sentence-transformers/all-MiniLM-L6-v2",
top_k: int = 10,
@ -126,7 +127,11 @@ class SentenceTransformersDiversityRanker:
embedding_separator: str = "\n",
strategy: Union[str, DiversityRankingStrategy] = "greedy_diversity_order",
lambda_threshold: float = 0.5,
): # pylint: disable=too-many-positional-arguments
model_kwargs: Optional[Dict[str, Any]] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
config_kwargs: Optional[Dict[str, Any]] = None,
backend: Literal["torch", "onnx", "openvino"] = "torch",
):
"""
Initialize a SentenceTransformersDiversityRanker.
@ -152,6 +157,18 @@ class SentenceTransformersDiversityRanker:
"maximum_margin_relevance".
:param lambda_threshold: The trade-off parameter between relevance and diversity. Only used when strategy is
"maximum_margin_relevance".
:param model_kwargs:
Additional keyword arguments for `AutoModelForSequenceClassification.from_pretrained`
when loading the model. Refer to specific model documentation for available kwargs.
:param tokenizer_kwargs:
Additional keyword arguments for `AutoTokenizer.from_pretrained` when loading the tokenizer.
Refer to specific model documentation for available kwargs.
:param config_kwargs:
Additional keyword arguments for `AutoConfig.from_pretrained` when loading the model configuration.
:param backend:
The backend to use for the Sentence Transformers model. Choose from "torch", "onnx", or "openvino".
Refer to the [Sentence Transformers documentation](https://sbert.net/docs/sentence_transformer/usage/efficiency.html)
for more information on acceleration and quantization options.
"""
torch_and_sentence_transformers_import.check()
@ -172,6 +189,10 @@ class SentenceTransformersDiversityRanker:
self.strategy = DiversityRankingStrategy.from_str(strategy) if isinstance(strategy, str) else strategy
self.lambda_threshold = lambda_threshold or 0.5
self._check_lambda_threshold(self.lambda_threshold, self.strategy)
self.model_kwargs = model_kwargs
self.tokenizer_kwargs = tokenizer_kwargs
self.config_kwargs = config_kwargs
self.backend = backend
def warm_up(self):
"""
@ -182,6 +203,10 @@ class SentenceTransformersDiversityRanker:
model_name_or_path=self.model_name_or_path,
device=self.device.to_torch_str(),
use_auth_token=self.token.resolve_value() if self.token else None,
model_kwargs=self.model_kwargs,
tokenizer_kwargs=self.tokenizer_kwargs,
config_kwargs=self.config_kwargs,
backend=self.backend,
)
def to_dict(self) -> Dict[str, Any]:
@ -191,7 +216,7 @@ class SentenceTransformersDiversityRanker:
:returns:
Dictionary with serialized data.
"""
return default_to_dict(
serialization_dict = default_to_dict(
self,
model=self.model_name_or_path,
top_k=self.top_k,
@ -206,7 +231,14 @@ class SentenceTransformersDiversityRanker:
embedding_separator=self.embedding_separator,
strategy=str(self.strategy),
lambda_threshold=self.lambda_threshold,
model_kwargs=self.model_kwargs,
tokenizer_kwargs=self.tokenizer_kwargs,
config_kwargs=self.config_kwargs,
backend=self.backend,
)
if serialization_dict["init_parameters"].get("model_kwargs") is not None:
serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
return serialization_dict
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersDiversityRanker":
@ -222,6 +254,8 @@ class SentenceTransformersDiversityRanker:
if init_params.get("device") is not None:
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
deserialize_secrets_inplace(init_params, keys=["token"])
if init_params.get("model_kwargs") is not None:
deserialize_hf_model_kwargs(init_params["model_kwargs"])
return default_from_dict(cls, data)
def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:

View File

@ -0,0 +1,5 @@
---
features:
- |
Sentence Transformers components now support ONNX and OpenVINO backends through the "backend" parameter.
Supported backends are torch (default), onnx, and openvino. Refer to the [Sentence Transformers documentation](https://sbert.net/docs/sentence_transformer/usage/efficiency.html) for more information.

View File

@ -82,6 +82,7 @@ class TestSentenceTransformersDocumentEmbedder:
"encode_kwargs": None,
"config_kwargs": None,
"precision": "float32",
"backend": "torch",
},
}
@ -127,6 +128,7 @@ class TestSentenceTransformersDocumentEmbedder:
"config_kwargs": {"use_memory_efficient_attention": True},
"precision": "int8",
"encode_kwargs": {"task": "clustering"},
"backend": "torch",
},
}
@ -252,6 +254,7 @@ class TestSentenceTransformersDocumentEmbedder:
model_kwargs=None,
tokenizer_kwargs={"model_max_length": 512},
config_kwargs={"use_memory_efficient_attention": True},
backend="torch",
)
@patch(
@ -357,3 +360,82 @@ class TestSentenceTransformersDocumentEmbedder:
normalize_embeddings=False,
precision="float32",
)
@patch(
"haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
)
def test_model_onnx_backend(self, mocked_factory):
onnx_embedder = SentenceTransformersDocumentEmbedder(
model="sentence-transformers/all-MiniLM-L6-v2",
token=None,
device=ComponentDevice.from_str("cpu"),
model_kwargs={
"file_name": "onnx/model.onnx"
}, # setting the path isn't necessary if the repo contains a "onnx/model.onnx" file but this is to prevent a HF warning
backend="onnx",
)
onnx_embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(
model="sentence-transformers/all-MiniLM-L6-v2",
device="cpu",
auth_token=None,
trust_remote_code=False,
truncate_dim=None,
model_kwargs={"file_name": "onnx/model.onnx"},
tokenizer_kwargs=None,
config_kwargs=None,
backend="onnx",
)
@patch(
"haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
)
def test_model_openvino_backend(self, mocked_factory):
openvino_embedder = SentenceTransformersDocumentEmbedder(
model="sentence-transformers/all-MiniLM-L6-v2",
token=None,
device=ComponentDevice.from_str("cpu"),
model_kwargs={
"file_name": "openvino/openvino_model.xml"
}, # setting the path isn't necessary if the repo contains a "openvino/openvino_model.xml" file but this is to prevent a HF warning
backend="openvino",
)
openvino_embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(
model="sentence-transformers/all-MiniLM-L6-v2",
device="cpu",
auth_token=None,
trust_remote_code=False,
truncate_dim=None,
model_kwargs={"file_name": "openvino/openvino_model.xml"},
tokenizer_kwargs=None,
config_kwargs=None,
backend="openvino",
)
@patch(
"haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
)
@pytest.mark.parametrize("model_kwargs", [{"torch_dtype": "bfloat16"}, {"torch_dtype": "float16"}])
def test_dtype_on_gpu(self, mocked_factory, model_kwargs):
torch_dtype_embedder = SentenceTransformersDocumentEmbedder(
model="sentence-transformers/all-MiniLM-L6-v2",
token=None,
device=ComponentDevice.from_str("cuda:0"),
model_kwargs=model_kwargs,
)
torch_dtype_embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(
model="sentence-transformers/all-MiniLM-L6-v2",
device="cuda:0",
auth_token=None,
trust_remote_code=False,
truncate_dim=None,
model_kwargs=model_kwargs,
tokenizer_kwargs=None,
config_kwargs=None,
backend="torch",
)

View File

@ -33,6 +33,7 @@ def test_model_initialization(mock_sentence_transformer):
auth_token=Secret.from_token("fake-api-token"),
trust_remote_code=True,
truncate_dim=256,
backend="torch",
)
mock_sentence_transformer.assert_called_once_with(
model_name_or_path="model",
@ -43,6 +44,7 @@ def test_model_initialization(mock_sentence_transformer):
model_kwargs=None,
tokenizer_kwargs=None,
config_kwargs=None,
backend="torch",
)

View File

@ -73,6 +73,7 @@ class TestSentenceTransformersTextEmbedder:
"encode_kwargs": None,
"config_kwargs": None,
"precision": "float32",
"backend": "torch",
},
}
@ -113,6 +114,7 @@ class TestSentenceTransformersTextEmbedder:
"config_kwargs": {"use_memory_efficient_attention": False},
"precision": "int8",
"encode_kwargs": {"task": "clustering"},
"backend": "torch",
},
}
@ -227,6 +229,7 @@ class TestSentenceTransformersTextEmbedder:
model_kwargs=None,
tokenizer_kwargs={"model_max_length": 512},
config_kwargs=None,
backend="torch",
)
@patch(
@ -314,3 +317,82 @@ class TestSentenceTransformersTextEmbedder:
precision="float32",
task="retrieval.query",
)
@patch(
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
)
def test_model_onnx_backend(self, mocked_factory):
onnx_embedder = SentenceTransformersTextEmbedder(
model="sentence-transformers/all-MiniLM-L6-v2",
token=None,
device=ComponentDevice.from_str("cpu"),
model_kwargs={
"file_name": "onnx/model.onnx"
}, # setting the path isn't necessary if the repo contains a "onnx/model.onnx" file but this is to prevent a HF warning
backend="onnx",
)
onnx_embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(
model="sentence-transformers/all-MiniLM-L6-v2",
device="cpu",
auth_token=None,
trust_remote_code=False,
truncate_dim=None,
model_kwargs={"file_name": "onnx/model.onnx"},
tokenizer_kwargs=None,
config_kwargs=None,
backend="onnx",
)
@patch(
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
)
def test_model_openvino_backend(self, mocked_factory):
openvino_embedder = SentenceTransformersTextEmbedder(
model="sentence-transformers/all-MiniLM-L6-v2",
token=None,
device=ComponentDevice.from_str("cpu"),
model_kwargs={
"file_name": "openvino/openvino_model.xml"
}, # setting the path isn't necessary if the repo contains a "openvino/openvino_model.xml" file but this is to prevent a HF warning
backend="openvino",
)
openvino_embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(
model="sentence-transformers/all-MiniLM-L6-v2",
device="cpu",
auth_token=None,
trust_remote_code=False,
truncate_dim=None,
model_kwargs={"file_name": "openvino/openvino_model.xml"},
tokenizer_kwargs=None,
config_kwargs=None,
backend="openvino",
)
@patch(
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
)
@pytest.mark.parametrize("model_kwargs", [{"torch_dtype": "bfloat16"}, {"torch_dtype": "float16"}])
def test_dtype_on_gpu(self, mocked_factory, model_kwargs):
torch_dtype_embedder = SentenceTransformersTextEmbedder(
model="sentence-transformers/all-MiniLM-L6-v2",
token=None,
device=ComponentDevice.from_str("cuda:0"),
model_kwargs=model_kwargs,
)
torch_dtype_embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(
model="sentence-transformers/all-MiniLM-L6-v2",
device="cuda:0",
auth_token=None,
trust_remote_code=False,
truncate_dim=None,
model_kwargs=model_kwargs,
tokenizer_kwargs=None,
config_kwargs=None,
backend="torch",
)

View File

@ -291,6 +291,10 @@ class TestSentenceTransformersDiversityRanker:
model_name_or_path="mock_model_name",
device=ComponentDevice.resolve_device(None).to_torch_str(),
use_auth_token=None,
model_kwargs=None,
tokenizer_kwargs=None,
config_kwargs=None,
backend="torch",
)
assert ranker.model == mock_model_instance
@ -721,3 +725,66 @@ class TestSentenceTransformersDiversityRanker:
"Wind turbine technology",
]
assert [doc.content for doc in results["documents"]] == expected
@patch("haystack.components.rankers.sentence_transformers_diversity.SentenceTransformer")
def test_model_onnx_backend(self, mocked_sentence_transformer):
ranker = SentenceTransformersDiversityRanker(
model="sentence-transformers/all-MiniLM-L6-v2",
token=None,
device=ComponentDevice.from_str("cpu"),
model_kwargs={"file_name": "onnx/model.onnx"},
backend="onnx",
)
ranker.warm_up()
mocked_sentence_transformer.assert_called_once_with(
model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
device="cpu",
use_auth_token=None,
model_kwargs={"file_name": "onnx/model.onnx"},
tokenizer_kwargs=None,
config_kwargs=None,
backend="onnx",
)
@patch("haystack.components.rankers.sentence_transformers_diversity.SentenceTransformer")
def test_model_openvino_backend(self, mocked_sentence_transformer):
ranker = SentenceTransformersDiversityRanker(
model="sentence-transformers/all-MiniLM-L6-v2",
token=None,
device=ComponentDevice.from_str("cpu"),
model_kwargs={"file_name": "openvino/openvino_model.xml"},
backend="openvino",
)
ranker.warm_up()
mocked_sentence_transformer.assert_called_once_with(
model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
device="cpu",
use_auth_token=None,
model_kwargs={"file_name": "openvino/openvino_model.xml"},
tokenizer_kwargs=None,
config_kwargs=None,
backend="openvino",
)
@patch("haystack.components.rankers.sentence_transformers_diversity.SentenceTransformer")
@pytest.mark.parametrize("model_kwargs", [{"torch_dtype": "float16"}, {"torch_dtype": "bfloat16"}])
def test_dtype_on_gpu(self, mocked_sentence_transformer, model_kwargs):
ranker = SentenceTransformersDiversityRanker(
model="sentence-transformers/all-MiniLM-L6-v2",
token=None,
device=ComponentDevice.from_str("cuda:0"),
model_kwargs=model_kwargs,
)
ranker.warm_up()
mocked_sentence_transformer.assert_called_once_with(
model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
device="cuda:0",
use_auth_token=None,
model_kwargs=model_kwargs,
tokenizer_kwargs=None,
config_kwargs=None,
backend="torch",
)