Add more flexible options for model downloads (Proxies, resume_download, local_files_only...) (#1256)

* allow passing more options for model/tokenizer download from remote

* temporarily change dependency to current farm master

* Add latest docstring and tutorial changes

* add kwargs

* add docstrings

* Add latest docstring and tutorial changes

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Malte Pietsch 2021-10-18 15:47:36 +02:00 committed by GitHub
parent 3d58e81b5e
commit eb95f0e8aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 53 additions and 21 deletions

View File

@ -48,7 +48,7 @@ While the underlying model can vary (BERT, Roberta, DistilBERT, ...), the interf
#### \_\_init\_\_ #### \_\_init\_\_
```python ```python
| __init__(model_name_or_path: str, model_version: Optional[str] = None, context_window_size: int = 150, batch_size: int = 50, use_gpu: bool = True, no_ans_boost: float = 0.0, return_no_answer: bool = False, top_k: int = 10, top_k_per_candidate: int = 3, top_k_per_sample: int = 1, num_processes: Optional[int] = None, max_seq_len: int = 256, doc_stride: int = 128, progress_bar: bool = True, duplicate_filtering: int = 0, use_confidence_scores: bool = True) | __init__(model_name_or_path: str, model_version: Optional[str] = None, context_window_size: int = 150, batch_size: int = 50, use_gpu: bool = True, no_ans_boost: float = 0.0, return_no_answer: bool = False, top_k: int = 10, top_k_per_candidate: int = 3, top_k_per_sample: int = 1, num_processes: Optional[int] = None, max_seq_len: int = 256, doc_stride: int = 128, progress_bar: bool = True, duplicate_filtering: int = 0, use_confidence_scores: bool = True, proxies=None, local_files_only=False, force_download=False, **kwargs)
``` ```
**Arguments**: **Arguments**:
@ -89,6 +89,15 @@ and that FARM includes no_answer in the sorted list of predictions.
Can be helpful to disable in production deployments to keep the logs clean. Can be helpful to disable in production deployments to keep the logs clean.
- `duplicate_filtering`: Answers are filtered based on their position. Both start and end position of the answers are considered. - `duplicate_filtering`: Answers are filtered based on their position. Both start and end position of the answers are considered.
The higher the value, answers that are more apart are filtered out. 0 corresponds to exact duplicates. -1 turns off duplicate removal. The higher the value, answers that are more apart are filtered out. 0 corresponds to exact duplicates. -1 turns off duplicate removal.
- `use_confidence_scores`: Sets the type of score that is returned with every predicted answer.
`True` => a scaled confidence / relevance score between [0, 1].
This score can also be further calibrated on your dataset via self.eval()
(see https://haystack.deepset.ai/components/reader#confidence-scores) .
`False` => an unscaled, raw score [-inf, +inf] which is the sum of start and end logit
from the model for the predicted span.
- `proxies`: Dict of proxy servers to use for downloading external models. Example: {'http': 'some.proxy:1234', 'http://hostname': 'my.proxy:3111'}
- `local_files_only`: Whether to force checking for local files only (and forbid downloads)
- `force_download`: Whether fo force a (re-)download even if the model exists locally in the cache.
<a name="farm.FARMReader.train"></a> <a name="farm.FARMReader.train"></a>
#### train #### train
@ -193,14 +202,14 @@ Example:
```python ```python
|{ |{
| 'query': 'Who is the father of Arya Stark?', | 'query': 'Who is the father of Arya Stark?',
| 'answers':[ | 'answers':[Answer(
| {'answer': 'Eddard,', | 'answer': 'Eddard,',
| 'context': "She travels with her father, Eddard, to King's Landing when he is", | 'context': "She travels with her father, Eddard, to King's Landing when he is",
| 'offset_answer_start': 147,
| 'offset_answer_end': 154,
| 'score': 0.9787139466668613, | 'score': 0.9787139466668613,
| 'document_id': '1337' | 'offsets_in_context': [Span(start=29, end=35],
| },... | 'offsets_in_context': [Span(start=347, end=353],
| 'document_id': '88d1ed769d003939d3a0d28034464ab2'
| ),...
| ] | ]
|} |}
``` ```

View File

@ -463,7 +463,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
@classmethod @classmethod
def convert_from_transformers(cls, model_name_or_path: Union[str, Path], device: str, revision: Optional[str] = None, def convert_from_transformers(cls, model_name_or_path: Union[str, Path], device: str, revision: Optional[str] = None,
task_type: Optional[str] = None, processor: Optional[Processor] = None): task_type: Optional[str] = None, processor: Optional[Processor] = None, **kwargs):
""" """
Load a (downstream) model from huggingface's transformers format. Use cases: Load a (downstream) model from huggingface's transformers format. Use cases:
- continue training in Haystack (e.g. take a squad QA model and fine-tune on your own data) - continue training in Haystack (e.g. take a squad QA model and fine-tune on your own data)
@ -489,7 +489,8 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
revision=revision, revision=revision,
device=device, device=device,
task_type=task_type, task_type=task_type,
processor=processor) processor=processor,
**kwargs)
@classmethod @classmethod

View File

@ -52,7 +52,11 @@ class FARMReader(BaseReader):
doc_stride: int = 128, doc_stride: int = 128,
progress_bar: bool = True, progress_bar: bool = True,
duplicate_filtering: int = 0, duplicate_filtering: int = 0,
use_confidence_scores: bool = True use_confidence_scores: bool = True,
proxies=None,
local_files_only=False,
force_download=False,
**kwargs
): ):
""" """
@ -92,6 +96,15 @@ class FARMReader(BaseReader):
Can be helpful to disable in production deployments to keep the logs clean. Can be helpful to disable in production deployments to keep the logs clean.
:param duplicate_filtering: Answers are filtered based on their position. Both start and end position of the answers are considered. :param duplicate_filtering: Answers are filtered based on their position. Both start and end position of the answers are considered.
The higher the value, answers that are more apart are filtered out. 0 corresponds to exact duplicates. -1 turns off duplicate removal. The higher the value, answers that are more apart are filtered out. 0 corresponds to exact duplicates. -1 turns off duplicate removal.
:param use_confidence_scores: Sets the type of score that is returned with every predicted answer.
`True` => a scaled confidence / relevance score between [0, 1].
This score can also be further calibrated on your dataset via self.eval()
(see https://haystack.deepset.ai/components/reader#confidence-scores) .
`False` => an unscaled, raw score [-inf, +inf] which is the sum of start and end logit
from the model for the predicted span.
:param proxies: Dict of proxy servers to use for downloading external models. Example: {'http': 'some.proxy:1234', 'http://hostname': 'my.proxy:3111'}
:param local_files_only: Whether to force checking for local files only (and forbid downloads)
:param force_download: Whether fo force a (re-)download even if the model exists locally in the cache.
""" """
# save init parameters to enable export of component config as YAML # save init parameters to enable export of component config as YAML
@ -100,7 +113,8 @@ class FARMReader(BaseReader):
batch_size=batch_size, use_gpu=use_gpu, no_ans_boost=no_ans_boost, return_no_answer=return_no_answer, batch_size=batch_size, use_gpu=use_gpu, no_ans_boost=no_ans_boost, return_no_answer=return_no_answer,
top_k=top_k, top_k_per_candidate=top_k_per_candidate, top_k_per_sample=top_k_per_sample, top_k=top_k, top_k_per_candidate=top_k_per_candidate, top_k_per_sample=top_k_per_sample,
num_processes=num_processes, max_seq_len=max_seq_len, doc_stride=doc_stride, progress_bar=progress_bar, num_processes=num_processes, max_seq_len=max_seq_len, doc_stride=doc_stride, progress_bar=progress_bar,
duplicate_filtering=duplicate_filtering, use_confidence_scores=use_confidence_scores duplicate_filtering=duplicate_filtering, proxies=proxies, local_files_only=local_files_only,
force_download=force_download, use_confidence_scores=use_confidence_scores, **kwargs
) )
self.return_no_answers = return_no_answer self.return_no_answers = return_no_answer
@ -110,7 +124,11 @@ class FARMReader(BaseReader):
task_type="question_answering", max_seq_len=max_seq_len, task_type="question_answering", max_seq_len=max_seq_len,
doc_stride=doc_stride, num_processes=num_processes, revision=model_version, doc_stride=doc_stride, num_processes=num_processes, revision=model_version,
disable_tqdm=not progress_bar, disable_tqdm=not progress_bar,
strict=False) strict=False,
proxies=proxies,
local_files_only=local_files_only,
force_download=force_download,
**kwargs)
self.inferencer.model.prediction_heads[0].context_window_size = context_window_size self.inferencer.model.prediction_heads[0].context_window_size = context_window_size
self.inferencer.model.prediction_heads[0].no_ans_boost = no_ans_boost self.inferencer.model.prediction_heads[0].no_ans_boost = no_ans_boost
self.inferencer.model.prediction_heads[0].n_best = top_k_per_candidate + 1 # including possible no_answer self.inferencer.model.prediction_heads[0].n_best = top_k_per_candidate + 1 # including possible no_answer
@ -346,7 +364,6 @@ class FARMReader(BaseReader):
return result return result
def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None): def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None):
#TODO update example in docstring
""" """
Use loaded QA model to find answers for a query in the supplied list of Document. Use loaded QA model to find answers for a query in the supplied list of Document.
@ -355,14 +372,14 @@ class FARMReader(BaseReader):
```python ```python
|{ |{
| 'query': 'Who is the father of Arya Stark?', | 'query': 'Who is the father of Arya Stark?',
| 'answers':[ | 'answers':[Answer(
| {'answer': 'Eddard,', | 'answer': 'Eddard,',
| 'context': "She travels with her father, Eddard, to King's Landing when he is", | 'context': "She travels with her father, Eddard, to King's Landing when he is",
| 'offset_answer_start': 147,
| 'offset_answer_end': 154,
| 'score': 0.9787139466668613, | 'score': 0.9787139466668613,
| 'document_id': '1337' | 'offsets_in_context': [Span(start=29, end=35],
| },... | 'offsets_in_context': [Span(start=347, end=353],
| 'document_id': '88d1ed769d003939d3a0d28034464ab2'
| ),...
| ] | ]
|} |}
``` ```

View File

@ -52,6 +52,11 @@ def test_prediction_attributes(prediction):
for ag in attributes_gold: for ag in attributes_gold:
assert ag in prediction assert ag in prediction
@pytest.mark.slow
def test_model_download_options():
# download disabled and model is not cached locally
with pytest.raises(OSError):
impossible_reader = FARMReader("mfeb/albert-xxlarge-v2-squad2", local_files_only=True)
def test_answer_attributes(prediction): def test_answer_attributes(prediction):
# TODO Transformers answer also has meta key # TODO Transformers answer also has meta key