diff --git a/e2e/pipelines/test_eval_extractive_qa_pipeline.py b/e2e/pipelines/test_eval_extractive_qa_pipeline.py index 5b5df26b5..d5f8fcf3d 100644 --- a/e2e/pipelines/test_eval_extractive_qa_pipeline.py +++ b/e2e/pipelines/test_eval_extractive_qa_pipeline.py @@ -13,7 +13,7 @@ def test_extractive_qa_pipeline(tmp_path): # Create the pipeline qa_pipeline = Pipeline() qa_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=InMemoryDocumentStore()), name="retriever") - qa_pipeline.add_component(instance=ExtractiveReader(model_name_or_path="deepset/tinyroberta-squad2"), name="reader") + qa_pipeline.add_component(instance=ExtractiveReader(model="deepset/tinyroberta-squad2"), name="reader") qa_pipeline.connect("retriever", "reader") # Populate the document store diff --git a/e2e/pipelines/test_extractive_qa_pipeline.py b/e2e/pipelines/test_extractive_qa_pipeline.py index 71b540d0f..46b3b6cc8 100644 --- a/e2e/pipelines/test_extractive_qa_pipeline.py +++ b/e2e/pipelines/test_extractive_qa_pipeline.py @@ -10,7 +10,7 @@ def test_extractive_qa_pipeline(tmp_path): # Create the pipeline qa_pipeline = Pipeline() qa_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=InMemoryDocumentStore()), name="retriever") - qa_pipeline.add_component(instance=ExtractiveReader(model_name_or_path="deepset/tinyroberta-squad2"), name="reader") + qa_pipeline.add_component(instance=ExtractiveReader(model="deepset/tinyroberta-squad2"), name="reader") qa_pipeline.connect("retriever", "reader") # Draw the pipeline diff --git a/haystack/components/readers/extractive.py b/haystack/components/readers/extractive.py index a26f66e9c..dacb4fa88 100644 --- a/haystack/components/readers/extractive.py +++ b/haystack/components/readers/extractive.py @@ -38,7 +38,7 @@ class ExtractiveReader: def __init__( self, - model_name_or_path: Union[Path, str] = "deepset/roberta-base-squad2-distilled", + model: Union[Path, str] = "deepset/roberta-base-squad2-distilled", device: Optional[str] = None, token: Union[bool, str, None] = None, top_k: int = 20, @@ -54,7 +54,7 @@ class ExtractiveReader: ) -> None: """ Creates an ExtractiveReader - :param model_name_or_path: A Hugging Face transformers question answering model. + :param model: A Hugging Face transformers question answering model. Can either be a path to a folder containing the model files or an identifier for the Hugging Face hub. Default: `'deepset/roberta-base-squad2-distilled'` :param device: Pytorch device string. Uses GPU by default, if available. @@ -83,11 +83,11 @@ class ExtractiveReader: both of these answers could be kept if this variable is set to 0.24 or lower. If None is provided then all answers are kept. :param model_kwargs: Additional keyword arguments passed to `AutoModelForQuestionAnswering.from_pretrained` - when loading the model specified in `model_name_or_path`. For details on what kwargs you can pass, + when loading the model specified in `model`. For details on what kwargs you can pass, see the model's documentation. """ torch_and_transformers_import.check() - self.model_name_or_path = str(model_name_or_path) + self.model_name_or_path = str(model) self.model = None self.device = device self.token = token @@ -114,7 +114,7 @@ class ExtractiveReader: """ serialization_dict = default_to_dict( self, - model_name_or_path=self.model_name_or_path, + model=self.model_name_or_path, device=self.device, token=self.token if not isinstance(self.token, str) else None, max_seq_length=self.max_seq_length, diff --git a/releasenotes/notes/rename-model-param-reader-b8cbb0d638e3b8c2.yaml b/releasenotes/notes/rename-model-param-reader-b8cbb0d638e3b8c2.yaml new file mode 100644 index 000000000..6a5bd3eef --- /dev/null +++ b/releasenotes/notes/rename-model-param-reader-b8cbb0d638e3b8c2.yaml @@ -0,0 +1,3 @@ +--- +upgrade: + - Rename parameter `model_name_or_path` to `model` in `ExtractiveReader`. diff --git a/test/components/readers/test_extractive.py b/test/components/readers/test_extractive.py index 5d193db36..d520d34ce 100644 --- a/test/components/readers/test_extractive.py +++ b/test/components/readers/test_extractive.py @@ -72,7 +72,7 @@ def mock_reader(mock_tokenizer): with patch("haystack.components.readers.extractive.AutoModelForQuestionAnswering.from_pretrained") as model: model.return_value = MockModel() - reader = ExtractiveReader(model_name_or_path="mock-model", device="cpu:0") + reader = ExtractiveReader(model="mock-model", device="cpu:0") reader.warm_up() return reader @@ -94,7 +94,7 @@ def test_to_dict(): assert data == { "type": "haystack.components.readers.extractive.ExtractiveReader", "init_parameters": { - "model_name_or_path": "my-model", + "model": "my-model", "device": None, "token": None, # don't serialize valid tokens "top_k": 20, @@ -117,7 +117,7 @@ def test_to_dict_empty_model_kwargs(): assert data == { "type": "haystack.components.readers.extractive.ExtractiveReader", "init_parameters": { - "model_name_or_path": "my-model", + "model": "my-model", "device": None, "token": None, # don't serialize valid tokens "top_k": 20, @@ -137,7 +137,7 @@ def test_from_dict(): data = { "type": "haystack.components.readers.extractive.ExtractiveReader", "init_parameters": { - "model_name_or_path": "my-model", + "model": "my-model", "device": None, "token": None, "top_k": 20,