feat: Add ONNX & OpenVINO backend support, and torch dtype kwargs in Sentence Transformers Components (#8813)

* initial rough draft

* expose backend instead of extracting from model_kwargs

* explictly set backend model path

* add reno

* expose backend for ST diversity backend

* add dtype tests and expose kwargs to ST ranker for backend parameters

* skip dtype tests as torch isnt compiled with cuda

* add new openvino dependency release, unskip tests

* resolve suggestion

* mock calls, turn integrations into unit tests

* remove unnecessary test dependencies
This commit is contained in:
Ulises M 2025-02-13 03:04:14 -08:00 committed by GitHub
parent 71416c81bc
commit bfdad40a80
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 299 additions and 6 deletions

View File

@ -2,7 +2,7 @@
# #
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Literal, Optional
from haystack.lazy_imports import LazyImport from haystack.lazy_imports import LazyImport
from haystack.utils.auth import Secret from haystack.utils.auth import Secret
@ -28,8 +28,9 @@ class _SentenceTransformersEmbeddingBackendFactory:
model_kwargs: Optional[Dict[str, Any]] = None, model_kwargs: Optional[Dict[str, Any]] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None, tokenizer_kwargs: Optional[Dict[str, Any]] = None,
config_kwargs: Optional[Dict[str, Any]] = None, config_kwargs: Optional[Dict[str, Any]] = None,
backend: Literal["torch", "onnx", "openvino"] = "torch",
): ):
embedding_backend_id = f"{model}{device}{auth_token}{truncate_dim}" embedding_backend_id = f"{model}{device}{auth_token}{truncate_dim}{backend}"
if embedding_backend_id in _SentenceTransformersEmbeddingBackendFactory._instances: if embedding_backend_id in _SentenceTransformersEmbeddingBackendFactory._instances:
return _SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] return _SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id]
@ -42,6 +43,7 @@ class _SentenceTransformersEmbeddingBackendFactory:
model_kwargs=model_kwargs, model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs, tokenizer_kwargs=tokenizer_kwargs,
config_kwargs=config_kwargs, config_kwargs=config_kwargs,
backend=backend,
) )
_SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend _SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend
return embedding_backend return embedding_backend
@ -62,8 +64,10 @@ class _SentenceTransformersEmbeddingBackend:
model_kwargs: Optional[Dict[str, Any]] = None, model_kwargs: Optional[Dict[str, Any]] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None, tokenizer_kwargs: Optional[Dict[str, Any]] = None,
config_kwargs: Optional[Dict[str, Any]] = None, config_kwargs: Optional[Dict[str, Any]] = None,
backend: Literal["torch", "onnx", "openvino"] = "torch",
): ):
sentence_transformers_import.check() sentence_transformers_import.check()
self.model = SentenceTransformer( self.model = SentenceTransformer(
model_name_or_path=model, model_name_or_path=model,
device=device, device=device,
@ -73,6 +77,7 @@ class _SentenceTransformersEmbeddingBackend:
model_kwargs=model_kwargs, model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs, tokenizer_kwargs=tokenizer_kwargs,
config_kwargs=config_kwargs, config_kwargs=config_kwargs,
backend=backend,
) )
def embed(self, data: List[str], **kwargs) -> List[List[float]]: def embed(self, data: List[str], **kwargs) -> List[List[float]]:

View File

@ -57,6 +57,7 @@ class SentenceTransformersDocumentEmbedder:
config_kwargs: Optional[Dict[str, Any]] = None, config_kwargs: Optional[Dict[str, Any]] = None,
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32", precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
encode_kwargs: Optional[Dict[str, Any]] = None, encode_kwargs: Optional[Dict[str, Any]] = None,
backend: Literal["torch", "onnx", "openvino"] = "torch",
): ):
""" """
Creates a SentenceTransformersDocumentEmbedder component. Creates a SentenceTransformersDocumentEmbedder component.
@ -109,6 +110,10 @@ class SentenceTransformersDocumentEmbedder:
Additional keyword arguments for `SentenceTransformer.encode` when embedding documents. Additional keyword arguments for `SentenceTransformer.encode` when embedding documents.
This parameter is provided for fine customization. Be careful not to clash with already set parameters and This parameter is provided for fine customization. Be careful not to clash with already set parameters and
avoid passing parameters that change the output type. avoid passing parameters that change the output type.
:param backend:
The backend to use for the Sentence Transformers model. Choose from "torch", "onnx", or "openvino".
Refer to the [Sentence Transformers documentation](https://sbert.net/docs/sentence_transformer/usage/efficiency.html)
for more information on acceleration and quantization options.
""" """
self.model = model self.model = model
@ -129,6 +134,7 @@ class SentenceTransformersDocumentEmbedder:
self.encode_kwargs = encode_kwargs self.encode_kwargs = encode_kwargs
self.embedding_backend = None self.embedding_backend = None
self.precision = precision self.precision = precision
self.backend = backend
def _get_telemetry_data(self) -> Dict[str, Any]: def _get_telemetry_data(self) -> Dict[str, Any]:
""" """
@ -162,6 +168,7 @@ class SentenceTransformersDocumentEmbedder:
config_kwargs=self.config_kwargs, config_kwargs=self.config_kwargs,
precision=self.precision, precision=self.precision,
encode_kwargs=self.encode_kwargs, encode_kwargs=self.encode_kwargs,
backend=self.backend,
) )
if serialization_dict["init_parameters"].get("model_kwargs") is not None: if serialization_dict["init_parameters"].get("model_kwargs") is not None:
serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"]) serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
@ -199,6 +206,7 @@ class SentenceTransformersDocumentEmbedder:
model_kwargs=self.model_kwargs, model_kwargs=self.model_kwargs,
tokenizer_kwargs=self.tokenizer_kwargs, tokenizer_kwargs=self.tokenizer_kwargs,
config_kwargs=self.config_kwargs, config_kwargs=self.config_kwargs,
backend=self.backend,
) )
if self.tokenizer_kwargs and self.tokenizer_kwargs.get("model_max_length"): if self.tokenizer_kwargs and self.tokenizer_kwargs.get("model_max_length"):
self.embedding_backend.model.max_seq_length = self.tokenizer_kwargs["model_max_length"] self.embedding_backend.model.max_seq_length = self.tokenizer_kwargs["model_max_length"]

View File

@ -51,6 +51,7 @@ class SentenceTransformersTextEmbedder:
config_kwargs: Optional[Dict[str, Any]] = None, config_kwargs: Optional[Dict[str, Any]] = None,
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32", precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
encode_kwargs: Optional[Dict[str, Any]] = None, encode_kwargs: Optional[Dict[str, Any]] = None,
backend: Literal["torch", "onnx", "openvino"] = "torch",
): ):
""" """
Create a SentenceTransformersTextEmbedder component. Create a SentenceTransformersTextEmbedder component.
@ -99,6 +100,10 @@ class SentenceTransformersTextEmbedder:
Additional keyword arguments for `SentenceTransformer.encode` when embedding texts. Additional keyword arguments for `SentenceTransformer.encode` when embedding texts.
This parameter is provided for fine customization. Be careful not to clash with already set parameters and This parameter is provided for fine customization. Be careful not to clash with already set parameters and
avoid passing parameters that change the output type. avoid passing parameters that change the output type.
:param backend:
The backend to use for the Sentence Transformers model. Choose from "torch", "onnx", or "openvino".
Refer to the [Sentence Transformers documentation](https://sbert.net/docs/sentence_transformer/usage/efficiency.html)
for more information on acceleration and quantization options.
""" """
self.model = model self.model = model
@ -117,6 +122,7 @@ class SentenceTransformersTextEmbedder:
self.encode_kwargs = encode_kwargs self.encode_kwargs = encode_kwargs
self.embedding_backend = None self.embedding_backend = None
self.precision = precision self.precision = precision
self.backend = backend
def _get_telemetry_data(self) -> Dict[str, Any]: def _get_telemetry_data(self) -> Dict[str, Any]:
""" """
@ -148,6 +154,7 @@ class SentenceTransformersTextEmbedder:
config_kwargs=self.config_kwargs, config_kwargs=self.config_kwargs,
precision=self.precision, precision=self.precision,
encode_kwargs=self.encode_kwargs, encode_kwargs=self.encode_kwargs,
backend=self.backend,
) )
if serialization_dict["init_parameters"].get("model_kwargs") is not None: if serialization_dict["init_parameters"].get("model_kwargs") is not None:
serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"]) serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
@ -185,6 +192,7 @@ class SentenceTransformersTextEmbedder:
model_kwargs=self.model_kwargs, model_kwargs=self.model_kwargs,
tokenizer_kwargs=self.tokenizer_kwargs, tokenizer_kwargs=self.tokenizer_kwargs,
config_kwargs=self.config_kwargs, config_kwargs=self.config_kwargs,
backend=self.backend,
) )
if self.tokenizer_kwargs and self.tokenizer_kwargs.get("model_max_length"): if self.tokenizer_kwargs and self.tokenizer_kwargs.get("model_max_length"):
self.embedding_backend.model.max_seq_length = self.tokenizer_kwargs["model_max_length"] self.embedding_backend.model.max_seq_length = self.tokenizer_kwargs["model_max_length"]

View File

@ -3,11 +3,12 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Literal, Optional, Union
from haystack import Document, component, default_from_dict, default_to_dict, logging from haystack import Document, component, default_from_dict, default_to_dict, logging
from haystack.lazy_imports import LazyImport from haystack.lazy_imports import LazyImport
from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace
from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -111,7 +112,7 @@ class SentenceTransformersDiversityRanker:
``` ```
""" # noqa: E501 """ # noqa: E501
def __init__( def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
self, self,
model: str = "sentence-transformers/all-MiniLM-L6-v2", model: str = "sentence-transformers/all-MiniLM-L6-v2",
top_k: int = 10, top_k: int = 10,
@ -126,7 +127,11 @@ class SentenceTransformersDiversityRanker:
embedding_separator: str = "\n", embedding_separator: str = "\n",
strategy: Union[str, DiversityRankingStrategy] = "greedy_diversity_order", strategy: Union[str, DiversityRankingStrategy] = "greedy_diversity_order",
lambda_threshold: float = 0.5, lambda_threshold: float = 0.5,
): # pylint: disable=too-many-positional-arguments model_kwargs: Optional[Dict[str, Any]] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
config_kwargs: Optional[Dict[str, Any]] = None,
backend: Literal["torch", "onnx", "openvino"] = "torch",
):
""" """
Initialize a SentenceTransformersDiversityRanker. Initialize a SentenceTransformersDiversityRanker.
@ -152,6 +157,18 @@ class SentenceTransformersDiversityRanker:
"maximum_margin_relevance". "maximum_margin_relevance".
:param lambda_threshold: The trade-off parameter between relevance and diversity. Only used when strategy is :param lambda_threshold: The trade-off parameter between relevance and diversity. Only used when strategy is
"maximum_margin_relevance". "maximum_margin_relevance".
:param model_kwargs:
Additional keyword arguments for `AutoModelForSequenceClassification.from_pretrained`
when loading the model. Refer to specific model documentation for available kwargs.
:param tokenizer_kwargs:
Additional keyword arguments for `AutoTokenizer.from_pretrained` when loading the tokenizer.
Refer to specific model documentation for available kwargs.
:param config_kwargs:
Additional keyword arguments for `AutoConfig.from_pretrained` when loading the model configuration.
:param backend:
The backend to use for the Sentence Transformers model. Choose from "torch", "onnx", or "openvino".
Refer to the [Sentence Transformers documentation](https://sbert.net/docs/sentence_transformer/usage/efficiency.html)
for more information on acceleration and quantization options.
""" """
torch_and_sentence_transformers_import.check() torch_and_sentence_transformers_import.check()
@ -172,6 +189,10 @@ class SentenceTransformersDiversityRanker:
self.strategy = DiversityRankingStrategy.from_str(strategy) if isinstance(strategy, str) else strategy self.strategy = DiversityRankingStrategy.from_str(strategy) if isinstance(strategy, str) else strategy
self.lambda_threshold = lambda_threshold or 0.5 self.lambda_threshold = lambda_threshold or 0.5
self._check_lambda_threshold(self.lambda_threshold, self.strategy) self._check_lambda_threshold(self.lambda_threshold, self.strategy)
self.model_kwargs = model_kwargs
self.tokenizer_kwargs = tokenizer_kwargs
self.config_kwargs = config_kwargs
self.backend = backend
def warm_up(self): def warm_up(self):
""" """
@ -182,6 +203,10 @@ class SentenceTransformersDiversityRanker:
model_name_or_path=self.model_name_or_path, model_name_or_path=self.model_name_or_path,
device=self.device.to_torch_str(), device=self.device.to_torch_str(),
use_auth_token=self.token.resolve_value() if self.token else None, use_auth_token=self.token.resolve_value() if self.token else None,
model_kwargs=self.model_kwargs,
tokenizer_kwargs=self.tokenizer_kwargs,
config_kwargs=self.config_kwargs,
backend=self.backend,
) )
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
@ -191,7 +216,7 @@ class SentenceTransformersDiversityRanker:
:returns: :returns:
Dictionary with serialized data. Dictionary with serialized data.
""" """
return default_to_dict( serialization_dict = default_to_dict(
self, self,
model=self.model_name_or_path, model=self.model_name_or_path,
top_k=self.top_k, top_k=self.top_k,
@ -206,7 +231,14 @@ class SentenceTransformersDiversityRanker:
embedding_separator=self.embedding_separator, embedding_separator=self.embedding_separator,
strategy=str(self.strategy), strategy=str(self.strategy),
lambda_threshold=self.lambda_threshold, lambda_threshold=self.lambda_threshold,
model_kwargs=self.model_kwargs,
tokenizer_kwargs=self.tokenizer_kwargs,
config_kwargs=self.config_kwargs,
backend=self.backend,
) )
if serialization_dict["init_parameters"].get("model_kwargs") is not None:
serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
return serialization_dict
@classmethod @classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersDiversityRanker": def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersDiversityRanker":
@ -222,6 +254,8 @@ class SentenceTransformersDiversityRanker:
if init_params.get("device") is not None: if init_params.get("device") is not None:
init_params["device"] = ComponentDevice.from_dict(init_params["device"]) init_params["device"] = ComponentDevice.from_dict(init_params["device"])
deserialize_secrets_inplace(init_params, keys=["token"]) deserialize_secrets_inplace(init_params, keys=["token"])
if init_params.get("model_kwargs") is not None:
deserialize_hf_model_kwargs(init_params["model_kwargs"])
return default_from_dict(cls, data) return default_from_dict(cls, data)
def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:

View File

@ -0,0 +1,5 @@
---
features:
- |
Sentence Transformers components now support ONNX and OpenVINO backends through the "backend" parameter.
Supported backends are torch (default), onnx, and openvino. Refer to the [Sentence Transformers documentation](https://sbert.net/docs/sentence_transformer/usage/efficiency.html) for more information.

View File

@ -82,6 +82,7 @@ class TestSentenceTransformersDocumentEmbedder:
"encode_kwargs": None, "encode_kwargs": None,
"config_kwargs": None, "config_kwargs": None,
"precision": "float32", "precision": "float32",
"backend": "torch",
}, },
} }
@ -127,6 +128,7 @@ class TestSentenceTransformersDocumentEmbedder:
"config_kwargs": {"use_memory_efficient_attention": True}, "config_kwargs": {"use_memory_efficient_attention": True},
"precision": "int8", "precision": "int8",
"encode_kwargs": {"task": "clustering"}, "encode_kwargs": {"task": "clustering"},
"backend": "torch",
}, },
} }
@ -252,6 +254,7 @@ class TestSentenceTransformersDocumentEmbedder:
model_kwargs=None, model_kwargs=None,
tokenizer_kwargs={"model_max_length": 512}, tokenizer_kwargs={"model_max_length": 512},
config_kwargs={"use_memory_efficient_attention": True}, config_kwargs={"use_memory_efficient_attention": True},
backend="torch",
) )
@patch( @patch(
@ -357,3 +360,82 @@ class TestSentenceTransformersDocumentEmbedder:
normalize_embeddings=False, normalize_embeddings=False,
precision="float32", precision="float32",
) )
@patch(
"haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
)
def test_model_onnx_backend(self, mocked_factory):
onnx_embedder = SentenceTransformersDocumentEmbedder(
model="sentence-transformers/all-MiniLM-L6-v2",
token=None,
device=ComponentDevice.from_str("cpu"),
model_kwargs={
"file_name": "onnx/model.onnx"
}, # setting the path isn't necessary if the repo contains a "onnx/model.onnx" file but this is to prevent a HF warning
backend="onnx",
)
onnx_embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(
model="sentence-transformers/all-MiniLM-L6-v2",
device="cpu",
auth_token=None,
trust_remote_code=False,
truncate_dim=None,
model_kwargs={"file_name": "onnx/model.onnx"},
tokenizer_kwargs=None,
config_kwargs=None,
backend="onnx",
)
@patch(
"haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
)
def test_model_openvino_backend(self, mocked_factory):
openvino_embedder = SentenceTransformersDocumentEmbedder(
model="sentence-transformers/all-MiniLM-L6-v2",
token=None,
device=ComponentDevice.from_str("cpu"),
model_kwargs={
"file_name": "openvino/openvino_model.xml"
}, # setting the path isn't necessary if the repo contains a "openvino/openvino_model.xml" file but this is to prevent a HF warning
backend="openvino",
)
openvino_embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(
model="sentence-transformers/all-MiniLM-L6-v2",
device="cpu",
auth_token=None,
trust_remote_code=False,
truncate_dim=None,
model_kwargs={"file_name": "openvino/openvino_model.xml"},
tokenizer_kwargs=None,
config_kwargs=None,
backend="openvino",
)
@patch(
"haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
)
@pytest.mark.parametrize("model_kwargs", [{"torch_dtype": "bfloat16"}, {"torch_dtype": "float16"}])
def test_dtype_on_gpu(self, mocked_factory, model_kwargs):
torch_dtype_embedder = SentenceTransformersDocumentEmbedder(
model="sentence-transformers/all-MiniLM-L6-v2",
token=None,
device=ComponentDevice.from_str("cuda:0"),
model_kwargs=model_kwargs,
)
torch_dtype_embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(
model="sentence-transformers/all-MiniLM-L6-v2",
device="cuda:0",
auth_token=None,
trust_remote_code=False,
truncate_dim=None,
model_kwargs=model_kwargs,
tokenizer_kwargs=None,
config_kwargs=None,
backend="torch",
)

View File

@ -33,6 +33,7 @@ def test_model_initialization(mock_sentence_transformer):
auth_token=Secret.from_token("fake-api-token"), auth_token=Secret.from_token("fake-api-token"),
trust_remote_code=True, trust_remote_code=True,
truncate_dim=256, truncate_dim=256,
backend="torch",
) )
mock_sentence_transformer.assert_called_once_with( mock_sentence_transformer.assert_called_once_with(
model_name_or_path="model", model_name_or_path="model",
@ -43,6 +44,7 @@ def test_model_initialization(mock_sentence_transformer):
model_kwargs=None, model_kwargs=None,
tokenizer_kwargs=None, tokenizer_kwargs=None,
config_kwargs=None, config_kwargs=None,
backend="torch",
) )

View File

@ -73,6 +73,7 @@ class TestSentenceTransformersTextEmbedder:
"encode_kwargs": None, "encode_kwargs": None,
"config_kwargs": None, "config_kwargs": None,
"precision": "float32", "precision": "float32",
"backend": "torch",
}, },
} }
@ -113,6 +114,7 @@ class TestSentenceTransformersTextEmbedder:
"config_kwargs": {"use_memory_efficient_attention": False}, "config_kwargs": {"use_memory_efficient_attention": False},
"precision": "int8", "precision": "int8",
"encode_kwargs": {"task": "clustering"}, "encode_kwargs": {"task": "clustering"},
"backend": "torch",
}, },
} }
@ -227,6 +229,7 @@ class TestSentenceTransformersTextEmbedder:
model_kwargs=None, model_kwargs=None,
tokenizer_kwargs={"model_max_length": 512}, tokenizer_kwargs={"model_max_length": 512},
config_kwargs=None, config_kwargs=None,
backend="torch",
) )
@patch( @patch(
@ -314,3 +317,82 @@ class TestSentenceTransformersTextEmbedder:
precision="float32", precision="float32",
task="retrieval.query", task="retrieval.query",
) )
@patch(
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
)
def test_model_onnx_backend(self, mocked_factory):
onnx_embedder = SentenceTransformersTextEmbedder(
model="sentence-transformers/all-MiniLM-L6-v2",
token=None,
device=ComponentDevice.from_str("cpu"),
model_kwargs={
"file_name": "onnx/model.onnx"
}, # setting the path isn't necessary if the repo contains a "onnx/model.onnx" file but this is to prevent a HF warning
backend="onnx",
)
onnx_embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(
model="sentence-transformers/all-MiniLM-L6-v2",
device="cpu",
auth_token=None,
trust_remote_code=False,
truncate_dim=None,
model_kwargs={"file_name": "onnx/model.onnx"},
tokenizer_kwargs=None,
config_kwargs=None,
backend="onnx",
)
@patch(
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
)
def test_model_openvino_backend(self, mocked_factory):
openvino_embedder = SentenceTransformersTextEmbedder(
model="sentence-transformers/all-MiniLM-L6-v2",
token=None,
device=ComponentDevice.from_str("cpu"),
model_kwargs={
"file_name": "openvino/openvino_model.xml"
}, # setting the path isn't necessary if the repo contains a "openvino/openvino_model.xml" file but this is to prevent a HF warning
backend="openvino",
)
openvino_embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(
model="sentence-transformers/all-MiniLM-L6-v2",
device="cpu",
auth_token=None,
trust_remote_code=False,
truncate_dim=None,
model_kwargs={"file_name": "openvino/openvino_model.xml"},
tokenizer_kwargs=None,
config_kwargs=None,
backend="openvino",
)
@patch(
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
)
@pytest.mark.parametrize("model_kwargs", [{"torch_dtype": "bfloat16"}, {"torch_dtype": "float16"}])
def test_dtype_on_gpu(self, mocked_factory, model_kwargs):
torch_dtype_embedder = SentenceTransformersTextEmbedder(
model="sentence-transformers/all-MiniLM-L6-v2",
token=None,
device=ComponentDevice.from_str("cuda:0"),
model_kwargs=model_kwargs,
)
torch_dtype_embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(
model="sentence-transformers/all-MiniLM-L6-v2",
device="cuda:0",
auth_token=None,
trust_remote_code=False,
truncate_dim=None,
model_kwargs=model_kwargs,
tokenizer_kwargs=None,
config_kwargs=None,
backend="torch",
)

View File

@ -291,6 +291,10 @@ class TestSentenceTransformersDiversityRanker:
model_name_or_path="mock_model_name", model_name_or_path="mock_model_name",
device=ComponentDevice.resolve_device(None).to_torch_str(), device=ComponentDevice.resolve_device(None).to_torch_str(),
use_auth_token=None, use_auth_token=None,
model_kwargs=None,
tokenizer_kwargs=None,
config_kwargs=None,
backend="torch",
) )
assert ranker.model == mock_model_instance assert ranker.model == mock_model_instance
@ -721,3 +725,66 @@ class TestSentenceTransformersDiversityRanker:
"Wind turbine technology", "Wind turbine technology",
] ]
assert [doc.content for doc in results["documents"]] == expected assert [doc.content for doc in results["documents"]] == expected
@patch("haystack.components.rankers.sentence_transformers_diversity.SentenceTransformer")
def test_model_onnx_backend(self, mocked_sentence_transformer):
ranker = SentenceTransformersDiversityRanker(
model="sentence-transformers/all-MiniLM-L6-v2",
token=None,
device=ComponentDevice.from_str("cpu"),
model_kwargs={"file_name": "onnx/model.onnx"},
backend="onnx",
)
ranker.warm_up()
mocked_sentence_transformer.assert_called_once_with(
model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
device="cpu",
use_auth_token=None,
model_kwargs={"file_name": "onnx/model.onnx"},
tokenizer_kwargs=None,
config_kwargs=None,
backend="onnx",
)
@patch("haystack.components.rankers.sentence_transformers_diversity.SentenceTransformer")
def test_model_openvino_backend(self, mocked_sentence_transformer):
ranker = SentenceTransformersDiversityRanker(
model="sentence-transformers/all-MiniLM-L6-v2",
token=None,
device=ComponentDevice.from_str("cpu"),
model_kwargs={"file_name": "openvino/openvino_model.xml"},
backend="openvino",
)
ranker.warm_up()
mocked_sentence_transformer.assert_called_once_with(
model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
device="cpu",
use_auth_token=None,
model_kwargs={"file_name": "openvino/openvino_model.xml"},
tokenizer_kwargs=None,
config_kwargs=None,
backend="openvino",
)
@patch("haystack.components.rankers.sentence_transformers_diversity.SentenceTransformer")
@pytest.mark.parametrize("model_kwargs", [{"torch_dtype": "float16"}, {"torch_dtype": "bfloat16"}])
def test_dtype_on_gpu(self, mocked_sentence_transformer, model_kwargs):
ranker = SentenceTransformersDiversityRanker(
model="sentence-transformers/all-MiniLM-L6-v2",
token=None,
device=ComponentDevice.from_str("cuda:0"),
model_kwargs=model_kwargs,
)
ranker.warm_up()
mocked_sentence_transformer.assert_called_once_with(
model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
device="cuda:0",
use_auth_token=None,
model_kwargs=model_kwargs,
tokenizer_kwargs=None,
config_kwargs=None,
backend="torch",
)