mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-08 21:28:00 +00:00
feat: Add trust_remote_code init param to SentenceTransformer embedders (#7356)
* Add trust_remote_code init param to SentenceTransformer embedders * Add release note * Go with no kwargs solution * Update haystack/components/embedders/sentence_transformers_document_embedder.py Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com> * Pydoc fix --------- Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>
This commit is contained in:
parent
5b4f9f1cda
commit
2aae8472e7
@ -15,12 +15,16 @@ class _SentenceTransformersEmbeddingBackendFactory:
|
||||
_instances: Dict[str, "_SentenceTransformersEmbeddingBackend"] = {}
|
||||
|
||||
@staticmethod
|
||||
def get_embedding_backend(model: str, device: Optional[str] = None, auth_token: Optional[Secret] = None):
|
||||
def get_embedding_backend(
|
||||
model: str, device: Optional[str] = None, auth_token: Optional[Secret] = None, trust_remote_code: bool = False
|
||||
):
|
||||
embedding_backend_id = f"{model}{device}{auth_token}"
|
||||
|
||||
if embedding_backend_id in _SentenceTransformersEmbeddingBackendFactory._instances:
|
||||
return _SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id]
|
||||
embedding_backend = _SentenceTransformersEmbeddingBackend(model=model, device=device, auth_token=auth_token)
|
||||
embedding_backend = _SentenceTransformersEmbeddingBackend(
|
||||
model=model, device=device, auth_token=auth_token, trust_remote_code=trust_remote_code
|
||||
)
|
||||
_SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend
|
||||
return embedding_backend
|
||||
|
||||
@ -30,10 +34,19 @@ class _SentenceTransformersEmbeddingBackend:
|
||||
Class to manage Sentence Transformers embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, model: str, device: Optional[str] = None, auth_token: Optional[Secret] = None):
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
device: Optional[str] = None,
|
||||
auth_token: Optional[Secret] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
sentence_transformers_import.check()
|
||||
self.model = SentenceTransformer(
|
||||
model_name_or_path=model, device=device, use_auth_token=auth_token.resolve_value() if auth_token else None
|
||||
model_name_or_path=model,
|
||||
device=device,
|
||||
use_auth_token=auth_token.resolve_value() if auth_token else None,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
def embed(self, data: List[str], **kwargs) -> List[List[float]]:
|
||||
|
||||
@ -39,6 +39,7 @@ class SentenceTransformersDocumentEmbedder:
|
||||
normalize_embeddings: bool = False,
|
||||
meta_fields_to_embed: Optional[List[str]] = None,
|
||||
embedding_separator: str = "\n",
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
"""
|
||||
Create a SentenceTransformersDocumentEmbedder component.
|
||||
@ -65,6 +66,9 @@ class SentenceTransformersDocumentEmbedder:
|
||||
List of meta fields that will be embedded along with the Document text.
|
||||
:param embedding_separator:
|
||||
Separator used to concatenate the meta fields to the Document text.
|
||||
:param trust_remote_code:
|
||||
If `False`, only Hugging Face verified model architectures are allowed.
|
||||
If `True`, custom models and scripts are allowed.
|
||||
"""
|
||||
|
||||
self.model = model
|
||||
@ -77,6 +81,7 @@ class SentenceTransformersDocumentEmbedder:
|
||||
self.normalize_embeddings = normalize_embeddings
|
||||
self.meta_fields_to_embed = meta_fields_to_embed or []
|
||||
self.embedding_separator = embedding_separator
|
||||
self.trust_remote_code = trust_remote_code
|
||||
|
||||
def _get_telemetry_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -103,6 +108,7 @@ class SentenceTransformersDocumentEmbedder:
|
||||
normalize_embeddings=self.normalize_embeddings,
|
||||
meta_fields_to_embed=self.meta_fields_to_embed,
|
||||
embedding_separator=self.embedding_separator,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -127,7 +133,10 @@ class SentenceTransformersDocumentEmbedder:
|
||||
"""
|
||||
if not hasattr(self, "embedding_backend"):
|
||||
self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
|
||||
model=self.model, device=self.device.to_torch_str(), auth_token=self.token
|
||||
model=self.model,
|
||||
device=self.device.to_torch_str(),
|
||||
auth_token=self.token,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
)
|
||||
|
||||
@component.output_types(documents=List[Document])
|
||||
|
||||
@ -37,6 +37,7 @@ class SentenceTransformersTextEmbedder:
|
||||
batch_size: int = 32,
|
||||
progress_bar: bool = True,
|
||||
normalize_embeddings: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
"""
|
||||
Create a SentenceTransformersTextEmbedder component.
|
||||
@ -59,6 +60,9 @@ class SentenceTransformersTextEmbedder:
|
||||
If True shows a progress bar when running.
|
||||
:param normalize_embeddings:
|
||||
If True returned vectors will have length 1.
|
||||
:param trust_remote_code:
|
||||
If `False`, only Hugging Face verified model architectures are allowed.
|
||||
If `True`, custom models and scripts are allowed.
|
||||
"""
|
||||
|
||||
self.model = model
|
||||
@ -69,6 +73,7 @@ class SentenceTransformersTextEmbedder:
|
||||
self.batch_size = batch_size
|
||||
self.progress_bar = progress_bar
|
||||
self.normalize_embeddings = normalize_embeddings
|
||||
self.trust_remote_code = trust_remote_code
|
||||
|
||||
def _get_telemetry_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -93,6 +98,7 @@ class SentenceTransformersTextEmbedder:
|
||||
batch_size=self.batch_size,
|
||||
progress_bar=self.progress_bar,
|
||||
normalize_embeddings=self.normalize_embeddings,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -117,7 +123,10 @@ class SentenceTransformersTextEmbedder:
|
||||
"""
|
||||
if not hasattr(self, "embedding_backend"):
|
||||
self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
|
||||
model=self.model, device=self.device.to_torch_str(), auth_token=self.token
|
||||
model=self.model,
|
||||
device=self.device.to_torch_str(),
|
||||
auth_token=self.token,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
)
|
||||
|
||||
@component.output_types(embedding=List[float])
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Add trust_remote_code parameter to SentenceTransformersDocumentEmbedder and SentenceTransformersTextEmbedder for allowing custom models and scripts.
|
||||
@ -1,10 +1,11 @@
|
||||
from unittest.mock import patch, MagicMock
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
from haystack.utils import Secret, ComponentDevice
|
||||
import pytest
|
||||
|
||||
from haystack import Document
|
||||
from haystack.components.embedders.sentence_transformers_document_embedder import SentenceTransformersDocumentEmbedder
|
||||
from haystack.utils import ComponentDevice, Secret
|
||||
|
||||
|
||||
class TestSentenceTransformersDocumentEmbedder:
|
||||
@ -20,6 +21,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
assert embedder.normalize_embeddings is False
|
||||
assert embedder.meta_fields_to_embed == []
|
||||
assert embedder.embedding_separator == "\n"
|
||||
assert embedder.trust_remote_code is False
|
||||
|
||||
def test_init_with_parameters(self):
|
||||
embedder = SentenceTransformersDocumentEmbedder(
|
||||
@ -33,6 +35,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
normalize_embeddings=True,
|
||||
meta_fields_to_embed=["test_field"],
|
||||
embedding_separator=" | ",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
assert embedder.model == "model"
|
||||
assert embedder.device == ComponentDevice.from_str("cuda:0")
|
||||
@ -44,6 +47,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
assert embedder.normalize_embeddings is True
|
||||
assert embedder.meta_fields_to_embed == ["test_field"]
|
||||
assert embedder.embedding_separator == " | "
|
||||
assert embedder.trust_remote_code
|
||||
|
||||
def test_to_dict(self):
|
||||
component = SentenceTransformersDocumentEmbedder(model="model", device=ComponentDevice.from_str("cpu"))
|
||||
@ -61,6 +65,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
"normalize_embeddings": False,
|
||||
"embedding_separator": "\n",
|
||||
"meta_fields_to_embed": [],
|
||||
"trust_remote_code": False,
|
||||
},
|
||||
}
|
||||
|
||||
@ -76,6 +81,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
normalize_embeddings=True,
|
||||
meta_fields_to_embed=["meta_field"],
|
||||
embedding_separator=" - ",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
data = component.to_dict()
|
||||
|
||||
@ -91,6 +97,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
"progress_bar": False,
|
||||
"normalize_embeddings": True,
|
||||
"embedding_separator": " - ",
|
||||
"trust_remote_code": True,
|
||||
"meta_fields_to_embed": ["meta_field"],
|
||||
},
|
||||
}
|
||||
@ -107,6 +114,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
"normalize_embeddings": True,
|
||||
"embedding_separator": " - ",
|
||||
"meta_fields_to_embed": ["meta_field"],
|
||||
"trust_remote_code": True,
|
||||
}
|
||||
component = SentenceTransformersDocumentEmbedder.from_dict(
|
||||
{
|
||||
@ -123,6 +131,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
assert component.progress_bar is False
|
||||
assert component.normalize_embeddings is True
|
||||
assert component.embedding_separator == " - "
|
||||
assert component.trust_remote_code
|
||||
assert component.meta_fields_to_embed == ["meta_field"]
|
||||
|
||||
@patch(
|
||||
@ -134,7 +143,9 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
)
|
||||
mocked_factory.get_embedding_backend.assert_not_called()
|
||||
embedder.warm_up()
|
||||
mocked_factory.get_embedding_backend.assert_called_once_with(model="model", device="cpu", auth_token=None)
|
||||
mocked_factory.get_embedding_backend.assert_called_once_with(
|
||||
model="model", device="cpu", auth_token=None, trust_remote_code=False
|
||||
)
|
||||
|
||||
@patch(
|
||||
"haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.components.embedders.backends.sentence_transformers_backend import (
|
||||
_SentenceTransformersEmbeddingBackendFactory,
|
||||
)
|
||||
@ -23,10 +25,10 @@ def test_factory_behavior(mock_sentence_transformer):
|
||||
@patch("haystack.components.embedders.backends.sentence_transformers_backend.SentenceTransformer")
|
||||
def test_model_initialization(mock_sentence_transformer):
|
||||
_SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
|
||||
model="model", device="cpu", auth_token=Secret.from_token("fake-api-token")
|
||||
model="model", device="cpu", auth_token=Secret.from_token("fake-api-token"), trust_remote_code=True
|
||||
)
|
||||
mock_sentence_transformer.assert_called_once_with(
|
||||
model_name_or_path="model", device="cpu", use_auth_token="fake-api-token"
|
||||
model_name_or_path="model", device="cpu", use_auth_token="fake-api-token", trust_remote_code=True
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
from unittest.mock import patch, MagicMock
|
||||
import pytest
|
||||
from haystack.utils import Secret, ComponentDevice
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from haystack.components.embedders.sentence_transformers_text_embedder import SentenceTransformersTextEmbedder
|
||||
from haystack.utils import ComponentDevice, Secret
|
||||
|
||||
|
||||
class TestSentenceTransformersTextEmbedder:
|
||||
@ -18,6 +18,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
assert embedder.batch_size == 32
|
||||
assert embedder.progress_bar is True
|
||||
assert embedder.normalize_embeddings is False
|
||||
assert embedder.trust_remote_code is False
|
||||
|
||||
def test_init_with_parameters(self):
|
||||
embedder = SentenceTransformersTextEmbedder(
|
||||
@ -29,6 +30,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
batch_size=64,
|
||||
progress_bar=False,
|
||||
normalize_embeddings=True,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
assert embedder.model == "model"
|
||||
assert embedder.device == ComponentDevice.from_str("cuda:0")
|
||||
@ -38,6 +40,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
assert embedder.batch_size == 64
|
||||
assert embedder.progress_bar is False
|
||||
assert embedder.normalize_embeddings is True
|
||||
assert embedder.trust_remote_code
|
||||
|
||||
def test_to_dict(self):
|
||||
component = SentenceTransformersTextEmbedder(model="model", device=ComponentDevice.from_str("cpu"))
|
||||
@ -53,6 +56,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
"batch_size": 32,
|
||||
"progress_bar": True,
|
||||
"normalize_embeddings": False,
|
||||
"trust_remote_code": False,
|
||||
},
|
||||
}
|
||||
|
||||
@ -66,6 +70,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
batch_size=64,
|
||||
progress_bar=False,
|
||||
normalize_embeddings=True,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
@ -79,6 +84,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
"batch_size": 64,
|
||||
"progress_bar": False,
|
||||
"normalize_embeddings": True,
|
||||
"trust_remote_code": True,
|
||||
},
|
||||
}
|
||||
|
||||
@ -99,6 +105,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
"batch_size": 32,
|
||||
"progress_bar": True,
|
||||
"normalize_embeddings": False,
|
||||
"trust_remote_code": False,
|
||||
},
|
||||
}
|
||||
component = SentenceTransformersTextEmbedder.from_dict(data)
|
||||
@ -110,6 +117,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
assert component.batch_size == 32
|
||||
assert component.progress_bar is True
|
||||
assert component.normalize_embeddings is False
|
||||
assert component.trust_remote_code is False
|
||||
|
||||
@patch(
|
||||
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
|
||||
@ -118,7 +126,9 @@ class TestSentenceTransformersTextEmbedder:
|
||||
embedder = SentenceTransformersTextEmbedder(model="model", token=None, device=ComponentDevice.from_str("cpu"))
|
||||
mocked_factory.get_embedding_backend.assert_not_called()
|
||||
embedder.warm_up()
|
||||
mocked_factory.get_embedding_backend.assert_called_once_with(model="model", device="cpu", auth_token=None)
|
||||
mocked_factory.get_embedding_backend.assert_called_once_with(
|
||||
model="model", device="cpu", auth_token=None, trust_remote_code=False
|
||||
)
|
||||
|
||||
@patch(
|
||||
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user