mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-06 03:57:19 +00:00
feat: Add model_kwargs to ExtractiveReader to impact model loading (#6257)
* Add ability to pass model_kwargs to AutoModelForQuestionAnswering * Add testing for new model_kwargs * Add spacing * Add release notes * Update haystack/preview/components/readers/extractive.py Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com> * Make changes suggested by Stefano --------- Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com>
This commit is contained in:
parent
cd429a73cd
commit
71d0d92ea2
@ -45,10 +45,11 @@ class ExtractiveReader:
|
||||
answers_per_seq: Optional[int] = None,
|
||||
no_answer: bool = True,
|
||||
calibration_factor: float = 0.1,
|
||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Creates an ExtractiveReader
|
||||
:param model: A HuggingFace transformers question answering model.
|
||||
:param model_name_or_path: A HuggingFace transformers question answering model.
|
||||
Can either be a path to a folder containing the model files or an identifier for the HF hub
|
||||
Default: `'deepset/roberta-base-squad2-distilled'`
|
||||
:param device: Pytorch device string. Uses GPU by default if available
|
||||
@ -68,6 +69,8 @@ class ExtractiveReader:
|
||||
This is relevant when a document has been split into multiple sequence due to max_seq_length.
|
||||
:param no_answer: Whether to return no answer scores
|
||||
:param calibration_factor: Factor used for calibrating confidence scores
|
||||
:param model_kwargs: Additional keyword arguments passed to `AutoModelForQuestionAnswering.from_pretrained`
|
||||
when loading the model specified in `model_name_or_path`.
|
||||
"""
|
||||
torch_and_transformers_import.check()
|
||||
self.model_name_or_path = str(model_name_or_path)
|
||||
@ -82,6 +85,7 @@ class ExtractiveReader:
|
||||
self.answers_per_seq = answers_per_seq
|
||||
self.no_answer = no_answer
|
||||
self.calibration_factor = calibration_factor
|
||||
self.model_kwargs = model_kwargs or {}
|
||||
|
||||
def _get_telemetry_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -106,6 +110,7 @@ class ExtractiveReader:
|
||||
answers_per_seq=self.answers_per_seq,
|
||||
no_answer=self.no_answer,
|
||||
calibration_factor=self.calibration_factor,
|
||||
model_kwargs=self.model_kwargs,
|
||||
)
|
||||
|
||||
def warm_up(self):
|
||||
@ -120,9 +125,10 @@ class ExtractiveReader:
|
||||
self.device = self.device or "mps:0"
|
||||
else:
|
||||
self.device = self.device or "cpu:0"
|
||||
self.model = AutoModelForQuestionAnswering.from_pretrained(self.model_name_or_path, token=self.token).to(
|
||||
self.device
|
||||
)
|
||||
|
||||
self.model = AutoModelForQuestionAnswering.from_pretrained(
|
||||
self.model_name_or_path, token=self.token, **self.model_kwargs
|
||||
).to(self.device)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, token=self.token)
|
||||
|
||||
def _flatten_documents(
|
||||
|
||||
@ -0,0 +1,5 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
Add new variable model_kwargs to the ExtractiveReader so we can pass different loading options supported by
|
||||
HuggingFace.
|
||||
@ -89,6 +89,30 @@ example_documents = [
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict():
|
||||
component = ExtractiveReader("my-model", token="secret-token", model_kwargs={"torch_dtype": "auto"})
|
||||
data = component.to_dict()
|
||||
|
||||
assert data == {
|
||||
"type": "ExtractiveReader",
|
||||
"init_parameters": {
|
||||
"model_name_or_path": "my-model",
|
||||
"device": None,
|
||||
"token": None, # don't serialize valid tokens
|
||||
"top_k": 20,
|
||||
"confidence_threshold": None,
|
||||
"max_seq_length": 384,
|
||||
"stride": 128,
|
||||
"max_batch_size": None,
|
||||
"answers_per_seq": None,
|
||||
"no_answer": True,
|
||||
"calibration_factor": 0.1,
|
||||
"model_kwargs": {"torch_dtype": "auto"},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_empty_model_kwargs():
|
||||
component = ExtractiveReader("my-model", token="secret-token")
|
||||
data = component.to_dict()
|
||||
|
||||
@ -106,6 +130,7 @@ def test_to_dict():
|
||||
"answers_per_seq": None,
|
||||
"no_answer": True,
|
||||
"calibration_factor": 0.1,
|
||||
"model_kwargs": {},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user