mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-21 06:58:27 +00:00
Add more flexible options for model downloads (Proxies, resume_download, local_files_only...) (#1256)
* allow passing more options for model/tokenizer download from remote * temporarily change dependency to current farm master * Add latest docstring and tutorial changes * add kwargs * add docstrings * Add latest docstring and tutorial changes Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
3d58e81b5e
commit
eb95f0e8aa
@ -48,7 +48,7 @@ While the underlying model can vary (BERT, Roberta, DistilBERT, ...), the interf
|
|||||||
#### \_\_init\_\_
|
#### \_\_init\_\_
|
||||||
|
|
||||||
```python
|
```python
|
||||||
| __init__(model_name_or_path: str, model_version: Optional[str] = None, context_window_size: int = 150, batch_size: int = 50, use_gpu: bool = True, no_ans_boost: float = 0.0, return_no_answer: bool = False, top_k: int = 10, top_k_per_candidate: int = 3, top_k_per_sample: int = 1, num_processes: Optional[int] = None, max_seq_len: int = 256, doc_stride: int = 128, progress_bar: bool = True, duplicate_filtering: int = 0, use_confidence_scores: bool = True)
|
| __init__(model_name_or_path: str, model_version: Optional[str] = None, context_window_size: int = 150, batch_size: int = 50, use_gpu: bool = True, no_ans_boost: float = 0.0, return_no_answer: bool = False, top_k: int = 10, top_k_per_candidate: int = 3, top_k_per_sample: int = 1, num_processes: Optional[int] = None, max_seq_len: int = 256, doc_stride: int = 128, progress_bar: bool = True, duplicate_filtering: int = 0, use_confidence_scores: bool = True, proxies=None, local_files_only=False, force_download=False, **kwargs)
|
||||||
```
|
```
|
||||||
|
|
||||||
**Arguments**:
|
**Arguments**:
|
||||||
@ -89,6 +89,15 @@ and that FARM includes no_answer in the sorted list of predictions.
|
|||||||
Can be helpful to disable in production deployments to keep the logs clean.
|
Can be helpful to disable in production deployments to keep the logs clean.
|
||||||
- `duplicate_filtering`: Answers are filtered based on their position. Both start and end position of the answers are considered.
|
- `duplicate_filtering`: Answers are filtered based on their position. Both start and end position of the answers are considered.
|
||||||
The higher the value, answers that are more apart are filtered out. 0 corresponds to exact duplicates. -1 turns off duplicate removal.
|
The higher the value, answers that are more apart are filtered out. 0 corresponds to exact duplicates. -1 turns off duplicate removal.
|
||||||
|
- `use_confidence_scores`: Sets the type of score that is returned with every predicted answer.
|
||||||
|
`True` => a scaled confidence / relevance score between [0, 1].
|
||||||
|
This score can also be further calibrated on your dataset via self.eval()
|
||||||
|
(see https://haystack.deepset.ai/components/reader#confidence-scores) .
|
||||||
|
`False` => an unscaled, raw score [-inf, +inf] which is the sum of start and end logit
|
||||||
|
from the model for the predicted span.
|
||||||
|
- `proxies`: Dict of proxy servers to use for downloading external models. Example: {'http': 'some.proxy:1234', 'http://hostname': 'my.proxy:3111'}
|
||||||
|
- `local_files_only`: Whether to force checking for local files only (and forbid downloads)
|
||||||
|
- `force_download`: Whether fo force a (re-)download even if the model exists locally in the cache.
|
||||||
|
|
||||||
<a name="farm.FARMReader.train"></a>
|
<a name="farm.FARMReader.train"></a>
|
||||||
#### train
|
#### train
|
||||||
@ -193,14 +202,14 @@ Example:
|
|||||||
```python
|
```python
|
||||||
|{
|
|{
|
||||||
| 'query': 'Who is the father of Arya Stark?',
|
| 'query': 'Who is the father of Arya Stark?',
|
||||||
| 'answers':[
|
| 'answers':[Answer(
|
||||||
| {'answer': 'Eddard,',
|
| 'answer': 'Eddard,',
|
||||||
| 'context': " She travels with her father, Eddard, to King's Landing when he is ",
|
| 'context': "She travels with her father, Eddard, to King's Landing when he is",
|
||||||
| 'offset_answer_start': 147,
|
|
||||||
| 'offset_answer_end': 154,
|
|
||||||
| 'score': 0.9787139466668613,
|
| 'score': 0.9787139466668613,
|
||||||
| 'document_id': '1337'
|
| 'offsets_in_context': [Span(start=29, end=35],
|
||||||
| },...
|
| 'offsets_in_context': [Span(start=347, end=353],
|
||||||
|
| 'document_id': '88d1ed769d003939d3a0d28034464ab2'
|
||||||
|
| ),...
|
||||||
| ]
|
| ]
|
||||||
|}
|
|}
|
||||||
```
|
```
|
||||||
|
@ -463,7 +463,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_from_transformers(cls, model_name_or_path: Union[str, Path], device: str, revision: Optional[str] = None,
|
def convert_from_transformers(cls, model_name_or_path: Union[str, Path], device: str, revision: Optional[str] = None,
|
||||||
task_type: Optional[str] = None, processor: Optional[Processor] = None):
|
task_type: Optional[str] = None, processor: Optional[Processor] = None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Load a (downstream) model from huggingface's transformers format. Use cases:
|
Load a (downstream) model from huggingface's transformers format. Use cases:
|
||||||
- continue training in Haystack (e.g. take a squad QA model and fine-tune on your own data)
|
- continue training in Haystack (e.g. take a squad QA model and fine-tune on your own data)
|
||||||
@ -489,7 +489,8 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
|
|||||||
revision=revision,
|
revision=revision,
|
||||||
device=device,
|
device=device,
|
||||||
task_type=task_type,
|
task_type=task_type,
|
||||||
processor=processor)
|
processor=processor,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -52,7 +52,11 @@ class FARMReader(BaseReader):
|
|||||||
doc_stride: int = 128,
|
doc_stride: int = 128,
|
||||||
progress_bar: bool = True,
|
progress_bar: bool = True,
|
||||||
duplicate_filtering: int = 0,
|
duplicate_filtering: int = 0,
|
||||||
use_confidence_scores: bool = True
|
use_confidence_scores: bool = True,
|
||||||
|
proxies=None,
|
||||||
|
local_files_only=False,
|
||||||
|
force_download=False,
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -92,6 +96,15 @@ class FARMReader(BaseReader):
|
|||||||
Can be helpful to disable in production deployments to keep the logs clean.
|
Can be helpful to disable in production deployments to keep the logs clean.
|
||||||
:param duplicate_filtering: Answers are filtered based on their position. Both start and end position of the answers are considered.
|
:param duplicate_filtering: Answers are filtered based on their position. Both start and end position of the answers are considered.
|
||||||
The higher the value, answers that are more apart are filtered out. 0 corresponds to exact duplicates. -1 turns off duplicate removal.
|
The higher the value, answers that are more apart are filtered out. 0 corresponds to exact duplicates. -1 turns off duplicate removal.
|
||||||
|
:param use_confidence_scores: Sets the type of score that is returned with every predicted answer.
|
||||||
|
`True` => a scaled confidence / relevance score between [0, 1].
|
||||||
|
This score can also be further calibrated on your dataset via self.eval()
|
||||||
|
(see https://haystack.deepset.ai/components/reader#confidence-scores) .
|
||||||
|
`False` => an unscaled, raw score [-inf, +inf] which is the sum of start and end logit
|
||||||
|
from the model for the predicted span.
|
||||||
|
:param proxies: Dict of proxy servers to use for downloading external models. Example: {'http': 'some.proxy:1234', 'http://hostname': 'my.proxy:3111'}
|
||||||
|
:param local_files_only: Whether to force checking for local files only (and forbid downloads)
|
||||||
|
:param force_download: Whether fo force a (re-)download even if the model exists locally in the cache.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# save init parameters to enable export of component config as YAML
|
# save init parameters to enable export of component config as YAML
|
||||||
@ -100,7 +113,8 @@ class FARMReader(BaseReader):
|
|||||||
batch_size=batch_size, use_gpu=use_gpu, no_ans_boost=no_ans_boost, return_no_answer=return_no_answer,
|
batch_size=batch_size, use_gpu=use_gpu, no_ans_boost=no_ans_boost, return_no_answer=return_no_answer,
|
||||||
top_k=top_k, top_k_per_candidate=top_k_per_candidate, top_k_per_sample=top_k_per_sample,
|
top_k=top_k, top_k_per_candidate=top_k_per_candidate, top_k_per_sample=top_k_per_sample,
|
||||||
num_processes=num_processes, max_seq_len=max_seq_len, doc_stride=doc_stride, progress_bar=progress_bar,
|
num_processes=num_processes, max_seq_len=max_seq_len, doc_stride=doc_stride, progress_bar=progress_bar,
|
||||||
duplicate_filtering=duplicate_filtering, use_confidence_scores=use_confidence_scores
|
duplicate_filtering=duplicate_filtering, proxies=proxies, local_files_only=local_files_only,
|
||||||
|
force_download=force_download, use_confidence_scores=use_confidence_scores, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
self.return_no_answers = return_no_answer
|
self.return_no_answers = return_no_answer
|
||||||
@ -110,7 +124,11 @@ class FARMReader(BaseReader):
|
|||||||
task_type="question_answering", max_seq_len=max_seq_len,
|
task_type="question_answering", max_seq_len=max_seq_len,
|
||||||
doc_stride=doc_stride, num_processes=num_processes, revision=model_version,
|
doc_stride=doc_stride, num_processes=num_processes, revision=model_version,
|
||||||
disable_tqdm=not progress_bar,
|
disable_tqdm=not progress_bar,
|
||||||
strict=False)
|
strict=False,
|
||||||
|
proxies=proxies,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
force_download=force_download,
|
||||||
|
**kwargs)
|
||||||
self.inferencer.model.prediction_heads[0].context_window_size = context_window_size
|
self.inferencer.model.prediction_heads[0].context_window_size = context_window_size
|
||||||
self.inferencer.model.prediction_heads[0].no_ans_boost = no_ans_boost
|
self.inferencer.model.prediction_heads[0].no_ans_boost = no_ans_boost
|
||||||
self.inferencer.model.prediction_heads[0].n_best = top_k_per_candidate + 1 # including possible no_answer
|
self.inferencer.model.prediction_heads[0].n_best = top_k_per_candidate + 1 # including possible no_answer
|
||||||
@ -346,7 +364,6 @@ class FARMReader(BaseReader):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None):
|
def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None):
|
||||||
#TODO update example in docstring
|
|
||||||
"""
|
"""
|
||||||
Use loaded QA model to find answers for a query in the supplied list of Document.
|
Use loaded QA model to find answers for a query in the supplied list of Document.
|
||||||
|
|
||||||
@ -355,14 +372,14 @@ class FARMReader(BaseReader):
|
|||||||
```python
|
```python
|
||||||
|{
|
|{
|
||||||
| 'query': 'Who is the father of Arya Stark?',
|
| 'query': 'Who is the father of Arya Stark?',
|
||||||
| 'answers':[
|
| 'answers':[Answer(
|
||||||
| {'answer': 'Eddard,',
|
| 'answer': 'Eddard,',
|
||||||
| 'context': " She travels with her father, Eddard, to King's Landing when he is ",
|
| 'context': "She travels with her father, Eddard, to King's Landing when he is",
|
||||||
| 'offset_answer_start': 147,
|
|
||||||
| 'offset_answer_end': 154,
|
|
||||||
| 'score': 0.9787139466668613,
|
| 'score': 0.9787139466668613,
|
||||||
| 'document_id': '1337'
|
| 'offsets_in_context': [Span(start=29, end=35],
|
||||||
| },...
|
| 'offsets_in_context': [Span(start=347, end=353],
|
||||||
|
| 'document_id': '88d1ed769d003939d3a0d28034464ab2'
|
||||||
|
| ),...
|
||||||
| ]
|
| ]
|
||||||
|}
|
|}
|
||||||
```
|
```
|
||||||
|
@ -52,6 +52,11 @@ def test_prediction_attributes(prediction):
|
|||||||
for ag in attributes_gold:
|
for ag in attributes_gold:
|
||||||
assert ag in prediction
|
assert ag in prediction
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
def test_model_download_options():
|
||||||
|
# download disabled and model is not cached locally
|
||||||
|
with pytest.raises(OSError):
|
||||||
|
impossible_reader = FARMReader("mfeb/albert-xxlarge-v2-squad2", local_files_only=True)
|
||||||
|
|
||||||
def test_answer_attributes(prediction):
|
def test_answer_attributes(prediction):
|
||||||
# TODO Transformers answer also has meta key
|
# TODO Transformers answer also has meta key
|
||||||
|
Loading…
x
Reference in New Issue
Block a user