mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-16 18:13:54 +00:00
feat: add trust_remote_code parameter to SentenceTransformersSimilarityRanker (#9546)
This commit is contained in:
parent
556dcc9e46
commit
d14f5dca0e
@ -52,6 +52,7 @@ class SentenceTransformersSimilarityRanker:
|
|||||||
embedding_separator: str = "\n",
|
embedding_separator: str = "\n",
|
||||||
scale_score: bool = True,
|
scale_score: bool = True,
|
||||||
score_threshold: Optional[float] = None,
|
score_threshold: Optional[float] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
|
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
config_kwargs: Optional[Dict[str, Any]] = None,
|
config_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
@ -84,6 +85,9 @@ class SentenceTransformersSimilarityRanker:
|
|||||||
If `False`, disables scaling of the raw logit predictions.
|
If `False`, disables scaling of the raw logit predictions.
|
||||||
:param score_threshold:
|
:param score_threshold:
|
||||||
Use it to return documents with a score above this threshold only.
|
Use it to return documents with a score above this threshold only.
|
||||||
|
:param trust_remote_code:
|
||||||
|
If `False`, allows only Hugging Face verified model architectures.
|
||||||
|
If `True`, allows custom models and scripts.
|
||||||
:param model_kwargs:
|
:param model_kwargs:
|
||||||
Additional keyword arguments for `AutoModelForSequenceClassification.from_pretrained`
|
Additional keyword arguments for `AutoModelForSequenceClassification.from_pretrained`
|
||||||
when loading the model. Refer to specific model documentation for available kwargs.
|
when loading the model. Refer to specific model documentation for available kwargs.
|
||||||
@ -119,6 +123,7 @@ class SentenceTransformersSimilarityRanker:
|
|||||||
self.embedding_separator = embedding_separator
|
self.embedding_separator = embedding_separator
|
||||||
self.scale_score = scale_score
|
self.scale_score = scale_score
|
||||||
self.score_threshold = score_threshold
|
self.score_threshold = score_threshold
|
||||||
|
self.trust_remote_code = trust_remote_code
|
||||||
self.model_kwargs = model_kwargs
|
self.model_kwargs = model_kwargs
|
||||||
self.tokenizer_kwargs = tokenizer_kwargs
|
self.tokenizer_kwargs = tokenizer_kwargs
|
||||||
self.config_kwargs = config_kwargs
|
self.config_kwargs = config_kwargs
|
||||||
@ -140,6 +145,7 @@ class SentenceTransformersSimilarityRanker:
|
|||||||
model_name_or_path=self.model,
|
model_name_or_path=self.model,
|
||||||
device=self.device.to_torch_str(),
|
device=self.device.to_torch_str(),
|
||||||
token=self.token.resolve_value() if self.token else None,
|
token=self.token.resolve_value() if self.token else None,
|
||||||
|
trust_remote_code=self.trust_remote_code,
|
||||||
model_kwargs=self.model_kwargs,
|
model_kwargs=self.model_kwargs,
|
||||||
tokenizer_kwargs=self.tokenizer_kwargs,
|
tokenizer_kwargs=self.tokenizer_kwargs,
|
||||||
config_kwargs=self.config_kwargs,
|
config_kwargs=self.config_kwargs,
|
||||||
@ -165,6 +171,7 @@ class SentenceTransformersSimilarityRanker:
|
|||||||
embedding_separator=self.embedding_separator,
|
embedding_separator=self.embedding_separator,
|
||||||
scale_score=self.scale_score,
|
scale_score=self.scale_score,
|
||||||
score_threshold=self.score_threshold,
|
score_threshold=self.score_threshold,
|
||||||
|
trust_remote_code=self.trust_remote_code,
|
||||||
model_kwargs=self.model_kwargs,
|
model_kwargs=self.model_kwargs,
|
||||||
tokenizer_kwargs=self.tokenizer_kwargs,
|
tokenizer_kwargs=self.tokenizer_kwargs,
|
||||||
config_kwargs=self.config_kwargs,
|
config_kwargs=self.config_kwargs,
|
||||||
|
|||||||
@ -0,0 +1,5 @@
|
|||||||
|
---
|
||||||
|
enhancements:
|
||||||
|
- |
|
||||||
|
Added a `trust_remote_code` parameter to the `SentenceTransformersSimilarityRanker` component.
|
||||||
|
When set to True, this enables execution of custom models and scripts hosted on the Hugging Face Hub.
|
||||||
@ -19,7 +19,29 @@ class TestSentenceTransformersSimilarityRanker:
|
|||||||
SentenceTransformersSimilarityRanker(top_k=-1)
|
SentenceTransformersSimilarityRanker(top_k=-1)
|
||||||
|
|
||||||
@patch("haystack.components.rankers.sentence_transformers_similarity.CrossEncoder")
|
@patch("haystack.components.rankers.sentence_transformers_similarity.CrossEncoder")
|
||||||
def test_init_onnx_backend(self, mocked_cross_encoder):
|
def test_init_warm_up_torch_backend(self, mocked_cross_encoder):
|
||||||
|
ranker = SentenceTransformersSimilarityRanker(
|
||||||
|
model="sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
token=None,
|
||||||
|
device=ComponentDevice.from_str("cpu"),
|
||||||
|
backend="torch",
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
ranker.warm_up()
|
||||||
|
mocked_cross_encoder.assert_called_once_with(
|
||||||
|
model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
device="cpu",
|
||||||
|
token=None,
|
||||||
|
trust_remote_code=True,
|
||||||
|
model_kwargs=None,
|
||||||
|
tokenizer_kwargs=None,
|
||||||
|
config_kwargs=None,
|
||||||
|
backend="torch",
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("haystack.components.rankers.sentence_transformers_similarity.CrossEncoder")
|
||||||
|
def test_init_warm_up_onnx_backend(self, mocked_cross_encoder):
|
||||||
onnx_ranker = SentenceTransformersSimilarityRanker(
|
onnx_ranker = SentenceTransformersSimilarityRanker(
|
||||||
model="sentence-transformers/all-MiniLM-L6-v2",
|
model="sentence-transformers/all-MiniLM-L6-v2",
|
||||||
token=None,
|
token=None,
|
||||||
@ -32,6 +54,7 @@ class TestSentenceTransformersSimilarityRanker:
|
|||||||
model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
|
model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
|
||||||
device="cpu",
|
device="cpu",
|
||||||
token=None,
|
token=None,
|
||||||
|
trust_remote_code=False,
|
||||||
model_kwargs=None,
|
model_kwargs=None,
|
||||||
tokenizer_kwargs=None,
|
tokenizer_kwargs=None,
|
||||||
config_kwargs=None,
|
config_kwargs=None,
|
||||||
@ -39,7 +62,7 @@ class TestSentenceTransformersSimilarityRanker:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@patch("haystack.components.rankers.sentence_transformers_similarity.CrossEncoder")
|
@patch("haystack.components.rankers.sentence_transformers_similarity.CrossEncoder")
|
||||||
def test_init_openvino_backend(self, mocked_cross_encoder):
|
def test_init_warm_up_openvino_backend(self, mocked_cross_encoder):
|
||||||
openvino_ranker = SentenceTransformersSimilarityRanker(
|
openvino_ranker = SentenceTransformersSimilarityRanker(
|
||||||
model="sentence-transformers/all-MiniLM-L6-v2",
|
model="sentence-transformers/all-MiniLM-L6-v2",
|
||||||
token=None,
|
token=None,
|
||||||
@ -52,6 +75,7 @@ class TestSentenceTransformersSimilarityRanker:
|
|||||||
model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
|
model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
|
||||||
device="cpu",
|
device="cpu",
|
||||||
token=None,
|
token=None,
|
||||||
|
trust_remote_code=False,
|
||||||
model_kwargs=None,
|
model_kwargs=None,
|
||||||
tokenizer_kwargs=None,
|
tokenizer_kwargs=None,
|
||||||
config_kwargs=None,
|
config_kwargs=None,
|
||||||
@ -74,6 +98,7 @@ class TestSentenceTransformersSimilarityRanker:
|
|||||||
"embedding_separator": "\n",
|
"embedding_separator": "\n",
|
||||||
"scale_score": True,
|
"scale_score": True,
|
||||||
"score_threshold": None,
|
"score_threshold": None,
|
||||||
|
"trust_remote_code": False,
|
||||||
"model_kwargs": None,
|
"model_kwargs": None,
|
||||||
"tokenizer_kwargs": None,
|
"tokenizer_kwargs": None,
|
||||||
"config_kwargs": None,
|
"config_kwargs": None,
|
||||||
@ -92,6 +117,7 @@ class TestSentenceTransformersSimilarityRanker:
|
|||||||
document_prefix="document_instruction: ",
|
document_prefix="document_instruction: ",
|
||||||
scale_score=False,
|
scale_score=False,
|
||||||
score_threshold=0.01,
|
score_threshold=0.01,
|
||||||
|
trust_remote_code=True,
|
||||||
model_kwargs={"torch_dtype": torch.float16},
|
model_kwargs={"torch_dtype": torch.float16},
|
||||||
tokenizer_kwargs={"model_max_length": 512},
|
tokenizer_kwargs={"model_max_length": 512},
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
@ -110,6 +136,7 @@ class TestSentenceTransformersSimilarityRanker:
|
|||||||
"embedding_separator": "\n",
|
"embedding_separator": "\n",
|
||||||
"scale_score": False,
|
"scale_score": False,
|
||||||
"score_threshold": 0.01,
|
"score_threshold": 0.01,
|
||||||
|
"trust_remote_code": True,
|
||||||
"model_kwargs": {"torch_dtype": "torch.float16"},
|
"model_kwargs": {"torch_dtype": "torch.float16"},
|
||||||
"tokenizer_kwargs": {"model_max_length": 512},
|
"tokenizer_kwargs": {"model_max_length": 512},
|
||||||
"config_kwargs": None,
|
"config_kwargs": None,
|
||||||
@ -141,6 +168,7 @@ class TestSentenceTransformersSimilarityRanker:
|
|||||||
"embedding_separator": "\n",
|
"embedding_separator": "\n",
|
||||||
"scale_score": True,
|
"scale_score": True,
|
||||||
"score_threshold": None,
|
"score_threshold": None,
|
||||||
|
"trust_remote_code": False,
|
||||||
"model_kwargs": {
|
"model_kwargs": {
|
||||||
"load_in_4bit": True,
|
"load_in_4bit": True,
|
||||||
"bnb_4bit_use_double_quant": True,
|
"bnb_4bit_use_double_quant": True,
|
||||||
@ -168,6 +196,7 @@ class TestSentenceTransformersSimilarityRanker:
|
|||||||
"embedding_separator": "\n",
|
"embedding_separator": "\n",
|
||||||
"scale_score": False,
|
"scale_score": False,
|
||||||
"score_threshold": 0.01,
|
"score_threshold": 0.01,
|
||||||
|
"trust_remote_code": False,
|
||||||
"model_kwargs": {"torch_dtype": "torch.float16"},
|
"model_kwargs": {"torch_dtype": "torch.float16"},
|
||||||
"tokenizer_kwargs": None,
|
"tokenizer_kwargs": None,
|
||||||
"config_kwargs": None,
|
"config_kwargs": None,
|
||||||
@ -187,6 +216,7 @@ class TestSentenceTransformersSimilarityRanker:
|
|||||||
assert component.embedding_separator == "\n"
|
assert component.embedding_separator == "\n"
|
||||||
assert not component.scale_score
|
assert not component.scale_score
|
||||||
assert component.score_threshold == 0.01
|
assert component.score_threshold == 0.01
|
||||||
|
assert component.trust_remote_code is False
|
||||||
assert component.model_kwargs == {"torch_dtype": torch.float16}
|
assert component.model_kwargs == {"torch_dtype": torch.float16}
|
||||||
assert component.tokenizer_kwargs is None
|
assert component.tokenizer_kwargs is None
|
||||||
assert component.config_kwargs is None
|
assert component.config_kwargs is None
|
||||||
@ -209,6 +239,7 @@ class TestSentenceTransformersSimilarityRanker:
|
|||||||
assert component.embedding_separator == "\n"
|
assert component.embedding_separator == "\n"
|
||||||
assert component.scale_score
|
assert component.scale_score
|
||||||
assert component.score_threshold is None
|
assert component.score_threshold is None
|
||||||
|
assert component.trust_remote_code is False
|
||||||
assert component.model_kwargs is None
|
assert component.model_kwargs is None
|
||||||
assert component.tokenizer_kwargs is None
|
assert component.tokenizer_kwargs is None
|
||||||
assert component.config_kwargs is None
|
assert component.config_kwargs is None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user