feat: Add model_kwargs to ExtractiveReader to impact model loading (#6257)

* Add ability to pass model_kwargs to AutoModelForQuestionAnswering

* Add testing for new model_kwargs

* Add spacing

* Add release notes

* Update haystack/preview/components/readers/extractive.py

Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com>

* Make changes suggested by Stefano

---------

Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com>
This commit is contained in:
Sebastian Husch Lee 2023-11-09 11:25:22 +01:00 committed by GitHub
parent cd429a73cd
commit 71d0d92ea2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 4 deletions

View File

@ -45,10 +45,11 @@ class ExtractiveReader:
answers_per_seq: Optional[int] = None,
no_answer: bool = True,
calibration_factor: float = 0.1,
model_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
"""
Creates an ExtractiveReader
:param model: A HuggingFace transformers question answering model.
:param model_name_or_path: A HuggingFace transformers question answering model.
Can either be a path to a folder containing the model files or an identifier for the HF hub
Default: `'deepset/roberta-base-squad2-distilled'`
:param device: Pytorch device string. Uses GPU by default if available
@ -68,6 +69,8 @@ class ExtractiveReader:
This is relevant when a document has been split into multiple sequence due to max_seq_length.
:param no_answer: Whether to return no answer scores
:param calibration_factor: Factor used for calibrating confidence scores
:param model_kwargs: Additional keyword arguments passed to `AutoModelForQuestionAnswering.from_pretrained`
when loading the model specified in `model_name_or_path`.
"""
torch_and_transformers_import.check()
self.model_name_or_path = str(model_name_or_path)
@ -82,6 +85,7 @@ class ExtractiveReader:
self.answers_per_seq = answers_per_seq
self.no_answer = no_answer
self.calibration_factor = calibration_factor
self.model_kwargs = model_kwargs or {}
def _get_telemetry_data(self) -> Dict[str, Any]:
"""
@ -106,6 +110,7 @@ class ExtractiveReader:
answers_per_seq=self.answers_per_seq,
no_answer=self.no_answer,
calibration_factor=self.calibration_factor,
model_kwargs=self.model_kwargs,
)
def warm_up(self):
@ -120,9 +125,10 @@ class ExtractiveReader:
self.device = self.device or "mps:0"
else:
self.device = self.device or "cpu:0"
self.model = AutoModelForQuestionAnswering.from_pretrained(self.model_name_or_path, token=self.token).to(
self.device
)
self.model = AutoModelForQuestionAnswering.from_pretrained(
self.model_name_or_path, token=self.token, **self.model_kwargs
).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, token=self.token)
def _flatten_documents(

View File

@ -0,0 +1,5 @@
---
enhancements:
- |
Add new variable model_kwargs to the ExtractiveReader so we can pass different loading options supported by
HuggingFace.

View File

@ -89,6 +89,30 @@ example_documents = [
@pytest.mark.unit
def test_to_dict():
component = ExtractiveReader("my-model", token="secret-token", model_kwargs={"torch_dtype": "auto"})
data = component.to_dict()
assert data == {
"type": "ExtractiveReader",
"init_parameters": {
"model_name_or_path": "my-model",
"device": None,
"token": None, # don't serialize valid tokens
"top_k": 20,
"confidence_threshold": None,
"max_seq_length": 384,
"stride": 128,
"max_batch_size": None,
"answers_per_seq": None,
"no_answer": True,
"calibration_factor": 0.1,
"model_kwargs": {"torch_dtype": "auto"},
},
}
@pytest.mark.unit
def test_to_dict_empty_model_kwargs():
component = ExtractiveReader("my-model", token="secret-token")
data = component.to_dict()
@ -106,6 +130,7 @@ def test_to_dict():
"answers_per_seq": None,
"no_answer": True,
"calibration_factor": 0.1,
"model_kwargs": {},
},
}