mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-14 17:13:03 +00:00
feat: add trust_remote_code parameter to SentenceTransformersSimilarityRanker (#9546)
This commit is contained in:
parent
556dcc9e46
commit
d14f5dca0e
@ -52,6 +52,7 @@ class SentenceTransformersSimilarityRanker:
|
||||
embedding_separator: str = "\n",
|
||||
scale_score: bool = True,
|
||||
score_threshold: Optional[float] = None,
|
||||
trust_remote_code: bool = False,
|
||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
config_kwargs: Optional[Dict[str, Any]] = None,
|
||||
@ -84,6 +85,9 @@ class SentenceTransformersSimilarityRanker:
|
||||
If `False`, disables scaling of the raw logit predictions.
|
||||
:param score_threshold:
|
||||
Use it to return documents with a score above this threshold only.
|
||||
:param trust_remote_code:
|
||||
If `False`, allows only Hugging Face verified model architectures.
|
||||
If `True`, allows custom models and scripts.
|
||||
:param model_kwargs:
|
||||
Additional keyword arguments for `AutoModelForSequenceClassification.from_pretrained`
|
||||
when loading the model. Refer to specific model documentation for available kwargs.
|
||||
@ -119,6 +123,7 @@ class SentenceTransformersSimilarityRanker:
|
||||
self.embedding_separator = embedding_separator
|
||||
self.scale_score = scale_score
|
||||
self.score_threshold = score_threshold
|
||||
self.trust_remote_code = trust_remote_code
|
||||
self.model_kwargs = model_kwargs
|
||||
self.tokenizer_kwargs = tokenizer_kwargs
|
||||
self.config_kwargs = config_kwargs
|
||||
@ -140,6 +145,7 @@ class SentenceTransformersSimilarityRanker:
|
||||
model_name_or_path=self.model,
|
||||
device=self.device.to_torch_str(),
|
||||
token=self.token.resolve_value() if self.token else None,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
model_kwargs=self.model_kwargs,
|
||||
tokenizer_kwargs=self.tokenizer_kwargs,
|
||||
config_kwargs=self.config_kwargs,
|
||||
@ -165,6 +171,7 @@ class SentenceTransformersSimilarityRanker:
|
||||
embedding_separator=self.embedding_separator,
|
||||
scale_score=self.scale_score,
|
||||
score_threshold=self.score_threshold,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
model_kwargs=self.model_kwargs,
|
||||
tokenizer_kwargs=self.tokenizer_kwargs,
|
||||
config_kwargs=self.config_kwargs,
|
||||
|
||||
@ -0,0 +1,5 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
Added a `trust_remote_code` parameter to the `SentenceTransformersSimilarityRanker` component.
|
||||
When set to True, this enables execution of custom models and scripts hosted on the Hugging Face Hub.
|
||||
@ -19,7 +19,29 @@ class TestSentenceTransformersSimilarityRanker:
|
||||
SentenceTransformersSimilarityRanker(top_k=-1)
|
||||
|
||||
@patch("haystack.components.rankers.sentence_transformers_similarity.CrossEncoder")
|
||||
def test_init_onnx_backend(self, mocked_cross_encoder):
|
||||
def test_init_warm_up_torch_backend(self, mocked_cross_encoder):
|
||||
ranker = SentenceTransformersSimilarityRanker(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
token=None,
|
||||
device=ComponentDevice.from_str("cpu"),
|
||||
backend="torch",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
ranker.warm_up()
|
||||
mocked_cross_encoder.assert_called_once_with(
|
||||
model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
|
||||
device="cpu",
|
||||
token=None,
|
||||
trust_remote_code=True,
|
||||
model_kwargs=None,
|
||||
tokenizer_kwargs=None,
|
||||
config_kwargs=None,
|
||||
backend="torch",
|
||||
)
|
||||
|
||||
@patch("haystack.components.rankers.sentence_transformers_similarity.CrossEncoder")
|
||||
def test_init_warm_up_onnx_backend(self, mocked_cross_encoder):
|
||||
onnx_ranker = SentenceTransformersSimilarityRanker(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
token=None,
|
||||
@ -32,6 +54,7 @@ class TestSentenceTransformersSimilarityRanker:
|
||||
model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
|
||||
device="cpu",
|
||||
token=None,
|
||||
trust_remote_code=False,
|
||||
model_kwargs=None,
|
||||
tokenizer_kwargs=None,
|
||||
config_kwargs=None,
|
||||
@ -39,7 +62,7 @@ class TestSentenceTransformersSimilarityRanker:
|
||||
)
|
||||
|
||||
@patch("haystack.components.rankers.sentence_transformers_similarity.CrossEncoder")
|
||||
def test_init_openvino_backend(self, mocked_cross_encoder):
|
||||
def test_init_warm_up_openvino_backend(self, mocked_cross_encoder):
|
||||
openvino_ranker = SentenceTransformersSimilarityRanker(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
token=None,
|
||||
@ -52,6 +75,7 @@ class TestSentenceTransformersSimilarityRanker:
|
||||
model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
|
||||
device="cpu",
|
||||
token=None,
|
||||
trust_remote_code=False,
|
||||
model_kwargs=None,
|
||||
tokenizer_kwargs=None,
|
||||
config_kwargs=None,
|
||||
@ -74,6 +98,7 @@ class TestSentenceTransformersSimilarityRanker:
|
||||
"embedding_separator": "\n",
|
||||
"scale_score": True,
|
||||
"score_threshold": None,
|
||||
"trust_remote_code": False,
|
||||
"model_kwargs": None,
|
||||
"tokenizer_kwargs": None,
|
||||
"config_kwargs": None,
|
||||
@ -92,6 +117,7 @@ class TestSentenceTransformersSimilarityRanker:
|
||||
document_prefix="document_instruction: ",
|
||||
scale_score=False,
|
||||
score_threshold=0.01,
|
||||
trust_remote_code=True,
|
||||
model_kwargs={"torch_dtype": torch.float16},
|
||||
tokenizer_kwargs={"model_max_length": 512},
|
||||
batch_size=32,
|
||||
@ -110,6 +136,7 @@ class TestSentenceTransformersSimilarityRanker:
|
||||
"embedding_separator": "\n",
|
||||
"scale_score": False,
|
||||
"score_threshold": 0.01,
|
||||
"trust_remote_code": True,
|
||||
"model_kwargs": {"torch_dtype": "torch.float16"},
|
||||
"tokenizer_kwargs": {"model_max_length": 512},
|
||||
"config_kwargs": None,
|
||||
@ -141,6 +168,7 @@ class TestSentenceTransformersSimilarityRanker:
|
||||
"embedding_separator": "\n",
|
||||
"scale_score": True,
|
||||
"score_threshold": None,
|
||||
"trust_remote_code": False,
|
||||
"model_kwargs": {
|
||||
"load_in_4bit": True,
|
||||
"bnb_4bit_use_double_quant": True,
|
||||
@ -168,6 +196,7 @@ class TestSentenceTransformersSimilarityRanker:
|
||||
"embedding_separator": "\n",
|
||||
"scale_score": False,
|
||||
"score_threshold": 0.01,
|
||||
"trust_remote_code": False,
|
||||
"model_kwargs": {"torch_dtype": "torch.float16"},
|
||||
"tokenizer_kwargs": None,
|
||||
"config_kwargs": None,
|
||||
@ -187,6 +216,7 @@ class TestSentenceTransformersSimilarityRanker:
|
||||
assert component.embedding_separator == "\n"
|
||||
assert not component.scale_score
|
||||
assert component.score_threshold == 0.01
|
||||
assert component.trust_remote_code is False
|
||||
assert component.model_kwargs == {"torch_dtype": torch.float16}
|
||||
assert component.tokenizer_kwargs is None
|
||||
assert component.config_kwargs is None
|
||||
@ -209,6 +239,7 @@ class TestSentenceTransformersSimilarityRanker:
|
||||
assert component.embedding_separator == "\n"
|
||||
assert component.scale_score
|
||||
assert component.score_threshold is None
|
||||
assert component.trust_remote_code is False
|
||||
assert component.model_kwargs is None
|
||||
assert component.tokenizer_kwargs is None
|
||||
assert component.config_kwargs is None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user