mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-26 17:36:34 +00:00
feat!: Rename model_name_or_path
to model
in ExtractiveReader
(#6736)
* rename model parameter and internam model attribute in ExtractiveReader * fix tests for ExtractiveReader * fix e2e * reno * another fix * review feedback * Update releasenotes/notes/rename-model-param-reader-b8cbb0d638e3b8c2.yaml
This commit is contained in:
parent
b236ea49e3
commit
96c0b59aaa
@ -13,7 +13,7 @@ def test_extractive_qa_pipeline(tmp_path):
|
|||||||
# Create the pipeline
|
# Create the pipeline
|
||||||
qa_pipeline = Pipeline()
|
qa_pipeline = Pipeline()
|
||||||
qa_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=InMemoryDocumentStore()), name="retriever")
|
qa_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=InMemoryDocumentStore()), name="retriever")
|
||||||
qa_pipeline.add_component(instance=ExtractiveReader(model_name_or_path="deepset/tinyroberta-squad2"), name="reader")
|
qa_pipeline.add_component(instance=ExtractiveReader(model="deepset/tinyroberta-squad2"), name="reader")
|
||||||
qa_pipeline.connect("retriever", "reader")
|
qa_pipeline.connect("retriever", "reader")
|
||||||
|
|
||||||
# Populate the document store
|
# Populate the document store
|
||||||
|
@ -10,7 +10,7 @@ def test_extractive_qa_pipeline(tmp_path):
|
|||||||
# Create the pipeline
|
# Create the pipeline
|
||||||
qa_pipeline = Pipeline()
|
qa_pipeline = Pipeline()
|
||||||
qa_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=InMemoryDocumentStore()), name="retriever")
|
qa_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=InMemoryDocumentStore()), name="retriever")
|
||||||
qa_pipeline.add_component(instance=ExtractiveReader(model_name_or_path="deepset/tinyroberta-squad2"), name="reader")
|
qa_pipeline.add_component(instance=ExtractiveReader(model="deepset/tinyroberta-squad2"), name="reader")
|
||||||
qa_pipeline.connect("retriever", "reader")
|
qa_pipeline.connect("retriever", "reader")
|
||||||
|
|
||||||
# Draw the pipeline
|
# Draw the pipeline
|
||||||
|
@ -38,7 +38,7 @@ class ExtractiveReader:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name_or_path: Union[Path, str] = "deepset/roberta-base-squad2-distilled",
|
model: Union[Path, str] = "deepset/roberta-base-squad2-distilled",
|
||||||
device: Optional[str] = None,
|
device: Optional[str] = None,
|
||||||
token: Union[bool, str, None] = None,
|
token: Union[bool, str, None] = None,
|
||||||
top_k: int = 20,
|
top_k: int = 20,
|
||||||
@ -54,7 +54,7 @@ class ExtractiveReader:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Creates an ExtractiveReader
|
Creates an ExtractiveReader
|
||||||
:param model_name_or_path: A Hugging Face transformers question answering model.
|
:param model: A Hugging Face transformers question answering model.
|
||||||
Can either be a path to a folder containing the model files or an identifier for the Hugging Face hub.
|
Can either be a path to a folder containing the model files or an identifier for the Hugging Face hub.
|
||||||
Default: `'deepset/roberta-base-squad2-distilled'`
|
Default: `'deepset/roberta-base-squad2-distilled'`
|
||||||
:param device: Pytorch device string. Uses GPU by default, if available.
|
:param device: Pytorch device string. Uses GPU by default, if available.
|
||||||
@ -83,11 +83,11 @@ class ExtractiveReader:
|
|||||||
both of these answers could be kept if this variable is set to 0.24 or lower.
|
both of these answers could be kept if this variable is set to 0.24 or lower.
|
||||||
If None is provided then all answers are kept.
|
If None is provided then all answers are kept.
|
||||||
:param model_kwargs: Additional keyword arguments passed to `AutoModelForQuestionAnswering.from_pretrained`
|
:param model_kwargs: Additional keyword arguments passed to `AutoModelForQuestionAnswering.from_pretrained`
|
||||||
when loading the model specified in `model_name_or_path`. For details on what kwargs you can pass,
|
when loading the model specified in `model`. For details on what kwargs you can pass,
|
||||||
see the model's documentation.
|
see the model's documentation.
|
||||||
"""
|
"""
|
||||||
torch_and_transformers_import.check()
|
torch_and_transformers_import.check()
|
||||||
self.model_name_or_path = str(model_name_or_path)
|
self.model_name_or_path = str(model)
|
||||||
self.model = None
|
self.model = None
|
||||||
self.device = device
|
self.device = device
|
||||||
self.token = token
|
self.token = token
|
||||||
@ -114,7 +114,7 @@ class ExtractiveReader:
|
|||||||
"""
|
"""
|
||||||
serialization_dict = default_to_dict(
|
serialization_dict = default_to_dict(
|
||||||
self,
|
self,
|
||||||
model_name_or_path=self.model_name_or_path,
|
model=self.model_name_or_path,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
token=self.token if not isinstance(self.token, str) else None,
|
token=self.token if not isinstance(self.token, str) else None,
|
||||||
max_seq_length=self.max_seq_length,
|
max_seq_length=self.max_seq_length,
|
||||||
|
@ -0,0 +1,3 @@
|
|||||||
|
---
|
||||||
|
upgrade:
|
||||||
|
- Rename parameter `model_name_or_path` to `model` in `ExtractiveReader`.
|
@ -72,7 +72,7 @@ def mock_reader(mock_tokenizer):
|
|||||||
|
|
||||||
with patch("haystack.components.readers.extractive.AutoModelForQuestionAnswering.from_pretrained") as model:
|
with patch("haystack.components.readers.extractive.AutoModelForQuestionAnswering.from_pretrained") as model:
|
||||||
model.return_value = MockModel()
|
model.return_value = MockModel()
|
||||||
reader = ExtractiveReader(model_name_or_path="mock-model", device="cpu:0")
|
reader = ExtractiveReader(model="mock-model", device="cpu:0")
|
||||||
reader.warm_up()
|
reader.warm_up()
|
||||||
return reader
|
return reader
|
||||||
|
|
||||||
@ -94,7 +94,7 @@ def test_to_dict():
|
|||||||
assert data == {
|
assert data == {
|
||||||
"type": "haystack.components.readers.extractive.ExtractiveReader",
|
"type": "haystack.components.readers.extractive.ExtractiveReader",
|
||||||
"init_parameters": {
|
"init_parameters": {
|
||||||
"model_name_or_path": "my-model",
|
"model": "my-model",
|
||||||
"device": None,
|
"device": None,
|
||||||
"token": None, # don't serialize valid tokens
|
"token": None, # don't serialize valid tokens
|
||||||
"top_k": 20,
|
"top_k": 20,
|
||||||
@ -117,7 +117,7 @@ def test_to_dict_empty_model_kwargs():
|
|||||||
assert data == {
|
assert data == {
|
||||||
"type": "haystack.components.readers.extractive.ExtractiveReader",
|
"type": "haystack.components.readers.extractive.ExtractiveReader",
|
||||||
"init_parameters": {
|
"init_parameters": {
|
||||||
"model_name_or_path": "my-model",
|
"model": "my-model",
|
||||||
"device": None,
|
"device": None,
|
||||||
"token": None, # don't serialize valid tokens
|
"token": None, # don't serialize valid tokens
|
||||||
"top_k": 20,
|
"top_k": 20,
|
||||||
@ -137,7 +137,7 @@ def test_from_dict():
|
|||||||
data = {
|
data = {
|
||||||
"type": "haystack.components.readers.extractive.ExtractiveReader",
|
"type": "haystack.components.readers.extractive.ExtractiveReader",
|
||||||
"init_parameters": {
|
"init_parameters": {
|
||||||
"model_name_or_path": "my-model",
|
"model": "my-model",
|
||||||
"device": None,
|
"device": None,
|
||||||
"token": None,
|
"token": None,
|
||||||
"top_k": 20,
|
"top_k": 20,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user