mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-21 06:58:27 +00:00
Add more flexible options for model downloads (Proxies, resume_download, local_files_only...) (#1256)
* allow passing more options for model/tokenizer download from remote * temporarily change dependency to current farm master * Add latest docstring and tutorial changes * add kwargs * add docstrings * Add latest docstring and tutorial changes Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
3d58e81b5e
commit
eb95f0e8aa
@ -48,7 +48,7 @@ While the underlying model can vary (BERT, Roberta, DistilBERT, ...), the interf
|
||||
#### \_\_init\_\_
|
||||
|
||||
```python
|
||||
| __init__(model_name_or_path: str, model_version: Optional[str] = None, context_window_size: int = 150, batch_size: int = 50, use_gpu: bool = True, no_ans_boost: float = 0.0, return_no_answer: bool = False, top_k: int = 10, top_k_per_candidate: int = 3, top_k_per_sample: int = 1, num_processes: Optional[int] = None, max_seq_len: int = 256, doc_stride: int = 128, progress_bar: bool = True, duplicate_filtering: int = 0, use_confidence_scores: bool = True)
|
||||
| __init__(model_name_or_path: str, model_version: Optional[str] = None, context_window_size: int = 150, batch_size: int = 50, use_gpu: bool = True, no_ans_boost: float = 0.0, return_no_answer: bool = False, top_k: int = 10, top_k_per_candidate: int = 3, top_k_per_sample: int = 1, num_processes: Optional[int] = None, max_seq_len: int = 256, doc_stride: int = 128, progress_bar: bool = True, duplicate_filtering: int = 0, use_confidence_scores: bool = True, proxies=None, local_files_only=False, force_download=False, **kwargs)
|
||||
```
|
||||
|
||||
**Arguments**:
|
||||
@ -89,6 +89,15 @@ and that FARM includes no_answer in the sorted list of predictions.
|
||||
Can be helpful to disable in production deployments to keep the logs clean.
|
||||
- `duplicate_filtering`: Answers are filtered based on their position. Both start and end position of the answers are considered.
|
||||
The higher the value, answers that are more apart are filtered out. 0 corresponds to exact duplicates. -1 turns off duplicate removal.
|
||||
- `use_confidence_scores`: Sets the type of score that is returned with every predicted answer.
|
||||
`True` => a scaled confidence / relevance score between [0, 1].
|
||||
This score can also be further calibrated on your dataset via self.eval()
|
||||
(see https://haystack.deepset.ai/components/reader#confidence-scores) .
|
||||
`False` => an unscaled, raw score [-inf, +inf] which is the sum of start and end logit
|
||||
from the model for the predicted span.
|
||||
- `proxies`: Dict of proxy servers to use for downloading external models. Example: {'http': 'some.proxy:1234', 'http://hostname': 'my.proxy:3111'}
|
||||
- `local_files_only`: Whether to force checking for local files only (and forbid downloads)
|
||||
- `force_download`: Whether fo force a (re-)download even if the model exists locally in the cache.
|
||||
|
||||
<a name="farm.FARMReader.train"></a>
|
||||
#### train
|
||||
@ -193,14 +202,14 @@ Example:
|
||||
```python
|
||||
|{
|
||||
| 'query': 'Who is the father of Arya Stark?',
|
||||
| 'answers':[
|
||||
| {'answer': 'Eddard,',
|
||||
| 'context': " She travels with her father, Eddard, to King's Landing when he is ",
|
||||
| 'offset_answer_start': 147,
|
||||
| 'offset_answer_end': 154,
|
||||
| 'answers':[Answer(
|
||||
| 'answer': 'Eddard,',
|
||||
| 'context': "She travels with her father, Eddard, to King's Landing when he is",
|
||||
| 'score': 0.9787139466668613,
|
||||
| 'document_id': '1337'
|
||||
| },...
|
||||
| 'offsets_in_context': [Span(start=29, end=35],
|
||||
| 'offsets_in_context': [Span(start=347, end=353],
|
||||
| 'document_id': '88d1ed769d003939d3a0d28034464ab2'
|
||||
| ),...
|
||||
| ]
|
||||
|}
|
||||
```
|
||||
|
@ -463,7 +463,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
|
||||
|
||||
@classmethod
|
||||
def convert_from_transformers(cls, model_name_or_path: Union[str, Path], device: str, revision: Optional[str] = None,
|
||||
task_type: Optional[str] = None, processor: Optional[Processor] = None):
|
||||
task_type: Optional[str] = None, processor: Optional[Processor] = None, **kwargs):
|
||||
"""
|
||||
Load a (downstream) model from huggingface's transformers format. Use cases:
|
||||
- continue training in Haystack (e.g. take a squad QA model and fine-tune on your own data)
|
||||
@ -489,7 +489,8 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
|
||||
revision=revision,
|
||||
device=device,
|
||||
task_type=task_type,
|
||||
processor=processor)
|
||||
processor=processor,
|
||||
**kwargs)
|
||||
|
||||
|
||||
@classmethod
|
||||
|
@ -52,7 +52,11 @@ class FARMReader(BaseReader):
|
||||
doc_stride: int = 128,
|
||||
progress_bar: bool = True,
|
||||
duplicate_filtering: int = 0,
|
||||
use_confidence_scores: bool = True
|
||||
use_confidence_scores: bool = True,
|
||||
proxies=None,
|
||||
local_files_only=False,
|
||||
force_download=False,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
"""
|
||||
@ -92,6 +96,15 @@ class FARMReader(BaseReader):
|
||||
Can be helpful to disable in production deployments to keep the logs clean.
|
||||
:param duplicate_filtering: Answers are filtered based on their position. Both start and end position of the answers are considered.
|
||||
The higher the value, answers that are more apart are filtered out. 0 corresponds to exact duplicates. -1 turns off duplicate removal.
|
||||
:param use_confidence_scores: Sets the type of score that is returned with every predicted answer.
|
||||
`True` => a scaled confidence / relevance score between [0, 1].
|
||||
This score can also be further calibrated on your dataset via self.eval()
|
||||
(see https://haystack.deepset.ai/components/reader#confidence-scores) .
|
||||
`False` => an unscaled, raw score [-inf, +inf] which is the sum of start and end logit
|
||||
from the model for the predicted span.
|
||||
:param proxies: Dict of proxy servers to use for downloading external models. Example: {'http': 'some.proxy:1234', 'http://hostname': 'my.proxy:3111'}
|
||||
:param local_files_only: Whether to force checking for local files only (and forbid downloads)
|
||||
:param force_download: Whether fo force a (re-)download even if the model exists locally in the cache.
|
||||
"""
|
||||
|
||||
# save init parameters to enable export of component config as YAML
|
||||
@ -100,7 +113,8 @@ class FARMReader(BaseReader):
|
||||
batch_size=batch_size, use_gpu=use_gpu, no_ans_boost=no_ans_boost, return_no_answer=return_no_answer,
|
||||
top_k=top_k, top_k_per_candidate=top_k_per_candidate, top_k_per_sample=top_k_per_sample,
|
||||
num_processes=num_processes, max_seq_len=max_seq_len, doc_stride=doc_stride, progress_bar=progress_bar,
|
||||
duplicate_filtering=duplicate_filtering, use_confidence_scores=use_confidence_scores
|
||||
duplicate_filtering=duplicate_filtering, proxies=proxies, local_files_only=local_files_only,
|
||||
force_download=force_download, use_confidence_scores=use_confidence_scores, **kwargs
|
||||
)
|
||||
|
||||
self.return_no_answers = return_no_answer
|
||||
@ -110,7 +124,11 @@ class FARMReader(BaseReader):
|
||||
task_type="question_answering", max_seq_len=max_seq_len,
|
||||
doc_stride=doc_stride, num_processes=num_processes, revision=model_version,
|
||||
disable_tqdm=not progress_bar,
|
||||
strict=False)
|
||||
strict=False,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
force_download=force_download,
|
||||
**kwargs)
|
||||
self.inferencer.model.prediction_heads[0].context_window_size = context_window_size
|
||||
self.inferencer.model.prediction_heads[0].no_ans_boost = no_ans_boost
|
||||
self.inferencer.model.prediction_heads[0].n_best = top_k_per_candidate + 1 # including possible no_answer
|
||||
@ -346,7 +364,6 @@ class FARMReader(BaseReader):
|
||||
return result
|
||||
|
||||
def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None):
|
||||
#TODO update example in docstring
|
||||
"""
|
||||
Use loaded QA model to find answers for a query in the supplied list of Document.
|
||||
|
||||
@ -355,14 +372,14 @@ class FARMReader(BaseReader):
|
||||
```python
|
||||
|{
|
||||
| 'query': 'Who is the father of Arya Stark?',
|
||||
| 'answers':[
|
||||
| {'answer': 'Eddard,',
|
||||
| 'context': " She travels with her father, Eddard, to King's Landing when he is ",
|
||||
| 'offset_answer_start': 147,
|
||||
| 'offset_answer_end': 154,
|
||||
| 'answers':[Answer(
|
||||
| 'answer': 'Eddard,',
|
||||
| 'context': "She travels with her father, Eddard, to King's Landing when he is",
|
||||
| 'score': 0.9787139466668613,
|
||||
| 'document_id': '1337'
|
||||
| },...
|
||||
| 'offsets_in_context': [Span(start=29, end=35],
|
||||
| 'offsets_in_context': [Span(start=347, end=353],
|
||||
| 'document_id': '88d1ed769d003939d3a0d28034464ab2'
|
||||
| ),...
|
||||
| ]
|
||||
|}
|
||||
```
|
||||
|
@ -52,6 +52,11 @@ def test_prediction_attributes(prediction):
|
||||
for ag in attributes_gold:
|
||||
assert ag in prediction
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_model_download_options():
|
||||
# download disabled and model is not cached locally
|
||||
with pytest.raises(OSError):
|
||||
impossible_reader = FARMReader("mfeb/albert-xxlarge-v2-squad2", local_files_only=True)
|
||||
|
||||
def test_answer_attributes(prediction):
|
||||
# TODO Transformers answer also has meta key
|
||||
|
Loading…
x
Reference in New Issue
Block a user