feat!: Rename model_name_or_path to model in ExtractiveReader (#6736)

* rename model parameter and internam model attribute in ExtractiveReader

* fix tests for ExtractiveReader

* fix e2e

* reno

* another fix

* review feedback

* Update releasenotes/notes/rename-model-param-reader-b8cbb0d638e3b8c2.yaml
This commit is contained in:
ZanSara 2024-01-15 14:48:33 +01:00 committed by GitHub
parent b236ea49e3
commit 96c0b59aaa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 14 additions and 11 deletions

View File

@ -13,7 +13,7 @@ def test_extractive_qa_pipeline(tmp_path):
# Create the pipeline
qa_pipeline = Pipeline()
qa_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=InMemoryDocumentStore()), name="retriever")
qa_pipeline.add_component(instance=ExtractiveReader(model_name_or_path="deepset/tinyroberta-squad2"), name="reader")
qa_pipeline.add_component(instance=ExtractiveReader(model="deepset/tinyroberta-squad2"), name="reader")
qa_pipeline.connect("retriever", "reader")
# Populate the document store

View File

@ -10,7 +10,7 @@ def test_extractive_qa_pipeline(tmp_path):
# Create the pipeline
qa_pipeline = Pipeline()
qa_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=InMemoryDocumentStore()), name="retriever")
qa_pipeline.add_component(instance=ExtractiveReader(model_name_or_path="deepset/tinyroberta-squad2"), name="reader")
qa_pipeline.add_component(instance=ExtractiveReader(model="deepset/tinyroberta-squad2"), name="reader")
qa_pipeline.connect("retriever", "reader")
# Draw the pipeline

View File

@ -38,7 +38,7 @@ class ExtractiveReader:
def __init__(
self,
model_name_or_path: Union[Path, str] = "deepset/roberta-base-squad2-distilled",
model: Union[Path, str] = "deepset/roberta-base-squad2-distilled",
device: Optional[str] = None,
token: Union[bool, str, None] = None,
top_k: int = 20,
@ -54,7 +54,7 @@ class ExtractiveReader:
) -> None:
"""
Creates an ExtractiveReader
:param model_name_or_path: A Hugging Face transformers question answering model.
:param model: A Hugging Face transformers question answering model.
Can either be a path to a folder containing the model files or an identifier for the Hugging Face hub.
Default: `'deepset/roberta-base-squad2-distilled'`
:param device: Pytorch device string. Uses GPU by default, if available.
@ -83,11 +83,11 @@ class ExtractiveReader:
both of these answers could be kept if this variable is set to 0.24 or lower.
If None is provided then all answers are kept.
:param model_kwargs: Additional keyword arguments passed to `AutoModelForQuestionAnswering.from_pretrained`
when loading the model specified in `model_name_or_path`. For details on what kwargs you can pass,
when loading the model specified in `model`. For details on what kwargs you can pass,
see the model's documentation.
"""
torch_and_transformers_import.check()
self.model_name_or_path = str(model_name_or_path)
self.model_name_or_path = str(model)
self.model = None
self.device = device
self.token = token
@ -114,7 +114,7 @@ class ExtractiveReader:
"""
serialization_dict = default_to_dict(
self,
model_name_or_path=self.model_name_or_path,
model=self.model_name_or_path,
device=self.device,
token=self.token if not isinstance(self.token, str) else None,
max_seq_length=self.max_seq_length,

View File

@ -0,0 +1,3 @@
---
upgrade:
- Rename parameter `model_name_or_path` to `model` in `ExtractiveReader`.

View File

@ -72,7 +72,7 @@ def mock_reader(mock_tokenizer):
with patch("haystack.components.readers.extractive.AutoModelForQuestionAnswering.from_pretrained") as model:
model.return_value = MockModel()
reader = ExtractiveReader(model_name_or_path="mock-model", device="cpu:0")
reader = ExtractiveReader(model="mock-model", device="cpu:0")
reader.warm_up()
return reader
@ -94,7 +94,7 @@ def test_to_dict():
assert data == {
"type": "haystack.components.readers.extractive.ExtractiveReader",
"init_parameters": {
"model_name_or_path": "my-model",
"model": "my-model",
"device": None,
"token": None, # don't serialize valid tokens
"top_k": 20,
@ -117,7 +117,7 @@ def test_to_dict_empty_model_kwargs():
assert data == {
"type": "haystack.components.readers.extractive.ExtractiveReader",
"init_parameters": {
"model_name_or_path": "my-model",
"model": "my-model",
"device": None,
"token": None, # don't serialize valid tokens
"top_k": 20,
@ -137,7 +137,7 @@ def test_from_dict():
data = {
"type": "haystack.components.readers.extractive.ExtractiveReader",
"init_parameters": {
"model_name_or_path": "my-model",
"model": "my-model",
"device": None,
"token": None,
"top_k": 20,