diff --git a/haystack/components/embedders/backends/sentence_transformers_backend.py b/haystack/components/embedders/backends/sentence_transformers_backend.py index c0fdfac09..c66be99a9 100644 --- a/haystack/components/embedders/backends/sentence_transformers_backend.py +++ b/haystack/components/embedders/backends/sentence_transformers_backend.py @@ -15,12 +15,16 @@ class _SentenceTransformersEmbeddingBackendFactory: _instances: Dict[str, "_SentenceTransformersEmbeddingBackend"] = {} @staticmethod - def get_embedding_backend(model: str, device: Optional[str] = None, auth_token: Optional[Secret] = None): + def get_embedding_backend( + model: str, device: Optional[str] = None, auth_token: Optional[Secret] = None, trust_remote_code: bool = False + ): embedding_backend_id = f"{model}{device}{auth_token}" if embedding_backend_id in _SentenceTransformersEmbeddingBackendFactory._instances: return _SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] - embedding_backend = _SentenceTransformersEmbeddingBackend(model=model, device=device, auth_token=auth_token) + embedding_backend = _SentenceTransformersEmbeddingBackend( + model=model, device=device, auth_token=auth_token, trust_remote_code=trust_remote_code + ) _SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend return embedding_backend @@ -30,10 +34,19 @@ class _SentenceTransformersEmbeddingBackend: Class to manage Sentence Transformers embeddings. """ - def __init__(self, model: str, device: Optional[str] = None, auth_token: Optional[Secret] = None): + def __init__( + self, + model: str, + device: Optional[str] = None, + auth_token: Optional[Secret] = None, + trust_remote_code: bool = False, + ): sentence_transformers_import.check() self.model = SentenceTransformer( - model_name_or_path=model, device=device, use_auth_token=auth_token.resolve_value() if auth_token else None + model_name_or_path=model, + device=device, + use_auth_token=auth_token.resolve_value() if auth_token else None, + trust_remote_code=trust_remote_code, ) def embed(self, data: List[str], **kwargs) -> List[List[float]]: diff --git a/haystack/components/embedders/sentence_transformers_document_embedder.py b/haystack/components/embedders/sentence_transformers_document_embedder.py index 9097d888d..afa425f96 100644 --- a/haystack/components/embedders/sentence_transformers_document_embedder.py +++ b/haystack/components/embedders/sentence_transformers_document_embedder.py @@ -39,6 +39,7 @@ class SentenceTransformersDocumentEmbedder: normalize_embeddings: bool = False, meta_fields_to_embed: Optional[List[str]] = None, embedding_separator: str = "\n", + trust_remote_code: bool = False, ): """ Create a SentenceTransformersDocumentEmbedder component. @@ -65,6 +66,9 @@ class SentenceTransformersDocumentEmbedder: List of meta fields that will be embedded along with the Document text. :param embedding_separator: Separator used to concatenate the meta fields to the Document text. + :param trust_remote_code: + If `False`, only Hugging Face verified model architectures are allowed. + If `True`, custom models and scripts are allowed. """ self.model = model @@ -77,6 +81,7 @@ class SentenceTransformersDocumentEmbedder: self.normalize_embeddings = normalize_embeddings self.meta_fields_to_embed = meta_fields_to_embed or [] self.embedding_separator = embedding_separator + self.trust_remote_code = trust_remote_code def _get_telemetry_data(self) -> Dict[str, Any]: """ @@ -103,6 +108,7 @@ class SentenceTransformersDocumentEmbedder: normalize_embeddings=self.normalize_embeddings, meta_fields_to_embed=self.meta_fields_to_embed, embedding_separator=self.embedding_separator, + trust_remote_code=self.trust_remote_code, ) @classmethod @@ -127,7 +133,10 @@ class SentenceTransformersDocumentEmbedder: """ if not hasattr(self, "embedding_backend"): self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend( - model=self.model, device=self.device.to_torch_str(), auth_token=self.token + model=self.model, + device=self.device.to_torch_str(), + auth_token=self.token, + trust_remote_code=self.trust_remote_code, ) @component.output_types(documents=List[Document]) diff --git a/haystack/components/embedders/sentence_transformers_text_embedder.py b/haystack/components/embedders/sentence_transformers_text_embedder.py index 0c1a3bc48..a9ea7f31e 100644 --- a/haystack/components/embedders/sentence_transformers_text_embedder.py +++ b/haystack/components/embedders/sentence_transformers_text_embedder.py @@ -37,6 +37,7 @@ class SentenceTransformersTextEmbedder: batch_size: int = 32, progress_bar: bool = True, normalize_embeddings: bool = False, + trust_remote_code: bool = False, ): """ Create a SentenceTransformersTextEmbedder component. @@ -59,6 +60,9 @@ class SentenceTransformersTextEmbedder: If True shows a progress bar when running. :param normalize_embeddings: If True returned vectors will have length 1. + :param trust_remote_code: + If `False`, only Hugging Face verified model architectures are allowed. + If `True`, custom models and scripts are allowed. """ self.model = model @@ -69,6 +73,7 @@ class SentenceTransformersTextEmbedder: self.batch_size = batch_size self.progress_bar = progress_bar self.normalize_embeddings = normalize_embeddings + self.trust_remote_code = trust_remote_code def _get_telemetry_data(self) -> Dict[str, Any]: """ @@ -93,6 +98,7 @@ class SentenceTransformersTextEmbedder: batch_size=self.batch_size, progress_bar=self.progress_bar, normalize_embeddings=self.normalize_embeddings, + trust_remote_code=self.trust_remote_code, ) @classmethod @@ -117,7 +123,10 @@ class SentenceTransformersTextEmbedder: """ if not hasattr(self, "embedding_backend"): self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend( - model=self.model, device=self.device.to_torch_str(), auth_token=self.token + model=self.model, + device=self.device.to_torch_str(), + auth_token=self.token, + trust_remote_code=self.trust_remote_code, ) @component.output_types(embedding=List[float]) diff --git a/releasenotes/notes/add-trust-remote-code-feature-c133b2b245d2ea7a.yaml b/releasenotes/notes/add-trust-remote-code-feature-c133b2b245d2ea7a.yaml new file mode 100644 index 000000000..a4cf31426 --- /dev/null +++ b/releasenotes/notes/add-trust-remote-code-feature-c133b2b245d2ea7a.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Add trust_remote_code parameter to SentenceTransformersDocumentEmbedder and SentenceTransformersTextEmbedder for allowing custom models and scripts. diff --git a/test/components/embedders/test_sentence_transformers_document_embedder.py b/test/components/embedders/test_sentence_transformers_document_embedder.py index be617e227..789943da8 100644 --- a/test/components/embedders/test_sentence_transformers_document_embedder.py +++ b/test/components/embedders/test_sentence_transformers_document_embedder.py @@ -1,10 +1,11 @@ -from unittest.mock import patch, MagicMock -import pytest +from unittest.mock import MagicMock, patch + import numpy as np -from haystack.utils import Secret, ComponentDevice +import pytest from haystack import Document from haystack.components.embedders.sentence_transformers_document_embedder import SentenceTransformersDocumentEmbedder +from haystack.utils import ComponentDevice, Secret class TestSentenceTransformersDocumentEmbedder: @@ -20,6 +21,7 @@ class TestSentenceTransformersDocumentEmbedder: assert embedder.normalize_embeddings is False assert embedder.meta_fields_to_embed == [] assert embedder.embedding_separator == "\n" + assert embedder.trust_remote_code is False def test_init_with_parameters(self): embedder = SentenceTransformersDocumentEmbedder( @@ -33,6 +35,7 @@ class TestSentenceTransformersDocumentEmbedder: normalize_embeddings=True, meta_fields_to_embed=["test_field"], embedding_separator=" | ", + trust_remote_code=True, ) assert embedder.model == "model" assert embedder.device == ComponentDevice.from_str("cuda:0") @@ -44,6 +47,7 @@ class TestSentenceTransformersDocumentEmbedder: assert embedder.normalize_embeddings is True assert embedder.meta_fields_to_embed == ["test_field"] assert embedder.embedding_separator == " | " + assert embedder.trust_remote_code def test_to_dict(self): component = SentenceTransformersDocumentEmbedder(model="model", device=ComponentDevice.from_str("cpu")) @@ -61,6 +65,7 @@ class TestSentenceTransformersDocumentEmbedder: "normalize_embeddings": False, "embedding_separator": "\n", "meta_fields_to_embed": [], + "trust_remote_code": False, }, } @@ -76,6 +81,7 @@ class TestSentenceTransformersDocumentEmbedder: normalize_embeddings=True, meta_fields_to_embed=["meta_field"], embedding_separator=" - ", + trust_remote_code=True, ) data = component.to_dict() @@ -91,6 +97,7 @@ class TestSentenceTransformersDocumentEmbedder: "progress_bar": False, "normalize_embeddings": True, "embedding_separator": " - ", + "trust_remote_code": True, "meta_fields_to_embed": ["meta_field"], }, } @@ -107,6 +114,7 @@ class TestSentenceTransformersDocumentEmbedder: "normalize_embeddings": True, "embedding_separator": " - ", "meta_fields_to_embed": ["meta_field"], + "trust_remote_code": True, } component = SentenceTransformersDocumentEmbedder.from_dict( { @@ -123,6 +131,7 @@ class TestSentenceTransformersDocumentEmbedder: assert component.progress_bar is False assert component.normalize_embeddings is True assert component.embedding_separator == " - " + assert component.trust_remote_code assert component.meta_fields_to_embed == ["meta_field"] @patch( @@ -134,7 +143,9 @@ class TestSentenceTransformersDocumentEmbedder: ) mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() - mocked_factory.get_embedding_backend.assert_called_once_with(model="model", device="cpu", auth_token=None) + mocked_factory.get_embedding_backend.assert_called_once_with( + model="model", device="cpu", auth_token=None, trust_remote_code=False + ) @patch( "haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory" diff --git a/test/components/embedders/test_sentence_transformers_embedding_backend.py b/test/components/embedders/test_sentence_transformers_embedding_backend.py index 726b1b23a..cbdc812ac 100644 --- a/test/components/embedders/test_sentence_transformers_embedding_backend.py +++ b/test/components/embedders/test_sentence_transformers_embedding_backend.py @@ -1,5 +1,7 @@ from unittest.mock import patch + import pytest + from haystack.components.embedders.backends.sentence_transformers_backend import ( _SentenceTransformersEmbeddingBackendFactory, ) @@ -23,10 +25,10 @@ def test_factory_behavior(mock_sentence_transformer): @patch("haystack.components.embedders.backends.sentence_transformers_backend.SentenceTransformer") def test_model_initialization(mock_sentence_transformer): _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend( - model="model", device="cpu", auth_token=Secret.from_token("fake-api-token") + model="model", device="cpu", auth_token=Secret.from_token("fake-api-token"), trust_remote_code=True ) mock_sentence_transformer.assert_called_once_with( - model_name_or_path="model", device="cpu", use_auth_token="fake-api-token" + model_name_or_path="model", device="cpu", use_auth_token="fake-api-token", trust_remote_code=True ) diff --git a/test/components/embedders/test_sentence_transformers_text_embedder.py b/test/components/embedders/test_sentence_transformers_text_embedder.py index 18383e1fd..3885bf6cf 100644 --- a/test/components/embedders/test_sentence_transformers_text_embedder.py +++ b/test/components/embedders/test_sentence_transformers_text_embedder.py @@ -1,10 +1,10 @@ -from unittest.mock import patch, MagicMock -import pytest -from haystack.utils import Secret, ComponentDevice +from unittest.mock import MagicMock, patch import numpy as np +import pytest from haystack.components.embedders.sentence_transformers_text_embedder import SentenceTransformersTextEmbedder +from haystack.utils import ComponentDevice, Secret class TestSentenceTransformersTextEmbedder: @@ -18,6 +18,7 @@ class TestSentenceTransformersTextEmbedder: assert embedder.batch_size == 32 assert embedder.progress_bar is True assert embedder.normalize_embeddings is False + assert embedder.trust_remote_code is False def test_init_with_parameters(self): embedder = SentenceTransformersTextEmbedder( @@ -29,6 +30,7 @@ class TestSentenceTransformersTextEmbedder: batch_size=64, progress_bar=False, normalize_embeddings=True, + trust_remote_code=True, ) assert embedder.model == "model" assert embedder.device == ComponentDevice.from_str("cuda:0") @@ -38,6 +40,7 @@ class TestSentenceTransformersTextEmbedder: assert embedder.batch_size == 64 assert embedder.progress_bar is False assert embedder.normalize_embeddings is True + assert embedder.trust_remote_code def test_to_dict(self): component = SentenceTransformersTextEmbedder(model="model", device=ComponentDevice.from_str("cpu")) @@ -53,6 +56,7 @@ class TestSentenceTransformersTextEmbedder: "batch_size": 32, "progress_bar": True, "normalize_embeddings": False, + "trust_remote_code": False, }, } @@ -66,6 +70,7 @@ class TestSentenceTransformersTextEmbedder: batch_size=64, progress_bar=False, normalize_embeddings=True, + trust_remote_code=True, ) data = component.to_dict() assert data == { @@ -79,6 +84,7 @@ class TestSentenceTransformersTextEmbedder: "batch_size": 64, "progress_bar": False, "normalize_embeddings": True, + "trust_remote_code": True, }, } @@ -99,6 +105,7 @@ class TestSentenceTransformersTextEmbedder: "batch_size": 32, "progress_bar": True, "normalize_embeddings": False, + "trust_remote_code": False, }, } component = SentenceTransformersTextEmbedder.from_dict(data) @@ -110,6 +117,7 @@ class TestSentenceTransformersTextEmbedder: assert component.batch_size == 32 assert component.progress_bar is True assert component.normalize_embeddings is False + assert component.trust_remote_code is False @patch( "haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory" @@ -118,7 +126,9 @@ class TestSentenceTransformersTextEmbedder: embedder = SentenceTransformersTextEmbedder(model="model", token=None, device=ComponentDevice.from_str("cpu")) mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() - mocked_factory.get_embedding_backend.assert_called_once_with(model="model", device="cpu", auth_token=None) + mocked_factory.get_embedding_backend.assert_called_once_with( + model="model", device="cpu", auth_token=None, trust_remote_code=False + ) @patch( "haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"