feat: Add trust_remote_code init param to SentenceTransformer embedders (#7356)

* Add trust_remote_code init param to SentenceTransformer embedders

* Add release note

* Go with no kwargs solution

* Update haystack/components/embedders/sentence_transformers_document_embedder.py

Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>

* Pydoc fix

---------

Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>
This commit is contained in:
Vladimir Blagojevic 2024-03-14 11:14:04 +01:00 committed by GitHub
parent 5b4f9f1cda
commit 2aae8472e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 74 additions and 16 deletions

View File

@ -15,12 +15,16 @@ class _SentenceTransformersEmbeddingBackendFactory:
_instances: Dict[str, "_SentenceTransformersEmbeddingBackend"] = {}
@staticmethod
def get_embedding_backend(model: str, device: Optional[str] = None, auth_token: Optional[Secret] = None):
def get_embedding_backend(
model: str, device: Optional[str] = None, auth_token: Optional[Secret] = None, trust_remote_code: bool = False
):
embedding_backend_id = f"{model}{device}{auth_token}"
if embedding_backend_id in _SentenceTransformersEmbeddingBackendFactory._instances:
return _SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id]
embedding_backend = _SentenceTransformersEmbeddingBackend(model=model, device=device, auth_token=auth_token)
embedding_backend = _SentenceTransformersEmbeddingBackend(
model=model, device=device, auth_token=auth_token, trust_remote_code=trust_remote_code
)
_SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend
return embedding_backend
@ -30,10 +34,19 @@ class _SentenceTransformersEmbeddingBackend:
Class to manage Sentence Transformers embeddings.
"""
def __init__(self, model: str, device: Optional[str] = None, auth_token: Optional[Secret] = None):
def __init__(
self,
model: str,
device: Optional[str] = None,
auth_token: Optional[Secret] = None,
trust_remote_code: bool = False,
):
sentence_transformers_import.check()
self.model = SentenceTransformer(
model_name_or_path=model, device=device, use_auth_token=auth_token.resolve_value() if auth_token else None
model_name_or_path=model,
device=device,
use_auth_token=auth_token.resolve_value() if auth_token else None,
trust_remote_code=trust_remote_code,
)
def embed(self, data: List[str], **kwargs) -> List[List[float]]:

View File

@ -39,6 +39,7 @@ class SentenceTransformersDocumentEmbedder:
normalize_embeddings: bool = False,
meta_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
trust_remote_code: bool = False,
):
"""
Create a SentenceTransformersDocumentEmbedder component.
@ -65,6 +66,9 @@ class SentenceTransformersDocumentEmbedder:
List of meta fields that will be embedded along with the Document text.
:param embedding_separator:
Separator used to concatenate the meta fields to the Document text.
:param trust_remote_code:
If `False`, only Hugging Face verified model architectures are allowed.
If `True`, custom models and scripts are allowed.
"""
self.model = model
@ -77,6 +81,7 @@ class SentenceTransformersDocumentEmbedder:
self.normalize_embeddings = normalize_embeddings
self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator
self.trust_remote_code = trust_remote_code
def _get_telemetry_data(self) -> Dict[str, Any]:
"""
@ -103,6 +108,7 @@ class SentenceTransformersDocumentEmbedder:
normalize_embeddings=self.normalize_embeddings,
meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator,
trust_remote_code=self.trust_remote_code,
)
@classmethod
@ -127,7 +133,10 @@ class SentenceTransformersDocumentEmbedder:
"""
if not hasattr(self, "embedding_backend"):
self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model=self.model, device=self.device.to_torch_str(), auth_token=self.token
model=self.model,
device=self.device.to_torch_str(),
auth_token=self.token,
trust_remote_code=self.trust_remote_code,
)
@component.output_types(documents=List[Document])

View File

@ -37,6 +37,7 @@ class SentenceTransformersTextEmbedder:
batch_size: int = 32,
progress_bar: bool = True,
normalize_embeddings: bool = False,
trust_remote_code: bool = False,
):
"""
Create a SentenceTransformersTextEmbedder component.
@ -59,6 +60,9 @@ class SentenceTransformersTextEmbedder:
If True shows a progress bar when running.
:param normalize_embeddings:
If True returned vectors will have length 1.
:param trust_remote_code:
If `False`, only Hugging Face verified model architectures are allowed.
If `True`, custom models and scripts are allowed.
"""
self.model = model
@ -69,6 +73,7 @@ class SentenceTransformersTextEmbedder:
self.batch_size = batch_size
self.progress_bar = progress_bar
self.normalize_embeddings = normalize_embeddings
self.trust_remote_code = trust_remote_code
def _get_telemetry_data(self) -> Dict[str, Any]:
"""
@ -93,6 +98,7 @@ class SentenceTransformersTextEmbedder:
batch_size=self.batch_size,
progress_bar=self.progress_bar,
normalize_embeddings=self.normalize_embeddings,
trust_remote_code=self.trust_remote_code,
)
@classmethod
@ -117,7 +123,10 @@ class SentenceTransformersTextEmbedder:
"""
if not hasattr(self, "embedding_backend"):
self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model=self.model, device=self.device.to_torch_str(), auth_token=self.token
model=self.model,
device=self.device.to_torch_str(),
auth_token=self.token,
trust_remote_code=self.trust_remote_code,
)
@component.output_types(embedding=List[float])

View File

@ -0,0 +1,4 @@
---
features:
- |
Add trust_remote_code parameter to SentenceTransformersDocumentEmbedder and SentenceTransformersTextEmbedder for allowing custom models and scripts.

View File

@ -1,10 +1,11 @@
from unittest.mock import patch, MagicMock
import pytest
from unittest.mock import MagicMock, patch
import numpy as np
from haystack.utils import Secret, ComponentDevice
import pytest
from haystack import Document
from haystack.components.embedders.sentence_transformers_document_embedder import SentenceTransformersDocumentEmbedder
from haystack.utils import ComponentDevice, Secret
class TestSentenceTransformersDocumentEmbedder:
@ -20,6 +21,7 @@ class TestSentenceTransformersDocumentEmbedder:
assert embedder.normalize_embeddings is False
assert embedder.meta_fields_to_embed == []
assert embedder.embedding_separator == "\n"
assert embedder.trust_remote_code is False
def test_init_with_parameters(self):
embedder = SentenceTransformersDocumentEmbedder(
@ -33,6 +35,7 @@ class TestSentenceTransformersDocumentEmbedder:
normalize_embeddings=True,
meta_fields_to_embed=["test_field"],
embedding_separator=" | ",
trust_remote_code=True,
)
assert embedder.model == "model"
assert embedder.device == ComponentDevice.from_str("cuda:0")
@ -44,6 +47,7 @@ class TestSentenceTransformersDocumentEmbedder:
assert embedder.normalize_embeddings is True
assert embedder.meta_fields_to_embed == ["test_field"]
assert embedder.embedding_separator == " | "
assert embedder.trust_remote_code
def test_to_dict(self):
component = SentenceTransformersDocumentEmbedder(model="model", device=ComponentDevice.from_str("cpu"))
@ -61,6 +65,7 @@ class TestSentenceTransformersDocumentEmbedder:
"normalize_embeddings": False,
"embedding_separator": "\n",
"meta_fields_to_embed": [],
"trust_remote_code": False,
},
}
@ -76,6 +81,7 @@ class TestSentenceTransformersDocumentEmbedder:
normalize_embeddings=True,
meta_fields_to_embed=["meta_field"],
embedding_separator=" - ",
trust_remote_code=True,
)
data = component.to_dict()
@ -91,6 +97,7 @@ class TestSentenceTransformersDocumentEmbedder:
"progress_bar": False,
"normalize_embeddings": True,
"embedding_separator": " - ",
"trust_remote_code": True,
"meta_fields_to_embed": ["meta_field"],
},
}
@ -107,6 +114,7 @@ class TestSentenceTransformersDocumentEmbedder:
"normalize_embeddings": True,
"embedding_separator": " - ",
"meta_fields_to_embed": ["meta_field"],
"trust_remote_code": True,
}
component = SentenceTransformersDocumentEmbedder.from_dict(
{
@ -123,6 +131,7 @@ class TestSentenceTransformersDocumentEmbedder:
assert component.progress_bar is False
assert component.normalize_embeddings is True
assert component.embedding_separator == " - "
assert component.trust_remote_code
assert component.meta_fields_to_embed == ["meta_field"]
@patch(
@ -134,7 +143,9 @@ class TestSentenceTransformersDocumentEmbedder:
)
mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(model="model", device="cpu", auth_token=None)
mocked_factory.get_embedding_backend.assert_called_once_with(
model="model", device="cpu", auth_token=None, trust_remote_code=False
)
@patch(
"haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"

View File

@ -1,5 +1,7 @@
from unittest.mock import patch
import pytest
from haystack.components.embedders.backends.sentence_transformers_backend import (
_SentenceTransformersEmbeddingBackendFactory,
)
@ -23,10 +25,10 @@ def test_factory_behavior(mock_sentence_transformer):
@patch("haystack.components.embedders.backends.sentence_transformers_backend.SentenceTransformer")
def test_model_initialization(mock_sentence_transformer):
_SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model="model", device="cpu", auth_token=Secret.from_token("fake-api-token")
model="model", device="cpu", auth_token=Secret.from_token("fake-api-token"), trust_remote_code=True
)
mock_sentence_transformer.assert_called_once_with(
model_name_or_path="model", device="cpu", use_auth_token="fake-api-token"
model_name_or_path="model", device="cpu", use_auth_token="fake-api-token", trust_remote_code=True
)

View File

@ -1,10 +1,10 @@
from unittest.mock import patch, MagicMock
import pytest
from haystack.utils import Secret, ComponentDevice
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from haystack.components.embedders.sentence_transformers_text_embedder import SentenceTransformersTextEmbedder
from haystack.utils import ComponentDevice, Secret
class TestSentenceTransformersTextEmbedder:
@ -18,6 +18,7 @@ class TestSentenceTransformersTextEmbedder:
assert embedder.batch_size == 32
assert embedder.progress_bar is True
assert embedder.normalize_embeddings is False
assert embedder.trust_remote_code is False
def test_init_with_parameters(self):
embedder = SentenceTransformersTextEmbedder(
@ -29,6 +30,7 @@ class TestSentenceTransformersTextEmbedder:
batch_size=64,
progress_bar=False,
normalize_embeddings=True,
trust_remote_code=True,
)
assert embedder.model == "model"
assert embedder.device == ComponentDevice.from_str("cuda:0")
@ -38,6 +40,7 @@ class TestSentenceTransformersTextEmbedder:
assert embedder.batch_size == 64
assert embedder.progress_bar is False
assert embedder.normalize_embeddings is True
assert embedder.trust_remote_code
def test_to_dict(self):
component = SentenceTransformersTextEmbedder(model="model", device=ComponentDevice.from_str("cpu"))
@ -53,6 +56,7 @@ class TestSentenceTransformersTextEmbedder:
"batch_size": 32,
"progress_bar": True,
"normalize_embeddings": False,
"trust_remote_code": False,
},
}
@ -66,6 +70,7 @@ class TestSentenceTransformersTextEmbedder:
batch_size=64,
progress_bar=False,
normalize_embeddings=True,
trust_remote_code=True,
)
data = component.to_dict()
assert data == {
@ -79,6 +84,7 @@ class TestSentenceTransformersTextEmbedder:
"batch_size": 64,
"progress_bar": False,
"normalize_embeddings": True,
"trust_remote_code": True,
},
}
@ -99,6 +105,7 @@ class TestSentenceTransformersTextEmbedder:
"batch_size": 32,
"progress_bar": True,
"normalize_embeddings": False,
"trust_remote_code": False,
},
}
component = SentenceTransformersTextEmbedder.from_dict(data)
@ -110,6 +117,7 @@ class TestSentenceTransformersTextEmbedder:
assert component.batch_size == 32
assert component.progress_bar is True
assert component.normalize_embeddings is False
assert component.trust_remote_code is False
@patch(
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
@ -118,7 +126,9 @@ class TestSentenceTransformersTextEmbedder:
embedder = SentenceTransformersTextEmbedder(model="model", token=None, device=ComponentDevice.from_str("cpu"))
mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(model="model", device="cpu", auth_token=None)
mocked_factory.get_embedding_backend.assert_called_once_with(
model="model", device="cpu", auth_token=None, trust_remote_code=False
)
@patch(
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"