diff --git a/haystack/preview/components/readers/extractive.py b/haystack/preview/components/readers/extractive.py index 055c0c0c6..7da03ca44 100644 --- a/haystack/preview/components/readers/extractive.py +++ b/haystack/preview/components/readers/extractive.py @@ -45,10 +45,11 @@ class ExtractiveReader: answers_per_seq: Optional[int] = None, no_answer: bool = True, calibration_factor: float = 0.1, + model_kwargs: Optional[Dict[str, Any]] = None, ) -> None: """ Creates an ExtractiveReader - :param model: A HuggingFace transformers question answering model. + :param model_name_or_path: A HuggingFace transformers question answering model. Can either be a path to a folder containing the model files or an identifier for the HF hub Default: `'deepset/roberta-base-squad2-distilled'` :param device: Pytorch device string. Uses GPU by default if available @@ -68,6 +69,8 @@ class ExtractiveReader: This is relevant when a document has been split into multiple sequence due to max_seq_length. :param no_answer: Whether to return no answer scores :param calibration_factor: Factor used for calibrating confidence scores + :param model_kwargs: Additional keyword arguments passed to `AutoModelForQuestionAnswering.from_pretrained` + when loading the model specified in `model_name_or_path`. """ torch_and_transformers_import.check() self.model_name_or_path = str(model_name_or_path) @@ -82,6 +85,7 @@ class ExtractiveReader: self.answers_per_seq = answers_per_seq self.no_answer = no_answer self.calibration_factor = calibration_factor + self.model_kwargs = model_kwargs or {} def _get_telemetry_data(self) -> Dict[str, Any]: """ @@ -106,6 +110,7 @@ class ExtractiveReader: answers_per_seq=self.answers_per_seq, no_answer=self.no_answer, calibration_factor=self.calibration_factor, + model_kwargs=self.model_kwargs, ) def warm_up(self): @@ -120,9 +125,10 @@ class ExtractiveReader: self.device = self.device or "mps:0" else: self.device = self.device or "cpu:0" - self.model = AutoModelForQuestionAnswering.from_pretrained(self.model_name_or_path, token=self.token).to( - self.device - ) + + self.model = AutoModelForQuestionAnswering.from_pretrained( + self.model_name_or_path, token=self.token, **self.model_kwargs + ).to(self.device) self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, token=self.token) def _flatten_documents( diff --git a/releasenotes/notes/add-model-kwargs-extractive-reader-c0b65ab34572408f.yaml b/releasenotes/notes/add-model-kwargs-extractive-reader-c0b65ab34572408f.yaml new file mode 100644 index 000000000..b647329e7 --- /dev/null +++ b/releasenotes/notes/add-model-kwargs-extractive-reader-c0b65ab34572408f.yaml @@ -0,0 +1,5 @@ +--- +enhancements: + - | + Add new variable model_kwargs to the ExtractiveReader so we can pass different loading options supported by + HuggingFace. diff --git a/test/preview/components/readers/test_extractive.py b/test/preview/components/readers/test_extractive.py index 8d2e9cace..060a0e493 100644 --- a/test/preview/components/readers/test_extractive.py +++ b/test/preview/components/readers/test_extractive.py @@ -89,6 +89,30 @@ example_documents = [ @pytest.mark.unit def test_to_dict(): + component = ExtractiveReader("my-model", token="secret-token", model_kwargs={"torch_dtype": "auto"}) + data = component.to_dict() + + assert data == { + "type": "ExtractiveReader", + "init_parameters": { + "model_name_or_path": "my-model", + "device": None, + "token": None, # don't serialize valid tokens + "top_k": 20, + "confidence_threshold": None, + "max_seq_length": 384, + "stride": 128, + "max_batch_size": None, + "answers_per_seq": None, + "no_answer": True, + "calibration_factor": 0.1, + "model_kwargs": {"torch_dtype": "auto"}, + }, + } + + +@pytest.mark.unit +def test_to_dict_empty_model_kwargs(): component = ExtractiveReader("my-model", token="secret-token") data = component.to_dict() @@ -106,6 +130,7 @@ def test_to_dict(): "answers_per_seq": None, "no_answer": True, "calibration_factor": 0.1, + "model_kwargs": {}, }, }