mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-11-03 19:29:32 +00:00 
			
		
		
		
	Add Tapas reader with scores (#1997)
* Add Tapas reader with scores * Adapt possible answer spans * Add latest docstring and tutorial changes * Remove unused imports * Adapt scoring * Add latest docstring and tutorial changes * Fix mypy * Infer model architecture from config * Adapt answer score calculation * Add latest docstring and tutorial changes * Fix mypy Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									ee6b8d0688
								
							
						
					
					
						commit
						bbb65a19bd
					
				@ -630,7 +630,7 @@ answer = prediction["answers"][0].answer  # "10 june 1996"
 | 
				
			|||||||
#### \_\_init\_\_
 | 
					#### \_\_init\_\_
 | 
				
			||||||
 | 
					
 | 
				
			||||||
```python
 | 
					```python
 | 
				
			||||||
 | __init__(model_name_or_path: str = "google/tapas-base-finetuned-wtq", model_version: Optional[str] = None, tokenizer: Optional[str] = None, use_gpu: bool = True, top_k: int = 10, max_seq_len: int = 256)
 | 
					 | __init__(model_name_or_path: str = "google/tapas-base-finetuned-wtq", model_version: Optional[str] = None, tokenizer: Optional[str] = None, use_gpu: bool = True, top_k: int = 10, top_k_per_candidate: int = 3, return_no_answer: bool = False, max_seq_len: int = 256)
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Load a TableQA model from Transformers.
 | 
					Load a TableQA model from Transformers.
 | 
				
			||||||
@ -638,18 +638,30 @@ Available models include:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
- ``'google/tapas-base-finetuned-wtq`'``
 | 
					- ``'google/tapas-base-finetuned-wtq`'``
 | 
				
			||||||
- ``'google/tapas-base-finetuned-wikisql-supervised``'
 | 
					- ``'google/tapas-base-finetuned-wikisql-supervised``'
 | 
				
			||||||
 | 
					- ``'deepset/tapas-large-nq-hn-reader'``
 | 
				
			||||||
 | 
					- ``'deepset/tapas-large-nq-reader'``
 | 
				
			||||||
 | 
					
 | 
				
			||||||
See https://huggingface.co/models?pipeline_tag=table-question-answering
 | 
					See https://huggingface.co/models?pipeline_tag=table-question-answering
 | 
				
			||||||
for full list of available TableQA models.
 | 
					for full list of available TableQA models.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					The nq-reader models are able to provide confidence scores, but cannot handle questions that need aggregation
 | 
				
			||||||
 | 
					over multiple cells. The returned answers are sorted first by a general table score and then by answer span
 | 
				
			||||||
 | 
					scores.
 | 
				
			||||||
 | 
					All the other models can handle aggregation questions, but don't provide reasonable confidence scores.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
**Arguments**:
 | 
					**Arguments**:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
- `model_name_or_path`: Directory of a saved model or the name of a public model e.g.
 | 
					- `model_name_or_path`: Directory of a saved model or the name of a public model e.g.
 | 
				
			||||||
See https://huggingface.co/models?pipeline_tag=table-question-answering for full list of available models.
 | 
					See https://huggingface.co/models?pipeline_tag=table-question-answering for full list of available models.
 | 
				
			||||||
- `model_version`: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
 | 
					- `model_version`: The version of model to use from the HuggingFace model hub. Can be tag name, branch name,
 | 
				
			||||||
 | 
					                      or commit hash.
 | 
				
			||||||
- `tokenizer`: Name of the tokenizer (usually the same as model)
 | 
					- `tokenizer`: Name of the tokenizer (usually the same as model)
 | 
				
			||||||
- `use_gpu`: Whether to use GPU or CPU. Falls back on CPU if no GPU is available.
 | 
					- `use_gpu`: Whether to use GPU or CPU. Falls back on CPU if no GPU is available.
 | 
				
			||||||
- `top_k`: The maximum number of answers to return
 | 
					- `top_k`: The maximum number of answers to return
 | 
				
			||||||
 | 
					- `top_k_per_candidate`: How many answers to extract for each candidate table that is coming from
 | 
				
			||||||
 | 
					                            the retriever.
 | 
				
			||||||
 | 
					- `return_no_answer`: Whether to include no_answer predictions in the results.
 | 
				
			||||||
 | 
					                         (Only applicable with nq-reader models.)
 | 
				
			||||||
- `max_seq_len`: Max sequence length of one input table for the model. If the number of tokens of
 | 
					- `max_seq_len`: Max sequence length of one input table for the model. If the number of tokens of
 | 
				
			||||||
                    query + table exceed max_seq_len, the table will be truncated by removing rows until the
 | 
					                    query + table exceed max_seq_len, the table will be truncated by removing rows until the
 | 
				
			||||||
                    input size fits the model.
 | 
					                    input size fits the model.
 | 
				
			||||||
 | 
				
			|||||||
@ -6,7 +6,9 @@ import torch
 | 
				
			|||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
import pandas as pd
 | 
					import pandas as pd
 | 
				
			||||||
from quantulum3 import parser
 | 
					from quantulum3 import parser
 | 
				
			||||||
from transformers import TapasTokenizer, TapasForQuestionAnswering, AutoTokenizer, AutoModelForSequenceClassification, BatchEncoding
 | 
					from transformers import TapasTokenizer, TapasForQuestionAnswering, AutoTokenizer, AutoModelForSequenceClassification, \
 | 
				
			||||||
 | 
					    BatchEncoding, TapasModel, TapasConfig
 | 
				
			||||||
 | 
					from transformers.models.tapas.modeling_tapas import TapasPreTrainedModel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from haystack.schema import Document, Answer, Span
 | 
					from haystack.schema import Document, Answer, Span
 | 
				
			||||||
from haystack.nodes.reader.base import BaseReader
 | 
					from haystack.nodes.reader.base import BaseReader
 | 
				
			||||||
@ -49,6 +51,8 @@ class TableReader(BaseReader):
 | 
				
			|||||||
            tokenizer: Optional[str] = None,
 | 
					            tokenizer: Optional[str] = None,
 | 
				
			||||||
            use_gpu: bool = True,
 | 
					            use_gpu: bool = True,
 | 
				
			||||||
            top_k: int = 10,
 | 
					            top_k: int = 10,
 | 
				
			||||||
 | 
					            top_k_per_candidate: int = 3,
 | 
				
			||||||
 | 
					            return_no_answer: bool = False,
 | 
				
			||||||
            max_seq_len: int = 256,
 | 
					            max_seq_len: int = 256,
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
@ -57,31 +61,54 @@ class TableReader(BaseReader):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        - ``'google/tapas-base-finetuned-wtq`'``
 | 
					        - ``'google/tapas-base-finetuned-wtq`'``
 | 
				
			||||||
        - ``'google/tapas-base-finetuned-wikisql-supervised``'
 | 
					        - ``'google/tapas-base-finetuned-wikisql-supervised``'
 | 
				
			||||||
 | 
					        - ``'deepset/tapas-large-nq-hn-reader'``
 | 
				
			||||||
 | 
					        - ``'deepset/tapas-large-nq-reader'``
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        See https://huggingface.co/models?pipeline_tag=table-question-answering
 | 
					        See https://huggingface.co/models?pipeline_tag=table-question-answering
 | 
				
			||||||
        for full list of available TableQA models.
 | 
					        for full list of available TableQA models.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        The nq-reader models are able to provide confidence scores, but cannot handle questions that need aggregation
 | 
				
			||||||
 | 
					        over multiple cells. The returned answers are sorted first by a general table score and then by answer span
 | 
				
			||||||
 | 
					        scores.
 | 
				
			||||||
 | 
					        All the other models can handle aggregation questions, but don't provide reasonable confidence scores.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        :param model_name_or_path: Directory of a saved model or the name of a public model e.g.
 | 
					        :param model_name_or_path: Directory of a saved model or the name of a public model e.g.
 | 
				
			||||||
        See https://huggingface.co/models?pipeline_tag=table-question-answering for full list of available models.
 | 
					        See https://huggingface.co/models?pipeline_tag=table-question-answering for full list of available models.
 | 
				
			||||||
        :param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
 | 
					        :param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name,
 | 
				
			||||||
 | 
					                              or commit hash.
 | 
				
			||||||
        :param tokenizer: Name of the tokenizer (usually the same as model)
 | 
					        :param tokenizer: Name of the tokenizer (usually the same as model)
 | 
				
			||||||
        :param use_gpu: Whether to use GPU or CPU. Falls back on CPU if no GPU is available.
 | 
					        :param use_gpu: Whether to use GPU or CPU. Falls back on CPU if no GPU is available.
 | 
				
			||||||
        :param top_k: The maximum number of answers to return
 | 
					        :param top_k: The maximum number of answers to return
 | 
				
			||||||
 | 
					        :param top_k_per_candidate: How many answers to extract for each candidate table that is coming from
 | 
				
			||||||
 | 
					                                    the retriever.
 | 
				
			||||||
 | 
					        :param return_no_answer: Whether to include no_answer predictions in the results.
 | 
				
			||||||
 | 
					                                 (Only applicable with nq-reader models.)
 | 
				
			||||||
        :param max_seq_len: Max sequence length of one input table for the model. If the number of tokens of
 | 
					        :param max_seq_len: Max sequence length of one input table for the model. If the number of tokens of
 | 
				
			||||||
                            query + table exceed max_seq_len, the table will be truncated by removing rows until the
 | 
					                            query + table exceed max_seq_len, the table will be truncated by removing rows until the
 | 
				
			||||||
                            input size fits the model.
 | 
					                            input size fits the model.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
					        # Save init parameters to enable export of component config as YAML
 | 
				
			||||||
 | 
					        self.set_config(model_name_or_path=model_name_or_path, model_version=model_version, tokenizer=tokenizer,
 | 
				
			||||||
 | 
					                        use_gpu=use_gpu, top_k=top_k, top_k_per_candidate=top_k_per_candidate,
 | 
				
			||||||
 | 
					                        return_no_answer=return_no_answer, max_seq_len=max_seq_len)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False)
 | 
					        self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False)
 | 
				
			||||||
 | 
					        config = TapasConfig.from_pretrained(model_name_or_path)
 | 
				
			||||||
 | 
					        if config.architectures[0] == "TapasForScoredQA":
 | 
				
			||||||
 | 
					            self.model = self.TapasForScoredQA.from_pretrained(model_name_or_path, revision=model_version)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
            self.model = TapasForQuestionAnswering.from_pretrained(model_name_or_path, revision=model_version)
 | 
					            self.model = TapasForQuestionAnswering.from_pretrained(model_name_or_path, revision=model_version)
 | 
				
			||||||
        self.model.to(str(self.devices[0]))
 | 
					        self.model.to(str(self.devices[0]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if tokenizer is None:
 | 
					        if tokenizer is None:
 | 
				
			||||||
            self.tokenizer = TapasTokenizer.from_pretrained(model_name_or_path)
 | 
					            self.tokenizer = TapasTokenizer.from_pretrained(model_name_or_path)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            self.tokenizer = TapasTokenizer.from_pretrained(tokenizer)
 | 
					            self.tokenizer = TapasTokenizer.from_pretrained(tokenizer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.top_k = top_k
 | 
					        self.top_k = top_k
 | 
				
			||||||
 | 
					        self.top_k_per_candidate = top_k_per_candidate
 | 
				
			||||||
        self.max_seq_len = max_seq_len
 | 
					        self.max_seq_len = max_seq_len
 | 
				
			||||||
        self.return_no_answers = False
 | 
					        self.return_no_answer = return_no_answer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> Dict:
 | 
					    def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> Dict:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
@ -102,6 +129,7 @@ class TableReader(BaseReader):
 | 
				
			|||||||
            top_k = self.top_k
 | 
					            top_k = self.top_k
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        answers = []
 | 
					        answers = []
 | 
				
			||||||
 | 
					        no_answer_score = 1.0
 | 
				
			||||||
        for document in documents:
 | 
					        for document in documents:
 | 
				
			||||||
            if document.content_type != "table":
 | 
					            if document.content_type != "table":
 | 
				
			||||||
                logger.warning(f"Skipping document with id {document.id} in TableReader, as it is not of type table.")
 | 
					                logger.warning(f"Skipping document with id {document.id} in TableReader, as it is not of type table.")
 | 
				
			||||||
@ -115,6 +143,38 @@ class TableReader(BaseReader):
 | 
				
			|||||||
                                    return_tensors="pt",
 | 
					                                    return_tensors="pt",
 | 
				
			||||||
                                    truncation=True)
 | 
					                                    truncation=True)
 | 
				
			||||||
            inputs.to(self.devices[0])
 | 
					            inputs.to(self.devices[0])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if isinstance(self.model, TapasForQuestionAnswering):
 | 
				
			||||||
 | 
					                current_answer = self._predict_tapas_for_qa(inputs, document)
 | 
				
			||||||
 | 
					                answers.append(current_answer)
 | 
				
			||||||
 | 
					            elif isinstance(self.model, self.TapasForScoredQA):
 | 
				
			||||||
 | 
					                current_answers, current_no_answer_score = self._predict_tapas_for_scored_qa(inputs, document)
 | 
				
			||||||
 | 
					                answers.extend(current_answers)
 | 
				
			||||||
 | 
					                if current_no_answer_score < no_answer_score:
 | 
				
			||||||
 | 
					                    no_answer_score = current_no_answer_score
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.return_no_answer and isinstance(self.model, self.TapasForScoredQA):
 | 
				
			||||||
 | 
					            answers.append(Answer(
 | 
				
			||||||
 | 
					                answer="",
 | 
				
			||||||
 | 
					                type="extractive",
 | 
				
			||||||
 | 
					                score=no_answer_score,
 | 
				
			||||||
 | 
					                context=None,
 | 
				
			||||||
 | 
					                offsets_in_context=[Span(start=0, end=0)],
 | 
				
			||||||
 | 
					                offsets_in_document=[Span(start=0, end=0)],
 | 
				
			||||||
 | 
					                document_id=None,
 | 
				
			||||||
 | 
					                meta=None
 | 
				
			||||||
 | 
					            ))
 | 
				
			||||||
 | 
					        answers = sorted(answers, reverse=True)
 | 
				
			||||||
 | 
					        answers = answers[:top_k]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        results = {"query": query,
 | 
				
			||||||
 | 
					                   "answers": answers}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return results
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _predict_tapas_for_qa(self, inputs: BatchEncoding, document: Document) -> Answer:
 | 
				
			||||||
 | 
					        table: pd.DataFrame = document.content
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Forward query and table through model and convert logits to predictions
 | 
					        # Forward query and table through model and convert logits to predictions
 | 
				
			||||||
        outputs = self.model(**inputs)
 | 
					        outputs = self.model(**inputs)
 | 
				
			||||||
        inputs.to("cpu")
 | 
					        inputs.to("cpu")
 | 
				
			||||||
@ -153,30 +213,105 @@ class TableReader(BaseReader):
 | 
				
			|||||||
        else:
 | 
					        else:
 | 
				
			||||||
            answer_str = self._aggregate_answers(current_aggregation_operator, current_answer_cells)
 | 
					            answer_str = self._aggregate_answers(current_aggregation_operator, current_answer_cells)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            answer_offsets = self._calculate_answer_offsets(current_answer_coordinates, table)
 | 
					        answer_offsets = self._calculate_answer_offsets(current_answer_coordinates, document.content)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            answers.append(
 | 
					        answer = Answer(
 | 
				
			||||||
                Answer(
 | 
					 | 
				
			||||||
            answer=answer_str,
 | 
					            answer=answer_str,
 | 
				
			||||||
            type="extractive",
 | 
					            type="extractive",
 | 
				
			||||||
            score=current_score,
 | 
					            score=current_score,
 | 
				
			||||||
                    context=table,
 | 
					            context=document.content,
 | 
				
			||||||
            offsets_in_document=answer_offsets,
 | 
					            offsets_in_document=answer_offsets,
 | 
				
			||||||
            offsets_in_context=answer_offsets,
 | 
					            offsets_in_context=answer_offsets,
 | 
				
			||||||
            document_id=document.id,
 | 
					            document_id=document.id,
 | 
				
			||||||
            meta={"aggregation_operator": current_aggregation_operator,
 | 
					            meta={"aggregation_operator": current_aggregation_operator,
 | 
				
			||||||
                  "answer_cells": current_answer_cells}
 | 
					                  "answer_cells": current_answer_cells}
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return answer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _predict_tapas_for_scored_qa(self, inputs: BatchEncoding, document: Document) -> Tuple[List[Answer], float]:
 | 
				
			||||||
 | 
					        table: pd.DataFrame = document.content
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Forward pass through model
 | 
				
			||||||
 | 
					        outputs = self.model.tapas(**inputs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Get general table score
 | 
				
			||||||
 | 
					        table_score = self.model.classifier(outputs.pooler_output)
 | 
				
			||||||
 | 
					        table_score_softmax = torch.nn.functional.softmax(table_score, dim=1)
 | 
				
			||||||
 | 
					        table_relevancy_prob = table_score_softmax[0][1].item()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Get possible answer spans
 | 
				
			||||||
 | 
					        token_types = [
 | 
				
			||||||
 | 
					            "segment_ids",
 | 
				
			||||||
 | 
					            "column_ids",
 | 
				
			||||||
 | 
					            "row_ids",
 | 
				
			||||||
 | 
					            "prev_labels",
 | 
				
			||||||
 | 
					            "column_ranks",
 | 
				
			||||||
 | 
					            "inv_column_ranks",
 | 
				
			||||||
 | 
					            "numeric_relations",
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
 | 
					        row_ids: List[int] = inputs.token_type_ids[:, :, token_types.index("row_ids")].tolist()[0]
 | 
				
			||||||
 | 
					        column_ids: List[int] = inputs.token_type_ids[:, :, token_types.index("column_ids")].tolist()[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        possible_answer_spans: List[Tuple[int, int, int, int]] = []  # List of tuples: (row_idx, col_idx, start_token, end_token)
 | 
				
			||||||
 | 
					        current_start_idx = -1
 | 
				
			||||||
 | 
					        current_column_id = -1
 | 
				
			||||||
 | 
					        for idx, (row_id, column_id) in enumerate(zip(row_ids, column_ids)):
 | 
				
			||||||
 | 
					            if row_id == 0 or column_id == 0:
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					            # Beginning of new cell
 | 
				
			||||||
 | 
					            if column_id != current_column_id:
 | 
				
			||||||
 | 
					                if current_start_idx != -1:
 | 
				
			||||||
 | 
					                    possible_answer_spans.append(
 | 
				
			||||||
 | 
					                        (row_ids[current_start_idx]-1, column_ids[current_start_idx]-1, current_start_idx, idx-1)
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                current_start_idx = idx
 | 
				
			||||||
 | 
					                current_column_id = column_id
 | 
				
			||||||
 | 
					        possible_answer_spans.append(
 | 
				
			||||||
 | 
					            (row_ids[current_start_idx]-1, column_ids[current_start_idx]-1, current_start_idx, len(row_ids)-1)
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Sort answers by score and select top-k answers
 | 
					        # Concat logits of start token and end token of possible answer spans
 | 
				
			||||||
        answers = sorted(answers, reverse=True)
 | 
					        sequence_output = outputs.last_hidden_state
 | 
				
			||||||
        answers = answers[:top_k]
 | 
					        concatenated_logits = []
 | 
				
			||||||
 | 
					        for possible_span in possible_answer_spans:
 | 
				
			||||||
 | 
					            start_token_logits = sequence_output[0, possible_span[2], :]
 | 
				
			||||||
 | 
					            end_token_logits = sequence_output[0, possible_span[3], :]
 | 
				
			||||||
 | 
					            concatenated_logits.append(torch.cat((start_token_logits, end_token_logits)))
 | 
				
			||||||
 | 
					        concatenated_logit_tensors = torch.unsqueeze(torch.stack(concatenated_logits), dim=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        results = {"query": query,
 | 
					        # Calculate score for each possible span
 | 
				
			||||||
                   "answers": answers}
 | 
					        span_logits = torch.einsum("bsj,j->bs", concatenated_logit_tensors, self.model.span_output_weights) \
 | 
				
			||||||
 | 
					                      + self.model.span_output_bias
 | 
				
			||||||
 | 
					        span_logits_softmax = torch.nn.functional.softmax(span_logits, dim=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return results
 | 
					        top_k_answer_spans = torch.topk(span_logits[0], min(self.top_k_per_candidate, len(possible_answer_spans)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        answers = []
 | 
				
			||||||
 | 
					        for answer_span_idx in top_k_answer_spans.indices:
 | 
				
			||||||
 | 
					            current_answer_span = possible_answer_spans[answer_span_idx]
 | 
				
			||||||
 | 
					            answer_str = table.iat[current_answer_span[:2]]
 | 
				
			||||||
 | 
					            answer_offsets = self._calculate_answer_offsets([current_answer_span[:2]], document.content)
 | 
				
			||||||
 | 
					            # As the general table score is more important for the final score, it is double weighted.
 | 
				
			||||||
 | 
					            current_score = ((2 * table_relevancy_prob) + span_logits_softmax[0, answer_span_idx].item()) / 3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            answers.append(
 | 
				
			||||||
 | 
					                Answer(
 | 
				
			||||||
 | 
					                    answer=answer_str,
 | 
				
			||||||
 | 
					                    type="extractive",
 | 
				
			||||||
 | 
					                    score=current_score,
 | 
				
			||||||
 | 
					                    context=document.content,
 | 
				
			||||||
 | 
					                    offsets_in_document=answer_offsets,
 | 
				
			||||||
 | 
					                    offsets_in_context=answer_offsets,
 | 
				
			||||||
 | 
					                    document_id=document.id,
 | 
				
			||||||
 | 
					                    meta={"aggregation_operator": "NONE",
 | 
				
			||||||
 | 
					                          "answer_cells": table.iat[current_answer_span[:2]]}
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        no_answer_score = 1 - table_relevancy_prob
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return answers, no_answer_score
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    def _calculate_answer_score(self, logits: torch.Tensor, inputs: BatchEncoding,
 | 
					    def _calculate_answer_score(self, logits: torch.Tensor, inputs: BatchEncoding,
 | 
				
			||||||
                                answer_coordinates: List[Tuple[int, int]]) -> float:
 | 
					                                answer_coordinates: List[Tuple[int, int]]) -> float:
 | 
				
			||||||
@ -253,6 +388,27 @@ class TableReader(BaseReader):
 | 
				
			|||||||
    def predict_batch(self, query_doc_list: List[dict], top_k: Optional[int] = None, batch_size: Optional[int] = None):
 | 
					    def predict_batch(self, query_doc_list: List[dict], top_k: Optional[int] = None, batch_size: Optional[int] = None):
 | 
				
			||||||
        raise NotImplementedError("Batch prediction not yet available in TableReader.")
 | 
					        raise NotImplementedError("Batch prediction not yet available in TableReader.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    class TapasForScoredQA(TapasPreTrainedModel):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def __init__(self, config):
 | 
				
			||||||
 | 
					            super().__init__(config)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # base model
 | 
				
			||||||
 | 
					            self.tapas = TapasModel(config)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # dropout (only used when training)
 | 
				
			||||||
 | 
					            self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # answer selection head
 | 
				
			||||||
 | 
					            self.span_output_weights = torch.nn.Parameter(torch.zeros(2 * config.hidden_size))
 | 
				
			||||||
 | 
					            self.span_output_bias = torch.nn.Parameter(torch.zeros([]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # table scoring head
 | 
				
			||||||
 | 
					            self.classifier = torch.nn.Linear(config.hidden_size, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # Initialize weights
 | 
				
			||||||
 | 
					            self.init_weights()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class RCIReader(BaseReader):
 | 
					class RCIReader(BaseReader):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user