feat: add trust_remote_code parameter to SentenceTransformersSimilarityRanker (#9546)

This commit is contained in:
Stefano Fiorucci 2025-06-24 11:39:59 +02:00 committed by GitHub
parent 556dcc9e46
commit d14f5dca0e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 45 additions and 2 deletions

View File

@ -52,6 +52,7 @@ class SentenceTransformersSimilarityRanker:
embedding_separator: str = "\n",
scale_score: bool = True,
score_threshold: Optional[float] = None,
trust_remote_code: bool = False,
model_kwargs: Optional[Dict[str, Any]] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
config_kwargs: Optional[Dict[str, Any]] = None,
@ -84,6 +85,9 @@ class SentenceTransformersSimilarityRanker:
If `False`, disables scaling of the raw logit predictions.
:param score_threshold:
Use it to return documents with a score above this threshold only.
:param trust_remote_code:
If `False`, allows only Hugging Face verified model architectures.
If `True`, allows custom models and scripts.
:param model_kwargs:
Additional keyword arguments for `AutoModelForSequenceClassification.from_pretrained`
when loading the model. Refer to specific model documentation for available kwargs.
@ -119,6 +123,7 @@ class SentenceTransformersSimilarityRanker:
self.embedding_separator = embedding_separator
self.scale_score = scale_score
self.score_threshold = score_threshold
self.trust_remote_code = trust_remote_code
self.model_kwargs = model_kwargs
self.tokenizer_kwargs = tokenizer_kwargs
self.config_kwargs = config_kwargs
@ -140,6 +145,7 @@ class SentenceTransformersSimilarityRanker:
model_name_or_path=self.model,
device=self.device.to_torch_str(),
token=self.token.resolve_value() if self.token else None,
trust_remote_code=self.trust_remote_code,
model_kwargs=self.model_kwargs,
tokenizer_kwargs=self.tokenizer_kwargs,
config_kwargs=self.config_kwargs,
@ -165,6 +171,7 @@ class SentenceTransformersSimilarityRanker:
embedding_separator=self.embedding_separator,
scale_score=self.scale_score,
score_threshold=self.score_threshold,
trust_remote_code=self.trust_remote_code,
model_kwargs=self.model_kwargs,
tokenizer_kwargs=self.tokenizer_kwargs,
config_kwargs=self.config_kwargs,

View File

@ -0,0 +1,5 @@
---
enhancements:
- |
Added a `trust_remote_code` parameter to the `SentenceTransformersSimilarityRanker` component.
When set to True, this enables execution of custom models and scripts hosted on the Hugging Face Hub.

View File

@ -19,7 +19,29 @@ class TestSentenceTransformersSimilarityRanker:
SentenceTransformersSimilarityRanker(top_k=-1)
@patch("haystack.components.rankers.sentence_transformers_similarity.CrossEncoder")
def test_init_onnx_backend(self, mocked_cross_encoder):
def test_init_warm_up_torch_backend(self, mocked_cross_encoder):
ranker = SentenceTransformersSimilarityRanker(
model="sentence-transformers/all-MiniLM-L6-v2",
token=None,
device=ComponentDevice.from_str("cpu"),
backend="torch",
trust_remote_code=True,
)
ranker.warm_up()
mocked_cross_encoder.assert_called_once_with(
model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
device="cpu",
token=None,
trust_remote_code=True,
model_kwargs=None,
tokenizer_kwargs=None,
config_kwargs=None,
backend="torch",
)
@patch("haystack.components.rankers.sentence_transformers_similarity.CrossEncoder")
def test_init_warm_up_onnx_backend(self, mocked_cross_encoder):
onnx_ranker = SentenceTransformersSimilarityRanker(
model="sentence-transformers/all-MiniLM-L6-v2",
token=None,
@ -32,6 +54,7 @@ class TestSentenceTransformersSimilarityRanker:
model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
device="cpu",
token=None,
trust_remote_code=False,
model_kwargs=None,
tokenizer_kwargs=None,
config_kwargs=None,
@ -39,7 +62,7 @@ class TestSentenceTransformersSimilarityRanker:
)
@patch("haystack.components.rankers.sentence_transformers_similarity.CrossEncoder")
def test_init_openvino_backend(self, mocked_cross_encoder):
def test_init_warm_up_openvino_backend(self, mocked_cross_encoder):
openvino_ranker = SentenceTransformersSimilarityRanker(
model="sentence-transformers/all-MiniLM-L6-v2",
token=None,
@ -52,6 +75,7 @@ class TestSentenceTransformersSimilarityRanker:
model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
device="cpu",
token=None,
trust_remote_code=False,
model_kwargs=None,
tokenizer_kwargs=None,
config_kwargs=None,
@ -74,6 +98,7 @@ class TestSentenceTransformersSimilarityRanker:
"embedding_separator": "\n",
"scale_score": True,
"score_threshold": None,
"trust_remote_code": False,
"model_kwargs": None,
"tokenizer_kwargs": None,
"config_kwargs": None,
@ -92,6 +117,7 @@ class TestSentenceTransformersSimilarityRanker:
document_prefix="document_instruction: ",
scale_score=False,
score_threshold=0.01,
trust_remote_code=True,
model_kwargs={"torch_dtype": torch.float16},
tokenizer_kwargs={"model_max_length": 512},
batch_size=32,
@ -110,6 +136,7 @@ class TestSentenceTransformersSimilarityRanker:
"embedding_separator": "\n",
"scale_score": False,
"score_threshold": 0.01,
"trust_remote_code": True,
"model_kwargs": {"torch_dtype": "torch.float16"},
"tokenizer_kwargs": {"model_max_length": 512},
"config_kwargs": None,
@ -141,6 +168,7 @@ class TestSentenceTransformersSimilarityRanker:
"embedding_separator": "\n",
"scale_score": True,
"score_threshold": None,
"trust_remote_code": False,
"model_kwargs": {
"load_in_4bit": True,
"bnb_4bit_use_double_quant": True,
@ -168,6 +196,7 @@ class TestSentenceTransformersSimilarityRanker:
"embedding_separator": "\n",
"scale_score": False,
"score_threshold": 0.01,
"trust_remote_code": False,
"model_kwargs": {"torch_dtype": "torch.float16"},
"tokenizer_kwargs": None,
"config_kwargs": None,
@ -187,6 +216,7 @@ class TestSentenceTransformersSimilarityRanker:
assert component.embedding_separator == "\n"
assert not component.scale_score
assert component.score_threshold == 0.01
assert component.trust_remote_code is False
assert component.model_kwargs == {"torch_dtype": torch.float16}
assert component.tokenizer_kwargs is None
assert component.config_kwargs is None
@ -209,6 +239,7 @@ class TestSentenceTransformersSimilarityRanker:
assert component.embedding_separator == "\n"
assert component.scale_score
assert component.score_threshold is None
assert component.trust_remote_code is False
assert component.model_kwargs is None
assert component.tokenizer_kwargs is None
assert component.config_kwargs is None