mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-25 08:58:40 +00:00
feat!: Rename model_name_or_path
to model
in ExtractiveReader
(#6736)
* rename model parameter and internam model attribute in ExtractiveReader * fix tests for ExtractiveReader * fix e2e * reno * another fix * review feedback * Update releasenotes/notes/rename-model-param-reader-b8cbb0d638e3b8c2.yaml
This commit is contained in:
parent
b236ea49e3
commit
96c0b59aaa
@ -13,7 +13,7 @@ def test_extractive_qa_pipeline(tmp_path):
|
||||
# Create the pipeline
|
||||
qa_pipeline = Pipeline()
|
||||
qa_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=InMemoryDocumentStore()), name="retriever")
|
||||
qa_pipeline.add_component(instance=ExtractiveReader(model_name_or_path="deepset/tinyroberta-squad2"), name="reader")
|
||||
qa_pipeline.add_component(instance=ExtractiveReader(model="deepset/tinyroberta-squad2"), name="reader")
|
||||
qa_pipeline.connect("retriever", "reader")
|
||||
|
||||
# Populate the document store
|
||||
|
@ -10,7 +10,7 @@ def test_extractive_qa_pipeline(tmp_path):
|
||||
# Create the pipeline
|
||||
qa_pipeline = Pipeline()
|
||||
qa_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=InMemoryDocumentStore()), name="retriever")
|
||||
qa_pipeline.add_component(instance=ExtractiveReader(model_name_or_path="deepset/tinyroberta-squad2"), name="reader")
|
||||
qa_pipeline.add_component(instance=ExtractiveReader(model="deepset/tinyroberta-squad2"), name="reader")
|
||||
qa_pipeline.connect("retriever", "reader")
|
||||
|
||||
# Draw the pipeline
|
||||
|
@ -38,7 +38,7 @@ class ExtractiveReader:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: Union[Path, str] = "deepset/roberta-base-squad2-distilled",
|
||||
model: Union[Path, str] = "deepset/roberta-base-squad2-distilled",
|
||||
device: Optional[str] = None,
|
||||
token: Union[bool, str, None] = None,
|
||||
top_k: int = 20,
|
||||
@ -54,7 +54,7 @@ class ExtractiveReader:
|
||||
) -> None:
|
||||
"""
|
||||
Creates an ExtractiveReader
|
||||
:param model_name_or_path: A Hugging Face transformers question answering model.
|
||||
:param model: A Hugging Face transformers question answering model.
|
||||
Can either be a path to a folder containing the model files or an identifier for the Hugging Face hub.
|
||||
Default: `'deepset/roberta-base-squad2-distilled'`
|
||||
:param device: Pytorch device string. Uses GPU by default, if available.
|
||||
@ -83,11 +83,11 @@ class ExtractiveReader:
|
||||
both of these answers could be kept if this variable is set to 0.24 or lower.
|
||||
If None is provided then all answers are kept.
|
||||
:param model_kwargs: Additional keyword arguments passed to `AutoModelForQuestionAnswering.from_pretrained`
|
||||
when loading the model specified in `model_name_or_path`. For details on what kwargs you can pass,
|
||||
when loading the model specified in `model`. For details on what kwargs you can pass,
|
||||
see the model's documentation.
|
||||
"""
|
||||
torch_and_transformers_import.check()
|
||||
self.model_name_or_path = str(model_name_or_path)
|
||||
self.model_name_or_path = str(model)
|
||||
self.model = None
|
||||
self.device = device
|
||||
self.token = token
|
||||
@ -114,7 +114,7 @@ class ExtractiveReader:
|
||||
"""
|
||||
serialization_dict = default_to_dict(
|
||||
self,
|
||||
model_name_or_path=self.model_name_or_path,
|
||||
model=self.model_name_or_path,
|
||||
device=self.device,
|
||||
token=self.token if not isinstance(self.token, str) else None,
|
||||
max_seq_length=self.max_seq_length,
|
||||
|
@ -0,0 +1,3 @@
|
||||
---
|
||||
upgrade:
|
||||
- Rename parameter `model_name_or_path` to `model` in `ExtractiveReader`.
|
@ -72,7 +72,7 @@ def mock_reader(mock_tokenizer):
|
||||
|
||||
with patch("haystack.components.readers.extractive.AutoModelForQuestionAnswering.from_pretrained") as model:
|
||||
model.return_value = MockModel()
|
||||
reader = ExtractiveReader(model_name_or_path="mock-model", device="cpu:0")
|
||||
reader = ExtractiveReader(model="mock-model", device="cpu:0")
|
||||
reader.warm_up()
|
||||
return reader
|
||||
|
||||
@ -94,7 +94,7 @@ def test_to_dict():
|
||||
assert data == {
|
||||
"type": "haystack.components.readers.extractive.ExtractiveReader",
|
||||
"init_parameters": {
|
||||
"model_name_or_path": "my-model",
|
||||
"model": "my-model",
|
||||
"device": None,
|
||||
"token": None, # don't serialize valid tokens
|
||||
"top_k": 20,
|
||||
@ -117,7 +117,7 @@ def test_to_dict_empty_model_kwargs():
|
||||
assert data == {
|
||||
"type": "haystack.components.readers.extractive.ExtractiveReader",
|
||||
"init_parameters": {
|
||||
"model_name_or_path": "my-model",
|
||||
"model": "my-model",
|
||||
"device": None,
|
||||
"token": None, # don't serialize valid tokens
|
||||
"top_k": 20,
|
||||
@ -137,7 +137,7 @@ def test_from_dict():
|
||||
data = {
|
||||
"type": "haystack.components.readers.extractive.ExtractiveReader",
|
||||
"init_parameters": {
|
||||
"model_name_or_path": "my-model",
|
||||
"model": "my-model",
|
||||
"device": None,
|
||||
"token": None,
|
||||
"top_k": 20,
|
||||
|
Loading…
x
Reference in New Issue
Block a user