mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-01 10:19:23 +00:00
chore: migrate ExtractiveReader to use secret management (#7309)
* chore: migrate `ExtractiveReader` to use secret management * docs: add release notes
This commit is contained in:
parent
50ad1fa2c4
commit
23c65c250f
@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from haystack import ComponentError, Document, ExtractedAnswer, component, default_from_dict, default_to_dict, logging
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import ComponentDevice, DeviceMap
|
||||
from haystack.utils import ComponentDevice, DeviceMap, Secret, deserialize_secrets_inplace
|
||||
from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_device_map, serialize_hf_model_kwargs
|
||||
|
||||
with LazyImport("Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import:
|
||||
@ -51,7 +51,7 @@ class ExtractiveReader:
|
||||
self,
|
||||
model: Union[Path, str] = "deepset/roberta-base-squad2-distilled",
|
||||
device: Optional[ComponentDevice] = None,
|
||||
token: Union[bool, str, None] = None,
|
||||
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
|
||||
top_k: int = 20,
|
||||
score_threshold: Optional[float] = None,
|
||||
max_seq_length: int = 384,
|
||||
@ -140,7 +140,7 @@ class ExtractiveReader:
|
||||
self,
|
||||
model=self.model_name_or_path,
|
||||
device=None,
|
||||
token=self.token if not isinstance(self.token, str) else None,
|
||||
token=self.token.to_dict() if self.token else None,
|
||||
max_seq_length=self.max_seq_length,
|
||||
top_k=self.top_k,
|
||||
score_threshold=self.score_threshold,
|
||||
@ -166,6 +166,7 @@ class ExtractiveReader:
|
||||
Deserialized component.
|
||||
"""
|
||||
init_params = data["init_parameters"]
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
|
||||
if init_params["device"] is not None:
|
||||
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
|
||||
deserialize_hf_model_kwargs(init_params["model_kwargs"])
|
||||
@ -179,9 +180,11 @@ class ExtractiveReader:
|
||||
# Take the first device used by `accelerate`. Needed to pass inputs from the tokenizer to the correct device.
|
||||
if self.model is None:
|
||||
self.model = AutoModelForQuestionAnswering.from_pretrained(
|
||||
self.model_name_or_path, token=self.token, **self.model_kwargs
|
||||
self.model_name_or_path, token=self.token.resolve_value() if self.token else None, **self.model_kwargs
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.model_name_or_path, token=self.token.resolve_value() if self.token else None
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, token=self.token)
|
||||
self.device = ComponentDevice.from_multiple(device_map=DeviceMap.from_hf(self.model.hf_device_map))
|
||||
|
||||
def _flatten_documents(
|
||||
|
||||
@ -0,0 +1,7 @@
|
||||
---
|
||||
upgrade:
|
||||
- |
|
||||
Update secret handling for the `ExtractiveReader` component using the `Secret` type.
|
||||
|
||||
The default init parameter `token` is now required to either use a token or the environment `HF_API_TOKEN` variable
|
||||
if authentication is required - The on-disk local token file is no longer supported.
|
||||
@ -1,17 +1,26 @@
|
||||
import logging
|
||||
from math import ceil, exp
|
||||
from typing import List
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
import logging
|
||||
import torch
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from transformers import pipeline
|
||||
|
||||
from haystack import Document, ExtractedAnswer
|
||||
from haystack.components.readers import ExtractiveReader
|
||||
from haystack.utils import Secret
|
||||
from haystack.utils.device import ComponentDevice, DeviceMap
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def initialized_token(monkeypatch: MonkeyPatch) -> Secret:
|
||||
monkeypatch.setenv("HF_API_TOKEN", "secret-token")
|
||||
|
||||
return Secret.from_env_var("HF_API_TOKEN", strict=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tokenizer():
|
||||
def mock_tokenize(
|
||||
@ -87,8 +96,8 @@ example_documents = [
|
||||
] * 2
|
||||
|
||||
|
||||
def test_to_dict():
|
||||
component = ExtractiveReader("my-model", token="secret-token", model_kwargs={"torch_dtype": torch.float16})
|
||||
def test_to_dict(initialized_token: Secret):
|
||||
component = ExtractiveReader("my-model", token=initialized_token, model_kwargs={"torch_dtype": torch.float16})
|
||||
data = component.to_dict()
|
||||
|
||||
assert data == {
|
||||
@ -96,7 +105,7 @@ def test_to_dict():
|
||||
"init_parameters": {
|
||||
"model": "my-model",
|
||||
"device": None,
|
||||
"token": None, # don't serialize valid tokens
|
||||
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
||||
"top_k": 20,
|
||||
"score_threshold": None,
|
||||
"max_seq_length": 384,
|
||||
@ -113,8 +122,8 @@ def test_to_dict():
|
||||
}
|
||||
|
||||
|
||||
def test_to_dict_empty_model_kwargs():
|
||||
component = ExtractiveReader("my-model", token="secret-token")
|
||||
def test_to_dict_no_token():
|
||||
component = ExtractiveReader("my-model", token=None, model_kwargs={"torch_dtype": torch.float16})
|
||||
data = component.to_dict()
|
||||
|
||||
assert data == {
|
||||
@ -122,7 +131,33 @@ def test_to_dict_empty_model_kwargs():
|
||||
"init_parameters": {
|
||||
"model": "my-model",
|
||||
"device": None,
|
||||
"token": None, # don't serialize valid tokens
|
||||
"token": None,
|
||||
"top_k": 20,
|
||||
"score_threshold": None,
|
||||
"max_seq_length": 384,
|
||||
"stride": 128,
|
||||
"max_batch_size": None,
|
||||
"answers_per_seq": None,
|
||||
"no_answer": True,
|
||||
"calibration_factor": 0.1,
|
||||
"model_kwargs": {
|
||||
"torch_dtype": "torch.float16",
|
||||
"device_map": ComponentDevice.resolve_device(None).to_hf(),
|
||||
}, # torch_dtype is correctly serialized
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_to_dict_empty_model_kwargs(initialized_token: Secret):
|
||||
component = ExtractiveReader("my-model", token=initialized_token)
|
||||
data = component.to_dict()
|
||||
|
||||
assert data == {
|
||||
"type": "haystack.components.readers.extractive.ExtractiveReader",
|
||||
"init_parameters": {
|
||||
"model": "my-model",
|
||||
"device": None,
|
||||
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
||||
"top_k": 20,
|
||||
"score_threshold": None,
|
||||
"max_seq_length": 384,
|
||||
@ -153,7 +188,7 @@ def test_to_dict_device_map(device_map, expected):
|
||||
"init_parameters": {
|
||||
"model": "my-model",
|
||||
"device": None,
|
||||
"token": None,
|
||||
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
||||
"top_k": 20,
|
||||
"score_threshold": None,
|
||||
"max_seq_length": 384,
|
||||
@ -173,7 +208,7 @@ def test_from_dict():
|
||||
"init_parameters": {
|
||||
"model": "my-model",
|
||||
"device": None,
|
||||
"token": None,
|
||||
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
||||
"top_k": 20,
|
||||
"score_threshold": None,
|
||||
"max_seq_length": 384,
|
||||
@ -189,7 +224,7 @@ def test_from_dict():
|
||||
component = ExtractiveReader.from_dict(data)
|
||||
assert component.model_name_or_path == "my-model"
|
||||
assert component.device is None
|
||||
assert component.token is None
|
||||
assert component.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
|
||||
assert component.top_k == 20
|
||||
assert component.score_threshold is None
|
||||
assert component.max_seq_length == 384
|
||||
@ -205,6 +240,29 @@ def test_from_dict():
|
||||
}
|
||||
|
||||
|
||||
def test_from_dict_no_token():
|
||||
data = {
|
||||
"type": "haystack.components.readers.extractive.ExtractiveReader",
|
||||
"init_parameters": {
|
||||
"model": "my-model",
|
||||
"device": None,
|
||||
"token": None,
|
||||
"top_k": 20,
|
||||
"score_threshold": None,
|
||||
"max_seq_length": 384,
|
||||
"stride": 128,
|
||||
"max_batch_size": None,
|
||||
"answers_per_seq": None,
|
||||
"no_answer": True,
|
||||
"calibration_factor": 0.1,
|
||||
"model_kwargs": {"torch_dtype": "torch.float16"},
|
||||
},
|
||||
}
|
||||
|
||||
component = ExtractiveReader.from_dict(data)
|
||||
assert component.token is None
|
||||
|
||||
|
||||
def test_output(mock_reader: ExtractiveReader):
|
||||
answers = mock_reader.run(example_queries[0], example_documents[0], top_k=3)[
|
||||
"answers"
|
||||
@ -336,8 +394,8 @@ def test_nest_answers(mock_reader: ExtractiveReader):
|
||||
|
||||
@patch("haystack.components.readers.extractive.AutoTokenizer.from_pretrained")
|
||||
@patch("haystack.components.readers.extractive.AutoModelForQuestionAnswering.from_pretrained")
|
||||
def test_warm_up_use_hf_token(mocked_automodel, mocked_autotokenizer):
|
||||
reader = ExtractiveReader("deepset/roberta-base-squad2", token="fake-token", device=ComponentDevice.from_str("cpu"))
|
||||
def test_warm_up_use_hf_token(mocked_automodel, mocked_autotokenizer, initialized_token: Secret):
|
||||
reader = ExtractiveReader("deepset/roberta-base-squad2", device=ComponentDevice.from_str("cpu"))
|
||||
|
||||
class MockedModel:
|
||||
def __init__(self):
|
||||
@ -346,8 +404,8 @@ def test_warm_up_use_hf_token(mocked_automodel, mocked_autotokenizer):
|
||||
mocked_automodel.return_value = MockedModel()
|
||||
reader.warm_up()
|
||||
|
||||
mocked_automodel.assert_called_once_with("deepset/roberta-base-squad2", token="fake-token", device_map="cpu")
|
||||
mocked_autotokenizer.assert_called_once_with("deepset/roberta-base-squad2", token="fake-token")
|
||||
mocked_automodel.assert_called_once_with("deepset/roberta-base-squad2", token="secret-token", device_map="cpu")
|
||||
mocked_autotokenizer.assert_called_once_with("deepset/roberta-base-squad2", token="secret-token")
|
||||
|
||||
|
||||
@patch("haystack.components.readers.extractive.AutoTokenizer.from_pretrained")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user