diff --git a/docs/_src/api/api/reader.md b/docs/_src/api/api/reader.md index 5d959e469..311274291 100644 --- a/docs/_src/api/api/reader.md +++ b/docs/_src/api/api/reader.md @@ -48,7 +48,7 @@ While the underlying model can vary (BERT, Roberta, DistilBERT, ...), the interf #### \_\_init\_\_ ```python - | __init__(model_name_or_path: str, model_version: Optional[str] = None, context_window_size: int = 150, batch_size: int = 50, use_gpu: bool = True, no_ans_boost: float = 0.0, return_no_answer: bool = False, top_k: int = 10, top_k_per_candidate: int = 3, top_k_per_sample: int = 1, num_processes: Optional[int] = None, max_seq_len: int = 256, doc_stride: int = 128, progress_bar: bool = True, duplicate_filtering: int = 0, use_confidence_scores: bool = True) + | __init__(model_name_or_path: str, model_version: Optional[str] = None, context_window_size: int = 150, batch_size: int = 50, use_gpu: bool = True, no_ans_boost: float = 0.0, return_no_answer: bool = False, top_k: int = 10, top_k_per_candidate: int = 3, top_k_per_sample: int = 1, num_processes: Optional[int] = None, max_seq_len: int = 256, doc_stride: int = 128, progress_bar: bool = True, duplicate_filtering: int = 0, use_confidence_scores: bool = True, proxies=None, local_files_only=False, force_download=False, **kwargs) ``` **Arguments**: @@ -89,6 +89,15 @@ and that FARM includes no_answer in the sorted list of predictions. Can be helpful to disable in production deployments to keep the logs clean. - `duplicate_filtering`: Answers are filtered based on their position. Both start and end position of the answers are considered. The higher the value, answers that are more apart are filtered out. 0 corresponds to exact duplicates. -1 turns off duplicate removal. +- `use_confidence_scores`: Sets the type of score that is returned with every predicted answer. + `True` => a scaled confidence / relevance score between [0, 1]. + This score can also be further calibrated on your dataset via self.eval() + (see https://haystack.deepset.ai/components/reader#confidence-scores) . + `False` => an unscaled, raw score [-inf, +inf] which is the sum of start and end logit + from the model for the predicted span. +- `proxies`: Dict of proxy servers to use for downloading external models. Example: {'http': 'some.proxy:1234', 'http://hostname': 'my.proxy:3111'} +- `local_files_only`: Whether to force checking for local files only (and forbid downloads) +- `force_download`: Whether fo force a (re-)download even if the model exists locally in the cache. #### train @@ -193,14 +202,14 @@ Example: ```python |{ | 'query': 'Who is the father of Arya Stark?', - | 'answers':[ - | {'answer': 'Eddard,', - | 'context': " She travels with her father, Eddard, to King's Landing when he is ", - | 'offset_answer_start': 147, - | 'offset_answer_end': 154, + | 'answers':[Answer( + | 'answer': 'Eddard,', + | 'context': "She travels with her father, Eddard, to King's Landing when he is", | 'score': 0.9787139466668613, - | 'document_id': '1337' - | },... + | 'offsets_in_context': [Span(start=29, end=35], + | 'offsets_in_context': [Span(start=347, end=353], + | 'document_id': '88d1ed769d003939d3a0d28034464ab2' + | ),... | ] |} ``` diff --git a/haystack/modeling/model/adaptive_model.py b/haystack/modeling/model/adaptive_model.py index 18f0381ab..e8b76c29c 100644 --- a/haystack/modeling/model/adaptive_model.py +++ b/haystack/modeling/model/adaptive_model.py @@ -463,7 +463,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel): @classmethod def convert_from_transformers(cls, model_name_or_path: Union[str, Path], device: str, revision: Optional[str] = None, - task_type: Optional[str] = None, processor: Optional[Processor] = None): + task_type: Optional[str] = None, processor: Optional[Processor] = None, **kwargs): """ Load a (downstream) model from huggingface's transformers format. Use cases: - continue training in Haystack (e.g. take a squad QA model and fine-tune on your own data) @@ -489,7 +489,8 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel): revision=revision, device=device, task_type=task_type, - processor=processor) + processor=processor, + **kwargs) @classmethod diff --git a/haystack/reader/farm.py b/haystack/reader/farm.py index 04db8e321..fe6d8abfd 100644 --- a/haystack/reader/farm.py +++ b/haystack/reader/farm.py @@ -52,7 +52,11 @@ class FARMReader(BaseReader): doc_stride: int = 128, progress_bar: bool = True, duplicate_filtering: int = 0, - use_confidence_scores: bool = True + use_confidence_scores: bool = True, + proxies=None, + local_files_only=False, + force_download=False, + **kwargs ): """ @@ -92,6 +96,15 @@ class FARMReader(BaseReader): Can be helpful to disable in production deployments to keep the logs clean. :param duplicate_filtering: Answers are filtered based on their position. Both start and end position of the answers are considered. The higher the value, answers that are more apart are filtered out. 0 corresponds to exact duplicates. -1 turns off duplicate removal. + :param use_confidence_scores: Sets the type of score that is returned with every predicted answer. + `True` => a scaled confidence / relevance score between [0, 1]. + This score can also be further calibrated on your dataset via self.eval() + (see https://haystack.deepset.ai/components/reader#confidence-scores) . + `False` => an unscaled, raw score [-inf, +inf] which is the sum of start and end logit + from the model for the predicted span. + :param proxies: Dict of proxy servers to use for downloading external models. Example: {'http': 'some.proxy:1234', 'http://hostname': 'my.proxy:3111'} + :param local_files_only: Whether to force checking for local files only (and forbid downloads) + :param force_download: Whether fo force a (re-)download even if the model exists locally in the cache. """ # save init parameters to enable export of component config as YAML @@ -100,7 +113,8 @@ class FARMReader(BaseReader): batch_size=batch_size, use_gpu=use_gpu, no_ans_boost=no_ans_boost, return_no_answer=return_no_answer, top_k=top_k, top_k_per_candidate=top_k_per_candidate, top_k_per_sample=top_k_per_sample, num_processes=num_processes, max_seq_len=max_seq_len, doc_stride=doc_stride, progress_bar=progress_bar, - duplicate_filtering=duplicate_filtering, use_confidence_scores=use_confidence_scores + duplicate_filtering=duplicate_filtering, proxies=proxies, local_files_only=local_files_only, + force_download=force_download, use_confidence_scores=use_confidence_scores, **kwargs ) self.return_no_answers = return_no_answer @@ -110,7 +124,11 @@ class FARMReader(BaseReader): task_type="question_answering", max_seq_len=max_seq_len, doc_stride=doc_stride, num_processes=num_processes, revision=model_version, disable_tqdm=not progress_bar, - strict=False) + strict=False, + proxies=proxies, + local_files_only=local_files_only, + force_download=force_download, + **kwargs) self.inferencer.model.prediction_heads[0].context_window_size = context_window_size self.inferencer.model.prediction_heads[0].no_ans_boost = no_ans_boost self.inferencer.model.prediction_heads[0].n_best = top_k_per_candidate + 1 # including possible no_answer @@ -346,7 +364,6 @@ class FARMReader(BaseReader): return result def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None): - #TODO update example in docstring """ Use loaded QA model to find answers for a query in the supplied list of Document. @@ -355,14 +372,14 @@ class FARMReader(BaseReader): ```python |{ | 'query': 'Who is the father of Arya Stark?', - | 'answers':[ - | {'answer': 'Eddard,', - | 'context': " She travels with her father, Eddard, to King's Landing when he is ", - | 'offset_answer_start': 147, - | 'offset_answer_end': 154, + | 'answers':[Answer( + | 'answer': 'Eddard,', + | 'context': "She travels with her father, Eddard, to King's Landing when he is", | 'score': 0.9787139466668613, - | 'document_id': '1337' - | },... + | 'offsets_in_context': [Span(start=29, end=35], + | 'offsets_in_context': [Span(start=347, end=353], + | 'document_id': '88d1ed769d003939d3a0d28034464ab2' + | ),... | ] |} ``` diff --git a/test/test_reader.py b/test/test_reader.py index 1a526c496..3abcda167 100644 --- a/test/test_reader.py +++ b/test/test_reader.py @@ -52,6 +52,11 @@ def test_prediction_attributes(prediction): for ag in attributes_gold: assert ag in prediction +@pytest.mark.slow +def test_model_download_options(): + # download disabled and model is not cached locally + with pytest.raises(OSError): + impossible_reader = FARMReader("mfeb/albert-xxlarge-v2-squad2", local_files_only=True) def test_answer_attributes(prediction): # TODO Transformers answer also has meta key