mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-15 01:23:59 +00:00
feat: Add ONNX & OpenVINO backend support, and torch dtype kwargs in Sentence Transformers Components (#8813)
* initial rough draft * expose backend instead of extracting from model_kwargs * explictly set backend model path * add reno * expose backend for ST diversity backend * add dtype tests and expose kwargs to ST ranker for backend parameters * skip dtype tests as torch isnt compiled with cuda * add new openvino dependency release, unskip tests * resolve suggestion * mock calls, turn integrations into unit tests * remove unnecessary test dependencies
This commit is contained in:
parent
71416c81bc
commit
bfdad40a80
@ -2,7 +2,7 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils.auth import Secret
|
||||
@ -28,8 +28,9 @@ class _SentenceTransformersEmbeddingBackendFactory:
|
||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
config_kwargs: Optional[Dict[str, Any]] = None,
|
||||
backend: Literal["torch", "onnx", "openvino"] = "torch",
|
||||
):
|
||||
embedding_backend_id = f"{model}{device}{auth_token}{truncate_dim}"
|
||||
embedding_backend_id = f"{model}{device}{auth_token}{truncate_dim}{backend}"
|
||||
|
||||
if embedding_backend_id in _SentenceTransformersEmbeddingBackendFactory._instances:
|
||||
return _SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id]
|
||||
@ -42,6 +43,7 @@ class _SentenceTransformersEmbeddingBackendFactory:
|
||||
model_kwargs=model_kwargs,
|
||||
tokenizer_kwargs=tokenizer_kwargs,
|
||||
config_kwargs=config_kwargs,
|
||||
backend=backend,
|
||||
)
|
||||
_SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend
|
||||
return embedding_backend
|
||||
@ -62,8 +64,10 @@ class _SentenceTransformersEmbeddingBackend:
|
||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
config_kwargs: Optional[Dict[str, Any]] = None,
|
||||
backend: Literal["torch", "onnx", "openvino"] = "torch",
|
||||
):
|
||||
sentence_transformers_import.check()
|
||||
|
||||
self.model = SentenceTransformer(
|
||||
model_name_or_path=model,
|
||||
device=device,
|
||||
@ -73,6 +77,7 @@ class _SentenceTransformersEmbeddingBackend:
|
||||
model_kwargs=model_kwargs,
|
||||
tokenizer_kwargs=tokenizer_kwargs,
|
||||
config_kwargs=config_kwargs,
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
def embed(self, data: List[str], **kwargs) -> List[List[float]]:
|
||||
|
||||
@ -57,6 +57,7 @@ class SentenceTransformersDocumentEmbedder:
|
||||
config_kwargs: Optional[Dict[str, Any]] = None,
|
||||
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
|
||||
encode_kwargs: Optional[Dict[str, Any]] = None,
|
||||
backend: Literal["torch", "onnx", "openvino"] = "torch",
|
||||
):
|
||||
"""
|
||||
Creates a SentenceTransformersDocumentEmbedder component.
|
||||
@ -109,6 +110,10 @@ class SentenceTransformersDocumentEmbedder:
|
||||
Additional keyword arguments for `SentenceTransformer.encode` when embedding documents.
|
||||
This parameter is provided for fine customization. Be careful not to clash with already set parameters and
|
||||
avoid passing parameters that change the output type.
|
||||
:param backend:
|
||||
The backend to use for the Sentence Transformers model. Choose from "torch", "onnx", or "openvino".
|
||||
Refer to the [Sentence Transformers documentation](https://sbert.net/docs/sentence_transformer/usage/efficiency.html)
|
||||
for more information on acceleration and quantization options.
|
||||
"""
|
||||
|
||||
self.model = model
|
||||
@ -129,6 +134,7 @@ class SentenceTransformersDocumentEmbedder:
|
||||
self.encode_kwargs = encode_kwargs
|
||||
self.embedding_backend = None
|
||||
self.precision = precision
|
||||
self.backend = backend
|
||||
|
||||
def _get_telemetry_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -162,6 +168,7 @@ class SentenceTransformersDocumentEmbedder:
|
||||
config_kwargs=self.config_kwargs,
|
||||
precision=self.precision,
|
||||
encode_kwargs=self.encode_kwargs,
|
||||
backend=self.backend,
|
||||
)
|
||||
if serialization_dict["init_parameters"].get("model_kwargs") is not None:
|
||||
serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
|
||||
@ -199,6 +206,7 @@ class SentenceTransformersDocumentEmbedder:
|
||||
model_kwargs=self.model_kwargs,
|
||||
tokenizer_kwargs=self.tokenizer_kwargs,
|
||||
config_kwargs=self.config_kwargs,
|
||||
backend=self.backend,
|
||||
)
|
||||
if self.tokenizer_kwargs and self.tokenizer_kwargs.get("model_max_length"):
|
||||
self.embedding_backend.model.max_seq_length = self.tokenizer_kwargs["model_max_length"]
|
||||
|
||||
@ -51,6 +51,7 @@ class SentenceTransformersTextEmbedder:
|
||||
config_kwargs: Optional[Dict[str, Any]] = None,
|
||||
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
|
||||
encode_kwargs: Optional[Dict[str, Any]] = None,
|
||||
backend: Literal["torch", "onnx", "openvino"] = "torch",
|
||||
):
|
||||
"""
|
||||
Create a SentenceTransformersTextEmbedder component.
|
||||
@ -99,6 +100,10 @@ class SentenceTransformersTextEmbedder:
|
||||
Additional keyword arguments for `SentenceTransformer.encode` when embedding texts.
|
||||
This parameter is provided for fine customization. Be careful not to clash with already set parameters and
|
||||
avoid passing parameters that change the output type.
|
||||
:param backend:
|
||||
The backend to use for the Sentence Transformers model. Choose from "torch", "onnx", or "openvino".
|
||||
Refer to the [Sentence Transformers documentation](https://sbert.net/docs/sentence_transformer/usage/efficiency.html)
|
||||
for more information on acceleration and quantization options.
|
||||
"""
|
||||
|
||||
self.model = model
|
||||
@ -117,6 +122,7 @@ class SentenceTransformersTextEmbedder:
|
||||
self.encode_kwargs = encode_kwargs
|
||||
self.embedding_backend = None
|
||||
self.precision = precision
|
||||
self.backend = backend
|
||||
|
||||
def _get_telemetry_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -148,6 +154,7 @@ class SentenceTransformersTextEmbedder:
|
||||
config_kwargs=self.config_kwargs,
|
||||
precision=self.precision,
|
||||
encode_kwargs=self.encode_kwargs,
|
||||
backend=self.backend,
|
||||
)
|
||||
if serialization_dict["init_parameters"].get("model_kwargs") is not None:
|
||||
serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
|
||||
@ -185,6 +192,7 @@ class SentenceTransformersTextEmbedder:
|
||||
model_kwargs=self.model_kwargs,
|
||||
tokenizer_kwargs=self.tokenizer_kwargs,
|
||||
config_kwargs=self.config_kwargs,
|
||||
backend=self.backend,
|
||||
)
|
||||
if self.tokenizer_kwargs and self.tokenizer_kwargs.get("model_max_length"):
|
||||
self.embedding_backend.model.max_seq_length = self.tokenizer_kwargs["model_max_length"]
|
||||
|
||||
@ -3,11 +3,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from haystack import Document, component, default_from_dict, default_to_dict, logging
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace
|
||||
from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -111,7 +112,7 @@ class SentenceTransformersDiversityRanker:
|
||||
```
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(
|
||||
def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
|
||||
self,
|
||||
model: str = "sentence-transformers/all-MiniLM-L6-v2",
|
||||
top_k: int = 10,
|
||||
@ -126,7 +127,11 @@ class SentenceTransformersDiversityRanker:
|
||||
embedding_separator: str = "\n",
|
||||
strategy: Union[str, DiversityRankingStrategy] = "greedy_diversity_order",
|
||||
lambda_threshold: float = 0.5,
|
||||
): # pylint: disable=too-many-positional-arguments
|
||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
config_kwargs: Optional[Dict[str, Any]] = None,
|
||||
backend: Literal["torch", "onnx", "openvino"] = "torch",
|
||||
):
|
||||
"""
|
||||
Initialize a SentenceTransformersDiversityRanker.
|
||||
|
||||
@ -152,6 +157,18 @@ class SentenceTransformersDiversityRanker:
|
||||
"maximum_margin_relevance".
|
||||
:param lambda_threshold: The trade-off parameter between relevance and diversity. Only used when strategy is
|
||||
"maximum_margin_relevance".
|
||||
:param model_kwargs:
|
||||
Additional keyword arguments for `AutoModelForSequenceClassification.from_pretrained`
|
||||
when loading the model. Refer to specific model documentation for available kwargs.
|
||||
:param tokenizer_kwargs:
|
||||
Additional keyword arguments for `AutoTokenizer.from_pretrained` when loading the tokenizer.
|
||||
Refer to specific model documentation for available kwargs.
|
||||
:param config_kwargs:
|
||||
Additional keyword arguments for `AutoConfig.from_pretrained` when loading the model configuration.
|
||||
:param backend:
|
||||
The backend to use for the Sentence Transformers model. Choose from "torch", "onnx", or "openvino".
|
||||
Refer to the [Sentence Transformers documentation](https://sbert.net/docs/sentence_transformer/usage/efficiency.html)
|
||||
for more information on acceleration and quantization options.
|
||||
"""
|
||||
torch_and_sentence_transformers_import.check()
|
||||
|
||||
@ -172,6 +189,10 @@ class SentenceTransformersDiversityRanker:
|
||||
self.strategy = DiversityRankingStrategy.from_str(strategy) if isinstance(strategy, str) else strategy
|
||||
self.lambda_threshold = lambda_threshold or 0.5
|
||||
self._check_lambda_threshold(self.lambda_threshold, self.strategy)
|
||||
self.model_kwargs = model_kwargs
|
||||
self.tokenizer_kwargs = tokenizer_kwargs
|
||||
self.config_kwargs = config_kwargs
|
||||
self.backend = backend
|
||||
|
||||
def warm_up(self):
|
||||
"""
|
||||
@ -182,6 +203,10 @@ class SentenceTransformersDiversityRanker:
|
||||
model_name_or_path=self.model_name_or_path,
|
||||
device=self.device.to_torch_str(),
|
||||
use_auth_token=self.token.resolve_value() if self.token else None,
|
||||
model_kwargs=self.model_kwargs,
|
||||
tokenizer_kwargs=self.tokenizer_kwargs,
|
||||
config_kwargs=self.config_kwargs,
|
||||
backend=self.backend,
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
@ -191,7 +216,7 @@ class SentenceTransformersDiversityRanker:
|
||||
:returns:
|
||||
Dictionary with serialized data.
|
||||
"""
|
||||
return default_to_dict(
|
||||
serialization_dict = default_to_dict(
|
||||
self,
|
||||
model=self.model_name_or_path,
|
||||
top_k=self.top_k,
|
||||
@ -206,7 +231,14 @@ class SentenceTransformersDiversityRanker:
|
||||
embedding_separator=self.embedding_separator,
|
||||
strategy=str(self.strategy),
|
||||
lambda_threshold=self.lambda_threshold,
|
||||
model_kwargs=self.model_kwargs,
|
||||
tokenizer_kwargs=self.tokenizer_kwargs,
|
||||
config_kwargs=self.config_kwargs,
|
||||
backend=self.backend,
|
||||
)
|
||||
if serialization_dict["init_parameters"].get("model_kwargs") is not None:
|
||||
serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
|
||||
return serialization_dict
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersDiversityRanker":
|
||||
@ -222,6 +254,8 @@ class SentenceTransformersDiversityRanker:
|
||||
if init_params.get("device") is not None:
|
||||
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
|
||||
deserialize_secrets_inplace(init_params, keys=["token"])
|
||||
if init_params.get("model_kwargs") is not None:
|
||||
deserialize_hf_model_kwargs(init_params["model_kwargs"])
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
|
||||
|
||||
@ -0,0 +1,5 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Sentence Transformers components now support ONNX and OpenVINO backends through the "backend" parameter.
|
||||
Supported backends are torch (default), onnx, and openvino. Refer to the [Sentence Transformers documentation](https://sbert.net/docs/sentence_transformer/usage/efficiency.html) for more information.
|
||||
@ -82,6 +82,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
"encode_kwargs": None,
|
||||
"config_kwargs": None,
|
||||
"precision": "float32",
|
||||
"backend": "torch",
|
||||
},
|
||||
}
|
||||
|
||||
@ -127,6 +128,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
"config_kwargs": {"use_memory_efficient_attention": True},
|
||||
"precision": "int8",
|
||||
"encode_kwargs": {"task": "clustering"},
|
||||
"backend": "torch",
|
||||
},
|
||||
}
|
||||
|
||||
@ -252,6 +254,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
model_kwargs=None,
|
||||
tokenizer_kwargs={"model_max_length": 512},
|
||||
config_kwargs={"use_memory_efficient_attention": True},
|
||||
backend="torch",
|
||||
)
|
||||
|
||||
@patch(
|
||||
@ -357,3 +360,82 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
normalize_embeddings=False,
|
||||
precision="float32",
|
||||
)
|
||||
|
||||
@patch(
|
||||
"haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
|
||||
)
|
||||
def test_model_onnx_backend(self, mocked_factory):
|
||||
onnx_embedder = SentenceTransformersDocumentEmbedder(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
token=None,
|
||||
device=ComponentDevice.from_str("cpu"),
|
||||
model_kwargs={
|
||||
"file_name": "onnx/model.onnx"
|
||||
}, # setting the path isn't necessary if the repo contains a "onnx/model.onnx" file but this is to prevent a HF warning
|
||||
backend="onnx",
|
||||
)
|
||||
onnx_embedder.warm_up()
|
||||
|
||||
mocked_factory.get_embedding_backend.assert_called_once_with(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
device="cpu",
|
||||
auth_token=None,
|
||||
trust_remote_code=False,
|
||||
truncate_dim=None,
|
||||
model_kwargs={"file_name": "onnx/model.onnx"},
|
||||
tokenizer_kwargs=None,
|
||||
config_kwargs=None,
|
||||
backend="onnx",
|
||||
)
|
||||
|
||||
@patch(
|
||||
"haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
|
||||
)
|
||||
def test_model_openvino_backend(self, mocked_factory):
|
||||
openvino_embedder = SentenceTransformersDocumentEmbedder(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
token=None,
|
||||
device=ComponentDevice.from_str("cpu"),
|
||||
model_kwargs={
|
||||
"file_name": "openvino/openvino_model.xml"
|
||||
}, # setting the path isn't necessary if the repo contains a "openvino/openvino_model.xml" file but this is to prevent a HF warning
|
||||
backend="openvino",
|
||||
)
|
||||
openvino_embedder.warm_up()
|
||||
|
||||
mocked_factory.get_embedding_backend.assert_called_once_with(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
device="cpu",
|
||||
auth_token=None,
|
||||
trust_remote_code=False,
|
||||
truncate_dim=None,
|
||||
model_kwargs={"file_name": "openvino/openvino_model.xml"},
|
||||
tokenizer_kwargs=None,
|
||||
config_kwargs=None,
|
||||
backend="openvino",
|
||||
)
|
||||
|
||||
@patch(
|
||||
"haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
|
||||
)
|
||||
@pytest.mark.parametrize("model_kwargs", [{"torch_dtype": "bfloat16"}, {"torch_dtype": "float16"}])
|
||||
def test_dtype_on_gpu(self, mocked_factory, model_kwargs):
|
||||
torch_dtype_embedder = SentenceTransformersDocumentEmbedder(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
token=None,
|
||||
device=ComponentDevice.from_str("cuda:0"),
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
torch_dtype_embedder.warm_up()
|
||||
|
||||
mocked_factory.get_embedding_backend.assert_called_once_with(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
device="cuda:0",
|
||||
auth_token=None,
|
||||
trust_remote_code=False,
|
||||
truncate_dim=None,
|
||||
model_kwargs=model_kwargs,
|
||||
tokenizer_kwargs=None,
|
||||
config_kwargs=None,
|
||||
backend="torch",
|
||||
)
|
||||
|
||||
@ -33,6 +33,7 @@ def test_model_initialization(mock_sentence_transformer):
|
||||
auth_token=Secret.from_token("fake-api-token"),
|
||||
trust_remote_code=True,
|
||||
truncate_dim=256,
|
||||
backend="torch",
|
||||
)
|
||||
mock_sentence_transformer.assert_called_once_with(
|
||||
model_name_or_path="model",
|
||||
@ -43,6 +44,7 @@ def test_model_initialization(mock_sentence_transformer):
|
||||
model_kwargs=None,
|
||||
tokenizer_kwargs=None,
|
||||
config_kwargs=None,
|
||||
backend="torch",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -73,6 +73,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
"encode_kwargs": None,
|
||||
"config_kwargs": None,
|
||||
"precision": "float32",
|
||||
"backend": "torch",
|
||||
},
|
||||
}
|
||||
|
||||
@ -113,6 +114,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
"config_kwargs": {"use_memory_efficient_attention": False},
|
||||
"precision": "int8",
|
||||
"encode_kwargs": {"task": "clustering"},
|
||||
"backend": "torch",
|
||||
},
|
||||
}
|
||||
|
||||
@ -227,6 +229,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
model_kwargs=None,
|
||||
tokenizer_kwargs={"model_max_length": 512},
|
||||
config_kwargs=None,
|
||||
backend="torch",
|
||||
)
|
||||
|
||||
@patch(
|
||||
@ -314,3 +317,82 @@ class TestSentenceTransformersTextEmbedder:
|
||||
precision="float32",
|
||||
task="retrieval.query",
|
||||
)
|
||||
|
||||
@patch(
|
||||
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
|
||||
)
|
||||
def test_model_onnx_backend(self, mocked_factory):
|
||||
onnx_embedder = SentenceTransformersTextEmbedder(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
token=None,
|
||||
device=ComponentDevice.from_str("cpu"),
|
||||
model_kwargs={
|
||||
"file_name": "onnx/model.onnx"
|
||||
}, # setting the path isn't necessary if the repo contains a "onnx/model.onnx" file but this is to prevent a HF warning
|
||||
backend="onnx",
|
||||
)
|
||||
onnx_embedder.warm_up()
|
||||
|
||||
mocked_factory.get_embedding_backend.assert_called_once_with(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
device="cpu",
|
||||
auth_token=None,
|
||||
trust_remote_code=False,
|
||||
truncate_dim=None,
|
||||
model_kwargs={"file_name": "onnx/model.onnx"},
|
||||
tokenizer_kwargs=None,
|
||||
config_kwargs=None,
|
||||
backend="onnx",
|
||||
)
|
||||
|
||||
@patch(
|
||||
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
|
||||
)
|
||||
def test_model_openvino_backend(self, mocked_factory):
|
||||
openvino_embedder = SentenceTransformersTextEmbedder(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
token=None,
|
||||
device=ComponentDevice.from_str("cpu"),
|
||||
model_kwargs={
|
||||
"file_name": "openvino/openvino_model.xml"
|
||||
}, # setting the path isn't necessary if the repo contains a "openvino/openvino_model.xml" file but this is to prevent a HF warning
|
||||
backend="openvino",
|
||||
)
|
||||
openvino_embedder.warm_up()
|
||||
|
||||
mocked_factory.get_embedding_backend.assert_called_once_with(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
device="cpu",
|
||||
auth_token=None,
|
||||
trust_remote_code=False,
|
||||
truncate_dim=None,
|
||||
model_kwargs={"file_name": "openvino/openvino_model.xml"},
|
||||
tokenizer_kwargs=None,
|
||||
config_kwargs=None,
|
||||
backend="openvino",
|
||||
)
|
||||
|
||||
@patch(
|
||||
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
|
||||
)
|
||||
@pytest.mark.parametrize("model_kwargs", [{"torch_dtype": "bfloat16"}, {"torch_dtype": "float16"}])
|
||||
def test_dtype_on_gpu(self, mocked_factory, model_kwargs):
|
||||
torch_dtype_embedder = SentenceTransformersTextEmbedder(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
token=None,
|
||||
device=ComponentDevice.from_str("cuda:0"),
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
torch_dtype_embedder.warm_up()
|
||||
|
||||
mocked_factory.get_embedding_backend.assert_called_once_with(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
device="cuda:0",
|
||||
auth_token=None,
|
||||
trust_remote_code=False,
|
||||
truncate_dim=None,
|
||||
model_kwargs=model_kwargs,
|
||||
tokenizer_kwargs=None,
|
||||
config_kwargs=None,
|
||||
backend="torch",
|
||||
)
|
||||
|
||||
@ -291,6 +291,10 @@ class TestSentenceTransformersDiversityRanker:
|
||||
model_name_or_path="mock_model_name",
|
||||
device=ComponentDevice.resolve_device(None).to_torch_str(),
|
||||
use_auth_token=None,
|
||||
model_kwargs=None,
|
||||
tokenizer_kwargs=None,
|
||||
config_kwargs=None,
|
||||
backend="torch",
|
||||
)
|
||||
assert ranker.model == mock_model_instance
|
||||
|
||||
@ -721,3 +725,66 @@ class TestSentenceTransformersDiversityRanker:
|
||||
"Wind turbine technology",
|
||||
]
|
||||
assert [doc.content for doc in results["documents"]] == expected
|
||||
|
||||
@patch("haystack.components.rankers.sentence_transformers_diversity.SentenceTransformer")
|
||||
def test_model_onnx_backend(self, mocked_sentence_transformer):
|
||||
ranker = SentenceTransformersDiversityRanker(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
token=None,
|
||||
device=ComponentDevice.from_str("cpu"),
|
||||
model_kwargs={"file_name": "onnx/model.onnx"},
|
||||
backend="onnx",
|
||||
)
|
||||
ranker.warm_up()
|
||||
|
||||
mocked_sentence_transformer.assert_called_once_with(
|
||||
model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
|
||||
device="cpu",
|
||||
use_auth_token=None,
|
||||
model_kwargs={"file_name": "onnx/model.onnx"},
|
||||
tokenizer_kwargs=None,
|
||||
config_kwargs=None,
|
||||
backend="onnx",
|
||||
)
|
||||
|
||||
@patch("haystack.components.rankers.sentence_transformers_diversity.SentenceTransformer")
|
||||
def test_model_openvino_backend(self, mocked_sentence_transformer):
|
||||
ranker = SentenceTransformersDiversityRanker(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
token=None,
|
||||
device=ComponentDevice.from_str("cpu"),
|
||||
model_kwargs={"file_name": "openvino/openvino_model.xml"},
|
||||
backend="openvino",
|
||||
)
|
||||
ranker.warm_up()
|
||||
|
||||
mocked_sentence_transformer.assert_called_once_with(
|
||||
model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
|
||||
device="cpu",
|
||||
use_auth_token=None,
|
||||
model_kwargs={"file_name": "openvino/openvino_model.xml"},
|
||||
tokenizer_kwargs=None,
|
||||
config_kwargs=None,
|
||||
backend="openvino",
|
||||
)
|
||||
|
||||
@patch("haystack.components.rankers.sentence_transformers_diversity.SentenceTransformer")
|
||||
@pytest.mark.parametrize("model_kwargs", [{"torch_dtype": "float16"}, {"torch_dtype": "bfloat16"}])
|
||||
def test_dtype_on_gpu(self, mocked_sentence_transformer, model_kwargs):
|
||||
ranker = SentenceTransformersDiversityRanker(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
token=None,
|
||||
device=ComponentDevice.from_str("cuda:0"),
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
ranker.warm_up()
|
||||
|
||||
mocked_sentence_transformer.assert_called_once_with(
|
||||
model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
|
||||
device="cuda:0",
|
||||
use_auth_token=None,
|
||||
model_kwargs=model_kwargs,
|
||||
tokenizer_kwargs=None,
|
||||
config_kwargs=None,
|
||||
backend="torch",
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user