feat!: Rename model_name_or_path to model in ExtractiveReader (#6736)

* rename model parameter and internam model attribute in ExtractiveReader

* fix tests for ExtractiveReader

* fix e2e

* reno

* another fix

* review feedback

* Update releasenotes/notes/rename-model-param-reader-b8cbb0d638e3b8c2.yaml
This commit is contained in:
ZanSara 2024-01-15 14:48:33 +01:00 committed by GitHub
parent b236ea49e3
commit 96c0b59aaa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 14 additions and 11 deletions

View File

@ -13,7 +13,7 @@ def test_extractive_qa_pipeline(tmp_path):
# Create the pipeline # Create the pipeline
qa_pipeline = Pipeline() qa_pipeline = Pipeline()
qa_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=InMemoryDocumentStore()), name="retriever") qa_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=InMemoryDocumentStore()), name="retriever")
qa_pipeline.add_component(instance=ExtractiveReader(model_name_or_path="deepset/tinyroberta-squad2"), name="reader") qa_pipeline.add_component(instance=ExtractiveReader(model="deepset/tinyroberta-squad2"), name="reader")
qa_pipeline.connect("retriever", "reader") qa_pipeline.connect("retriever", "reader")
# Populate the document store # Populate the document store

View File

@ -10,7 +10,7 @@ def test_extractive_qa_pipeline(tmp_path):
# Create the pipeline # Create the pipeline
qa_pipeline = Pipeline() qa_pipeline = Pipeline()
qa_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=InMemoryDocumentStore()), name="retriever") qa_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=InMemoryDocumentStore()), name="retriever")
qa_pipeline.add_component(instance=ExtractiveReader(model_name_or_path="deepset/tinyroberta-squad2"), name="reader") qa_pipeline.add_component(instance=ExtractiveReader(model="deepset/tinyroberta-squad2"), name="reader")
qa_pipeline.connect("retriever", "reader") qa_pipeline.connect("retriever", "reader")
# Draw the pipeline # Draw the pipeline

View File

@ -38,7 +38,7 @@ class ExtractiveReader:
def __init__( def __init__(
self, self,
model_name_or_path: Union[Path, str] = "deepset/roberta-base-squad2-distilled", model: Union[Path, str] = "deepset/roberta-base-squad2-distilled",
device: Optional[str] = None, device: Optional[str] = None,
token: Union[bool, str, None] = None, token: Union[bool, str, None] = None,
top_k: int = 20, top_k: int = 20,
@ -54,7 +54,7 @@ class ExtractiveReader:
) -> None: ) -> None:
""" """
Creates an ExtractiveReader Creates an ExtractiveReader
:param model_name_or_path: A Hugging Face transformers question answering model. :param model: A Hugging Face transformers question answering model.
Can either be a path to a folder containing the model files or an identifier for the Hugging Face hub. Can either be a path to a folder containing the model files or an identifier for the Hugging Face hub.
Default: `'deepset/roberta-base-squad2-distilled'` Default: `'deepset/roberta-base-squad2-distilled'`
:param device: Pytorch device string. Uses GPU by default, if available. :param device: Pytorch device string. Uses GPU by default, if available.
@ -83,11 +83,11 @@ class ExtractiveReader:
both of these answers could be kept if this variable is set to 0.24 or lower. both of these answers could be kept if this variable is set to 0.24 or lower.
If None is provided then all answers are kept. If None is provided then all answers are kept.
:param model_kwargs: Additional keyword arguments passed to `AutoModelForQuestionAnswering.from_pretrained` :param model_kwargs: Additional keyword arguments passed to `AutoModelForQuestionAnswering.from_pretrained`
when loading the model specified in `model_name_or_path`. For details on what kwargs you can pass, when loading the model specified in `model`. For details on what kwargs you can pass,
see the model's documentation. see the model's documentation.
""" """
torch_and_transformers_import.check() torch_and_transformers_import.check()
self.model_name_or_path = str(model_name_or_path) self.model_name_or_path = str(model)
self.model = None self.model = None
self.device = device self.device = device
self.token = token self.token = token
@ -114,7 +114,7 @@ class ExtractiveReader:
""" """
serialization_dict = default_to_dict( serialization_dict = default_to_dict(
self, self,
model_name_or_path=self.model_name_or_path, model=self.model_name_or_path,
device=self.device, device=self.device,
token=self.token if not isinstance(self.token, str) else None, token=self.token if not isinstance(self.token, str) else None,
max_seq_length=self.max_seq_length, max_seq_length=self.max_seq_length,

View File

@ -0,0 +1,3 @@
---
upgrade:
- Rename parameter `model_name_or_path` to `model` in `ExtractiveReader`.

View File

@ -72,7 +72,7 @@ def mock_reader(mock_tokenizer):
with patch("haystack.components.readers.extractive.AutoModelForQuestionAnswering.from_pretrained") as model: with patch("haystack.components.readers.extractive.AutoModelForQuestionAnswering.from_pretrained") as model:
model.return_value = MockModel() model.return_value = MockModel()
reader = ExtractiveReader(model_name_or_path="mock-model", device="cpu:0") reader = ExtractiveReader(model="mock-model", device="cpu:0")
reader.warm_up() reader.warm_up()
return reader return reader
@ -94,7 +94,7 @@ def test_to_dict():
assert data == { assert data == {
"type": "haystack.components.readers.extractive.ExtractiveReader", "type": "haystack.components.readers.extractive.ExtractiveReader",
"init_parameters": { "init_parameters": {
"model_name_or_path": "my-model", "model": "my-model",
"device": None, "device": None,
"token": None, # don't serialize valid tokens "token": None, # don't serialize valid tokens
"top_k": 20, "top_k": 20,
@ -117,7 +117,7 @@ def test_to_dict_empty_model_kwargs():
assert data == { assert data == {
"type": "haystack.components.readers.extractive.ExtractiveReader", "type": "haystack.components.readers.extractive.ExtractiveReader",
"init_parameters": { "init_parameters": {
"model_name_or_path": "my-model", "model": "my-model",
"device": None, "device": None,
"token": None, # don't serialize valid tokens "token": None, # don't serialize valid tokens
"top_k": 20, "top_k": 20,
@ -137,7 +137,7 @@ def test_from_dict():
data = { data = {
"type": "haystack.components.readers.extractive.ExtractiveReader", "type": "haystack.components.readers.extractive.ExtractiveReader",
"init_parameters": { "init_parameters": {
"model_name_or_path": "my-model", "model": "my-model",
"device": None, "device": None,
"token": None, "token": None,
"top_k": 20, "top_k": 20,