mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-09 14:23:43 +00:00
chore: migrate ExtractiveReader to use secret management (#7309)
* chore: migrate `ExtractiveReader` to use secret management * docs: add release notes
This commit is contained in:
parent
50ad1fa2c4
commit
23c65c250f
@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
from haystack import ComponentError, Document, ExtractedAnswer, component, default_from_dict, default_to_dict, logging
|
from haystack import ComponentError, Document, ExtractedAnswer, component, default_from_dict, default_to_dict, logging
|
||||||
from haystack.lazy_imports import LazyImport
|
from haystack.lazy_imports import LazyImport
|
||||||
from haystack.utils import ComponentDevice, DeviceMap
|
from haystack.utils import ComponentDevice, DeviceMap, Secret, deserialize_secrets_inplace
|
||||||
from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_device_map, serialize_hf_model_kwargs
|
from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_device_map, serialize_hf_model_kwargs
|
||||||
|
|
||||||
with LazyImport("Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import:
|
with LazyImport("Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import:
|
||||||
@ -51,7 +51,7 @@ class ExtractiveReader:
|
|||||||
self,
|
self,
|
||||||
model: Union[Path, str] = "deepset/roberta-base-squad2-distilled",
|
model: Union[Path, str] = "deepset/roberta-base-squad2-distilled",
|
||||||
device: Optional[ComponentDevice] = None,
|
device: Optional[ComponentDevice] = None,
|
||||||
token: Union[bool, str, None] = None,
|
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
|
||||||
top_k: int = 20,
|
top_k: int = 20,
|
||||||
score_threshold: Optional[float] = None,
|
score_threshold: Optional[float] = None,
|
||||||
max_seq_length: int = 384,
|
max_seq_length: int = 384,
|
||||||
@ -140,7 +140,7 @@ class ExtractiveReader:
|
|||||||
self,
|
self,
|
||||||
model=self.model_name_or_path,
|
model=self.model_name_or_path,
|
||||||
device=None,
|
device=None,
|
||||||
token=self.token if not isinstance(self.token, str) else None,
|
token=self.token.to_dict() if self.token else None,
|
||||||
max_seq_length=self.max_seq_length,
|
max_seq_length=self.max_seq_length,
|
||||||
top_k=self.top_k,
|
top_k=self.top_k,
|
||||||
score_threshold=self.score_threshold,
|
score_threshold=self.score_threshold,
|
||||||
@ -166,6 +166,7 @@ class ExtractiveReader:
|
|||||||
Deserialized component.
|
Deserialized component.
|
||||||
"""
|
"""
|
||||||
init_params = data["init_parameters"]
|
init_params = data["init_parameters"]
|
||||||
|
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
|
||||||
if init_params["device"] is not None:
|
if init_params["device"] is not None:
|
||||||
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
|
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
|
||||||
deserialize_hf_model_kwargs(init_params["model_kwargs"])
|
deserialize_hf_model_kwargs(init_params["model_kwargs"])
|
||||||
@ -179,9 +180,11 @@ class ExtractiveReader:
|
|||||||
# Take the first device used by `accelerate`. Needed to pass inputs from the tokenizer to the correct device.
|
# Take the first device used by `accelerate`. Needed to pass inputs from the tokenizer to the correct device.
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
self.model = AutoModelForQuestionAnswering.from_pretrained(
|
self.model = AutoModelForQuestionAnswering.from_pretrained(
|
||||||
self.model_name_or_path, token=self.token, **self.model_kwargs
|
self.model_name_or_path, token=self.token.resolve_value() if self.token else None, **self.model_kwargs
|
||||||
|
)
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
self.model_name_or_path, token=self.token.resolve_value() if self.token else None
|
||||||
)
|
)
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, token=self.token)
|
|
||||||
self.device = ComponentDevice.from_multiple(device_map=DeviceMap.from_hf(self.model.hf_device_map))
|
self.device = ComponentDevice.from_multiple(device_map=DeviceMap.from_hf(self.model.hf_device_map))
|
||||||
|
|
||||||
def _flatten_documents(
|
def _flatten_documents(
|
||||||
|
|||||||
@ -0,0 +1,7 @@
|
|||||||
|
---
|
||||||
|
upgrade:
|
||||||
|
- |
|
||||||
|
Update secret handling for the `ExtractiveReader` component using the `Secret` type.
|
||||||
|
|
||||||
|
The default init parameter `token` is now required to either use a token or the environment `HF_API_TOKEN` variable
|
||||||
|
if authentication is required - The on-disk local token file is no longer supported.
|
||||||
@ -1,17 +1,26 @@
|
|||||||
|
import logging
|
||||||
from math import ceil, exp
|
from math import ceil, exp
|
||||||
from typing import List
|
from typing import List
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import logging
|
|
||||||
import torch
|
import torch
|
||||||
|
from _pytest.monkeypatch import MonkeyPatch
|
||||||
from transformers import pipeline
|
from transformers import pipeline
|
||||||
|
|
||||||
from haystack import Document, ExtractedAnswer
|
from haystack import Document, ExtractedAnswer
|
||||||
from haystack.components.readers import ExtractiveReader
|
from haystack.components.readers import ExtractiveReader
|
||||||
|
from haystack.utils import Secret
|
||||||
from haystack.utils.device import ComponentDevice, DeviceMap
|
from haystack.utils.device import ComponentDevice, DeviceMap
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def initialized_token(monkeypatch: MonkeyPatch) -> Secret:
|
||||||
|
monkeypatch.setenv("HF_API_TOKEN", "secret-token")
|
||||||
|
|
||||||
|
return Secret.from_env_var("HF_API_TOKEN", strict=False)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_tokenizer():
|
def mock_tokenizer():
|
||||||
def mock_tokenize(
|
def mock_tokenize(
|
||||||
@ -87,8 +96,8 @@ example_documents = [
|
|||||||
] * 2
|
] * 2
|
||||||
|
|
||||||
|
|
||||||
def test_to_dict():
|
def test_to_dict(initialized_token: Secret):
|
||||||
component = ExtractiveReader("my-model", token="secret-token", model_kwargs={"torch_dtype": torch.float16})
|
component = ExtractiveReader("my-model", token=initialized_token, model_kwargs={"torch_dtype": torch.float16})
|
||||||
data = component.to_dict()
|
data = component.to_dict()
|
||||||
|
|
||||||
assert data == {
|
assert data == {
|
||||||
@ -96,7 +105,7 @@ def test_to_dict():
|
|||||||
"init_parameters": {
|
"init_parameters": {
|
||||||
"model": "my-model",
|
"model": "my-model",
|
||||||
"device": None,
|
"device": None,
|
||||||
"token": None, # don't serialize valid tokens
|
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
||||||
"top_k": 20,
|
"top_k": 20,
|
||||||
"score_threshold": None,
|
"score_threshold": None,
|
||||||
"max_seq_length": 384,
|
"max_seq_length": 384,
|
||||||
@ -113,8 +122,8 @@ def test_to_dict():
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_to_dict_empty_model_kwargs():
|
def test_to_dict_no_token():
|
||||||
component = ExtractiveReader("my-model", token="secret-token")
|
component = ExtractiveReader("my-model", token=None, model_kwargs={"torch_dtype": torch.float16})
|
||||||
data = component.to_dict()
|
data = component.to_dict()
|
||||||
|
|
||||||
assert data == {
|
assert data == {
|
||||||
@ -122,7 +131,33 @@ def test_to_dict_empty_model_kwargs():
|
|||||||
"init_parameters": {
|
"init_parameters": {
|
||||||
"model": "my-model",
|
"model": "my-model",
|
||||||
"device": None,
|
"device": None,
|
||||||
"token": None, # don't serialize valid tokens
|
"token": None,
|
||||||
|
"top_k": 20,
|
||||||
|
"score_threshold": None,
|
||||||
|
"max_seq_length": 384,
|
||||||
|
"stride": 128,
|
||||||
|
"max_batch_size": None,
|
||||||
|
"answers_per_seq": None,
|
||||||
|
"no_answer": True,
|
||||||
|
"calibration_factor": 0.1,
|
||||||
|
"model_kwargs": {
|
||||||
|
"torch_dtype": "torch.float16",
|
||||||
|
"device_map": ComponentDevice.resolve_device(None).to_hf(),
|
||||||
|
}, # torch_dtype is correctly serialized
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_dict_empty_model_kwargs(initialized_token: Secret):
|
||||||
|
component = ExtractiveReader("my-model", token=initialized_token)
|
||||||
|
data = component.to_dict()
|
||||||
|
|
||||||
|
assert data == {
|
||||||
|
"type": "haystack.components.readers.extractive.ExtractiveReader",
|
||||||
|
"init_parameters": {
|
||||||
|
"model": "my-model",
|
||||||
|
"device": None,
|
||||||
|
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
||||||
"top_k": 20,
|
"top_k": 20,
|
||||||
"score_threshold": None,
|
"score_threshold": None,
|
||||||
"max_seq_length": 384,
|
"max_seq_length": 384,
|
||||||
@ -153,7 +188,7 @@ def test_to_dict_device_map(device_map, expected):
|
|||||||
"init_parameters": {
|
"init_parameters": {
|
||||||
"model": "my-model",
|
"model": "my-model",
|
||||||
"device": None,
|
"device": None,
|
||||||
"token": None,
|
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
||||||
"top_k": 20,
|
"top_k": 20,
|
||||||
"score_threshold": None,
|
"score_threshold": None,
|
||||||
"max_seq_length": 384,
|
"max_seq_length": 384,
|
||||||
@ -173,7 +208,7 @@ def test_from_dict():
|
|||||||
"init_parameters": {
|
"init_parameters": {
|
||||||
"model": "my-model",
|
"model": "my-model",
|
||||||
"device": None,
|
"device": None,
|
||||||
"token": None,
|
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
||||||
"top_k": 20,
|
"top_k": 20,
|
||||||
"score_threshold": None,
|
"score_threshold": None,
|
||||||
"max_seq_length": 384,
|
"max_seq_length": 384,
|
||||||
@ -189,7 +224,7 @@ def test_from_dict():
|
|||||||
component = ExtractiveReader.from_dict(data)
|
component = ExtractiveReader.from_dict(data)
|
||||||
assert component.model_name_or_path == "my-model"
|
assert component.model_name_or_path == "my-model"
|
||||||
assert component.device is None
|
assert component.device is None
|
||||||
assert component.token is None
|
assert component.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
|
||||||
assert component.top_k == 20
|
assert component.top_k == 20
|
||||||
assert component.score_threshold is None
|
assert component.score_threshold is None
|
||||||
assert component.max_seq_length == 384
|
assert component.max_seq_length == 384
|
||||||
@ -205,6 +240,29 @@ def test_from_dict():
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_dict_no_token():
|
||||||
|
data = {
|
||||||
|
"type": "haystack.components.readers.extractive.ExtractiveReader",
|
||||||
|
"init_parameters": {
|
||||||
|
"model": "my-model",
|
||||||
|
"device": None,
|
||||||
|
"token": None,
|
||||||
|
"top_k": 20,
|
||||||
|
"score_threshold": None,
|
||||||
|
"max_seq_length": 384,
|
||||||
|
"stride": 128,
|
||||||
|
"max_batch_size": None,
|
||||||
|
"answers_per_seq": None,
|
||||||
|
"no_answer": True,
|
||||||
|
"calibration_factor": 0.1,
|
||||||
|
"model_kwargs": {"torch_dtype": "torch.float16"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
component = ExtractiveReader.from_dict(data)
|
||||||
|
assert component.token is None
|
||||||
|
|
||||||
|
|
||||||
def test_output(mock_reader: ExtractiveReader):
|
def test_output(mock_reader: ExtractiveReader):
|
||||||
answers = mock_reader.run(example_queries[0], example_documents[0], top_k=3)[
|
answers = mock_reader.run(example_queries[0], example_documents[0], top_k=3)[
|
||||||
"answers"
|
"answers"
|
||||||
@ -336,8 +394,8 @@ def test_nest_answers(mock_reader: ExtractiveReader):
|
|||||||
|
|
||||||
@patch("haystack.components.readers.extractive.AutoTokenizer.from_pretrained")
|
@patch("haystack.components.readers.extractive.AutoTokenizer.from_pretrained")
|
||||||
@patch("haystack.components.readers.extractive.AutoModelForQuestionAnswering.from_pretrained")
|
@patch("haystack.components.readers.extractive.AutoModelForQuestionAnswering.from_pretrained")
|
||||||
def test_warm_up_use_hf_token(mocked_automodel, mocked_autotokenizer):
|
def test_warm_up_use_hf_token(mocked_automodel, mocked_autotokenizer, initialized_token: Secret):
|
||||||
reader = ExtractiveReader("deepset/roberta-base-squad2", token="fake-token", device=ComponentDevice.from_str("cpu"))
|
reader = ExtractiveReader("deepset/roberta-base-squad2", device=ComponentDevice.from_str("cpu"))
|
||||||
|
|
||||||
class MockedModel:
|
class MockedModel:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -346,8 +404,8 @@ def test_warm_up_use_hf_token(mocked_automodel, mocked_autotokenizer):
|
|||||||
mocked_automodel.return_value = MockedModel()
|
mocked_automodel.return_value = MockedModel()
|
||||||
reader.warm_up()
|
reader.warm_up()
|
||||||
|
|
||||||
mocked_automodel.assert_called_once_with("deepset/roberta-base-squad2", token="fake-token", device_map="cpu")
|
mocked_automodel.assert_called_once_with("deepset/roberta-base-squad2", token="secret-token", device_map="cpu")
|
||||||
mocked_autotokenizer.assert_called_once_with("deepset/roberta-base-squad2", token="fake-token")
|
mocked_autotokenizer.assert_called_once_with("deepset/roberta-base-squad2", token="secret-token")
|
||||||
|
|
||||||
|
|
||||||
@patch("haystack.components.readers.extractive.AutoTokenizer.from_pretrained")
|
@patch("haystack.components.readers.extractive.AutoTokenizer.from_pretrained")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user