Add more flexible options for model downloads (Proxies, resume_download, local_files_only...) (#1256)

* allow passing more options for model/tokenizer download from remote

* temporarily change dependency to current farm master

* Add latest docstring and tutorial changes

* add kwargs

* add docstrings

* Add latest docstring and tutorial changes

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Malte Pietsch 2021-10-18 15:47:36 +02:00 committed by GitHub
parent 3d58e81b5e
commit eb95f0e8aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 53 additions and 21 deletions

View File

@ -48,7 +48,7 @@ While the underlying model can vary (BERT, Roberta, DistilBERT, ...), the interf
#### \_\_init\_\_
```python
| __init__(model_name_or_path: str, model_version: Optional[str] = None, context_window_size: int = 150, batch_size: int = 50, use_gpu: bool = True, no_ans_boost: float = 0.0, return_no_answer: bool = False, top_k: int = 10, top_k_per_candidate: int = 3, top_k_per_sample: int = 1, num_processes: Optional[int] = None, max_seq_len: int = 256, doc_stride: int = 128, progress_bar: bool = True, duplicate_filtering: int = 0, use_confidence_scores: bool = True)
| __init__(model_name_or_path: str, model_version: Optional[str] = None, context_window_size: int = 150, batch_size: int = 50, use_gpu: bool = True, no_ans_boost: float = 0.0, return_no_answer: bool = False, top_k: int = 10, top_k_per_candidate: int = 3, top_k_per_sample: int = 1, num_processes: Optional[int] = None, max_seq_len: int = 256, doc_stride: int = 128, progress_bar: bool = True, duplicate_filtering: int = 0, use_confidence_scores: bool = True, proxies=None, local_files_only=False, force_download=False, **kwargs)
```
**Arguments**:
@ -89,6 +89,15 @@ and that FARM includes no_answer in the sorted list of predictions.
Can be helpful to disable in production deployments to keep the logs clean.
- `duplicate_filtering`: Answers are filtered based on their position. Both start and end position of the answers are considered.
The higher the value, answers that are more apart are filtered out. 0 corresponds to exact duplicates. -1 turns off duplicate removal.
- `use_confidence_scores`: Sets the type of score that is returned with every predicted answer.
`True` => a scaled confidence / relevance score between [0, 1].
This score can also be further calibrated on your dataset via self.eval()
(see https://haystack.deepset.ai/components/reader#confidence-scores) .
`False` => an unscaled, raw score [-inf, +inf] which is the sum of start and end logit
from the model for the predicted span.
- `proxies`: Dict of proxy servers to use for downloading external models. Example: {'http': 'some.proxy:1234', 'http://hostname': 'my.proxy:3111'}
- `local_files_only`: Whether to force checking for local files only (and forbid downloads)
- `force_download`: Whether fo force a (re-)download even if the model exists locally in the cache.
<a name="farm.FARMReader.train"></a>
#### train
@ -193,14 +202,14 @@ Example:
```python
|{
| 'query': 'Who is the father of Arya Stark?',
| 'answers':[
| {'answer': 'Eddard,',
| 'context': " She travels with her father, Eddard, to King's Landing when he is ",
| 'offset_answer_start': 147,
| 'offset_answer_end': 154,
| 'answers':[Answer(
| 'answer': 'Eddard,',
| 'context': "She travels with her father, Eddard, to King's Landing when he is",
| 'score': 0.9787139466668613,
| 'document_id': '1337'
| },...
| 'offsets_in_context': [Span(start=29, end=35],
| 'offsets_in_context': [Span(start=347, end=353],
| 'document_id': '88d1ed769d003939d3a0d28034464ab2'
| ),...
| ]
|}
```

View File

@ -463,7 +463,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
@classmethod
def convert_from_transformers(cls, model_name_or_path: Union[str, Path], device: str, revision: Optional[str] = None,
task_type: Optional[str] = None, processor: Optional[Processor] = None):
task_type: Optional[str] = None, processor: Optional[Processor] = None, **kwargs):
"""
Load a (downstream) model from huggingface's transformers format. Use cases:
- continue training in Haystack (e.g. take a squad QA model and fine-tune on your own data)
@ -489,7 +489,8 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
revision=revision,
device=device,
task_type=task_type,
processor=processor)
processor=processor,
**kwargs)
@classmethod

View File

@ -52,7 +52,11 @@ class FARMReader(BaseReader):
doc_stride: int = 128,
progress_bar: bool = True,
duplicate_filtering: int = 0,
use_confidence_scores: bool = True
use_confidence_scores: bool = True,
proxies=None,
local_files_only=False,
force_download=False,
**kwargs
):
"""
@ -92,6 +96,15 @@ class FARMReader(BaseReader):
Can be helpful to disable in production deployments to keep the logs clean.
:param duplicate_filtering: Answers are filtered based on their position. Both start and end position of the answers are considered.
The higher the value, answers that are more apart are filtered out. 0 corresponds to exact duplicates. -1 turns off duplicate removal.
:param use_confidence_scores: Sets the type of score that is returned with every predicted answer.
`True` => a scaled confidence / relevance score between [0, 1].
This score can also be further calibrated on your dataset via self.eval()
(see https://haystack.deepset.ai/components/reader#confidence-scores) .
`False` => an unscaled, raw score [-inf, +inf] which is the sum of start and end logit
from the model for the predicted span.
:param proxies: Dict of proxy servers to use for downloading external models. Example: {'http': 'some.proxy:1234', 'http://hostname': 'my.proxy:3111'}
:param local_files_only: Whether to force checking for local files only (and forbid downloads)
:param force_download: Whether fo force a (re-)download even if the model exists locally in the cache.
"""
# save init parameters to enable export of component config as YAML
@ -100,7 +113,8 @@ class FARMReader(BaseReader):
batch_size=batch_size, use_gpu=use_gpu, no_ans_boost=no_ans_boost, return_no_answer=return_no_answer,
top_k=top_k, top_k_per_candidate=top_k_per_candidate, top_k_per_sample=top_k_per_sample,
num_processes=num_processes, max_seq_len=max_seq_len, doc_stride=doc_stride, progress_bar=progress_bar,
duplicate_filtering=duplicate_filtering, use_confidence_scores=use_confidence_scores
duplicate_filtering=duplicate_filtering, proxies=proxies, local_files_only=local_files_only,
force_download=force_download, use_confidence_scores=use_confidence_scores, **kwargs
)
self.return_no_answers = return_no_answer
@ -110,7 +124,11 @@ class FARMReader(BaseReader):
task_type="question_answering", max_seq_len=max_seq_len,
doc_stride=doc_stride, num_processes=num_processes, revision=model_version,
disable_tqdm=not progress_bar,
strict=False)
strict=False,
proxies=proxies,
local_files_only=local_files_only,
force_download=force_download,
**kwargs)
self.inferencer.model.prediction_heads[0].context_window_size = context_window_size
self.inferencer.model.prediction_heads[0].no_ans_boost = no_ans_boost
self.inferencer.model.prediction_heads[0].n_best = top_k_per_candidate + 1 # including possible no_answer
@ -346,7 +364,6 @@ class FARMReader(BaseReader):
return result
def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None):
#TODO update example in docstring
"""
Use loaded QA model to find answers for a query in the supplied list of Document.
@ -355,14 +372,14 @@ class FARMReader(BaseReader):
```python
|{
| 'query': 'Who is the father of Arya Stark?',
| 'answers':[
| {'answer': 'Eddard,',
| 'context': " She travels with her father, Eddard, to King's Landing when he is ",
| 'offset_answer_start': 147,
| 'offset_answer_end': 154,
| 'answers':[Answer(
| 'answer': 'Eddard,',
| 'context': "She travels with her father, Eddard, to King's Landing when he is",
| 'score': 0.9787139466668613,
| 'document_id': '1337'
| },...
| 'offsets_in_context': [Span(start=29, end=35],
| 'offsets_in_context': [Span(start=347, end=353],
| 'document_id': '88d1ed769d003939d3a0d28034464ab2'
| ),...
| ]
|}
```

View File

@ -52,6 +52,11 @@ def test_prediction_attributes(prediction):
for ag in attributes_gold:
assert ag in prediction
@pytest.mark.slow
def test_model_download_options():
# download disabled and model is not cached locally
with pytest.raises(OSError):
impossible_reader = FARMReader("mfeb/albert-xxlarge-v2-squad2", local_files_only=True)
def test_answer_attributes(prediction):
# TODO Transformers answer also has meta key