diff --git a/haystack/components/rankers/sentence_transformers_similarity.py b/haystack/components/rankers/sentence_transformers_similarity.py index 82e649814..2040da3ec 100644 --- a/haystack/components/rankers/sentence_transformers_similarity.py +++ b/haystack/components/rankers/sentence_transformers_similarity.py @@ -52,6 +52,7 @@ class SentenceTransformersSimilarityRanker: embedding_separator: str = "\n", scale_score: bool = True, score_threshold: Optional[float] = None, + trust_remote_code: bool = False, model_kwargs: Optional[Dict[str, Any]] = None, tokenizer_kwargs: Optional[Dict[str, Any]] = None, config_kwargs: Optional[Dict[str, Any]] = None, @@ -84,6 +85,9 @@ class SentenceTransformersSimilarityRanker: If `False`, disables scaling of the raw logit predictions. :param score_threshold: Use it to return documents with a score above this threshold only. + :param trust_remote_code: + If `False`, allows only Hugging Face verified model architectures. + If `True`, allows custom models and scripts. :param model_kwargs: Additional keyword arguments for `AutoModelForSequenceClassification.from_pretrained` when loading the model. Refer to specific model documentation for available kwargs. @@ -119,6 +123,7 @@ class SentenceTransformersSimilarityRanker: self.embedding_separator = embedding_separator self.scale_score = scale_score self.score_threshold = score_threshold + self.trust_remote_code = trust_remote_code self.model_kwargs = model_kwargs self.tokenizer_kwargs = tokenizer_kwargs self.config_kwargs = config_kwargs @@ -140,6 +145,7 @@ class SentenceTransformersSimilarityRanker: model_name_or_path=self.model, device=self.device.to_torch_str(), token=self.token.resolve_value() if self.token else None, + trust_remote_code=self.trust_remote_code, model_kwargs=self.model_kwargs, tokenizer_kwargs=self.tokenizer_kwargs, config_kwargs=self.config_kwargs, @@ -165,6 +171,7 @@ class SentenceTransformersSimilarityRanker: embedding_separator=self.embedding_separator, scale_score=self.scale_score, score_threshold=self.score_threshold, + trust_remote_code=self.trust_remote_code, model_kwargs=self.model_kwargs, tokenizer_kwargs=self.tokenizer_kwargs, config_kwargs=self.config_kwargs, diff --git a/releasenotes/notes/st-similarity-ranker-trust-remote-code-7e00abfc96afa698.yaml b/releasenotes/notes/st-similarity-ranker-trust-remote-code-7e00abfc96afa698.yaml new file mode 100644 index 000000000..4e8532847 --- /dev/null +++ b/releasenotes/notes/st-similarity-ranker-trust-remote-code-7e00abfc96afa698.yaml @@ -0,0 +1,5 @@ +--- +enhancements: + - | + Added a `trust_remote_code` parameter to the `SentenceTransformersSimilarityRanker` component. + When set to True, this enables execution of custom models and scripts hosted on the Hugging Face Hub. diff --git a/test/components/rankers/test_sentence_transformers_similarity.py b/test/components/rankers/test_sentence_transformers_similarity.py index 8d0836cc5..d7e68589f 100644 --- a/test/components/rankers/test_sentence_transformers_similarity.py +++ b/test/components/rankers/test_sentence_transformers_similarity.py @@ -19,7 +19,29 @@ class TestSentenceTransformersSimilarityRanker: SentenceTransformersSimilarityRanker(top_k=-1) @patch("haystack.components.rankers.sentence_transformers_similarity.CrossEncoder") - def test_init_onnx_backend(self, mocked_cross_encoder): + def test_init_warm_up_torch_backend(self, mocked_cross_encoder): + ranker = SentenceTransformersSimilarityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", + token=None, + device=ComponentDevice.from_str("cpu"), + backend="torch", + trust_remote_code=True, + ) + + ranker.warm_up() + mocked_cross_encoder.assert_called_once_with( + model_name_or_path="sentence-transformers/all-MiniLM-L6-v2", + device="cpu", + token=None, + trust_remote_code=True, + model_kwargs=None, + tokenizer_kwargs=None, + config_kwargs=None, + backend="torch", + ) + + @patch("haystack.components.rankers.sentence_transformers_similarity.CrossEncoder") + def test_init_warm_up_onnx_backend(self, mocked_cross_encoder): onnx_ranker = SentenceTransformersSimilarityRanker( model="sentence-transformers/all-MiniLM-L6-v2", token=None, @@ -32,6 +54,7 @@ class TestSentenceTransformersSimilarityRanker: model_name_or_path="sentence-transformers/all-MiniLM-L6-v2", device="cpu", token=None, + trust_remote_code=False, model_kwargs=None, tokenizer_kwargs=None, config_kwargs=None, @@ -39,7 +62,7 @@ class TestSentenceTransformersSimilarityRanker: ) @patch("haystack.components.rankers.sentence_transformers_similarity.CrossEncoder") - def test_init_openvino_backend(self, mocked_cross_encoder): + def test_init_warm_up_openvino_backend(self, mocked_cross_encoder): openvino_ranker = SentenceTransformersSimilarityRanker( model="sentence-transformers/all-MiniLM-L6-v2", token=None, @@ -52,6 +75,7 @@ class TestSentenceTransformersSimilarityRanker: model_name_or_path="sentence-transformers/all-MiniLM-L6-v2", device="cpu", token=None, + trust_remote_code=False, model_kwargs=None, tokenizer_kwargs=None, config_kwargs=None, @@ -74,6 +98,7 @@ class TestSentenceTransformersSimilarityRanker: "embedding_separator": "\n", "scale_score": True, "score_threshold": None, + "trust_remote_code": False, "model_kwargs": None, "tokenizer_kwargs": None, "config_kwargs": None, @@ -92,6 +117,7 @@ class TestSentenceTransformersSimilarityRanker: document_prefix="document_instruction: ", scale_score=False, score_threshold=0.01, + trust_remote_code=True, model_kwargs={"torch_dtype": torch.float16}, tokenizer_kwargs={"model_max_length": 512}, batch_size=32, @@ -110,6 +136,7 @@ class TestSentenceTransformersSimilarityRanker: "embedding_separator": "\n", "scale_score": False, "score_threshold": 0.01, + "trust_remote_code": True, "model_kwargs": {"torch_dtype": "torch.float16"}, "tokenizer_kwargs": {"model_max_length": 512}, "config_kwargs": None, @@ -141,6 +168,7 @@ class TestSentenceTransformersSimilarityRanker: "embedding_separator": "\n", "scale_score": True, "score_threshold": None, + "trust_remote_code": False, "model_kwargs": { "load_in_4bit": True, "bnb_4bit_use_double_quant": True, @@ -168,6 +196,7 @@ class TestSentenceTransformersSimilarityRanker: "embedding_separator": "\n", "scale_score": False, "score_threshold": 0.01, + "trust_remote_code": False, "model_kwargs": {"torch_dtype": "torch.float16"}, "tokenizer_kwargs": None, "config_kwargs": None, @@ -187,6 +216,7 @@ class TestSentenceTransformersSimilarityRanker: assert component.embedding_separator == "\n" assert not component.scale_score assert component.score_threshold == 0.01 + assert component.trust_remote_code is False assert component.model_kwargs == {"torch_dtype": torch.float16} assert component.tokenizer_kwargs is None assert component.config_kwargs is None @@ -209,6 +239,7 @@ class TestSentenceTransformersSimilarityRanker: assert component.embedding_separator == "\n" assert component.scale_score assert component.score_threshold is None + assert component.trust_remote_code is False assert component.model_kwargs is None assert component.tokenizer_kwargs is None assert component.config_kwargs is None