diff --git a/docs/_src/api/api/reader.md b/docs/_src/api/api/reader.md index 26e658568..806bfff0a 100644 --- a/docs/_src/api/api/reader.md +++ b/docs/_src/api/api/reader.md @@ -630,7 +630,7 @@ answer = prediction["answers"][0].answer # "10 june 1996" #### \_\_init\_\_ ```python - | __init__(model_name_or_path: str = "google/tapas-base-finetuned-wtq", model_version: Optional[str] = None, tokenizer: Optional[str] = None, use_gpu: bool = True, top_k: int = 10, max_seq_len: int = 256) + | __init__(model_name_or_path: str = "google/tapas-base-finetuned-wtq", model_version: Optional[str] = None, tokenizer: Optional[str] = None, use_gpu: bool = True, top_k: int = 10, top_k_per_candidate: int = 3, return_no_answer: bool = False, max_seq_len: int = 256) ``` Load a TableQA model from Transformers. @@ -638,18 +638,30 @@ Available models include: - ``'google/tapas-base-finetuned-wtq`'`` - ``'google/tapas-base-finetuned-wikisql-supervised``' +- ``'deepset/tapas-large-nq-hn-reader'`` +- ``'deepset/tapas-large-nq-reader'`` See https://huggingface.co/models?pipeline_tag=table-question-answering for full list of available TableQA models. +The nq-reader models are able to provide confidence scores, but cannot handle questions that need aggregation +over multiple cells. The returned answers are sorted first by a general table score and then by answer span +scores. +All the other models can handle aggregation questions, but don't provide reasonable confidence scores. + **Arguments**: - `model_name_or_path`: Directory of a saved model or the name of a public model e.g. See https://huggingface.co/models?pipeline_tag=table-question-answering for full list of available models. -- `model_version`: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash. +- `model_version`: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, + or commit hash. - `tokenizer`: Name of the tokenizer (usually the same as model) - `use_gpu`: Whether to use GPU or CPU. Falls back on CPU if no GPU is available. - `top_k`: The maximum number of answers to return +- `top_k_per_candidate`: How many answers to extract for each candidate table that is coming from + the retriever. +- `return_no_answer`: Whether to include no_answer predictions in the results. + (Only applicable with nq-reader models.) - `max_seq_len`: Max sequence length of one input table for the model. If the number of tokens of query + table exceed max_seq_len, the table will be truncated by removing rows until the input size fits the model. diff --git a/haystack/nodes/reader/table.py b/haystack/nodes/reader/table.py index 5bebe0eb7..2c2a4cbdc 100644 --- a/haystack/nodes/reader/table.py +++ b/haystack/nodes/reader/table.py @@ -6,7 +6,9 @@ import torch import numpy as np import pandas as pd from quantulum3 import parser -from transformers import TapasTokenizer, TapasForQuestionAnswering, AutoTokenizer, AutoModelForSequenceClassification, BatchEncoding +from transformers import TapasTokenizer, TapasForQuestionAnswering, AutoTokenizer, AutoModelForSequenceClassification, \ + BatchEncoding, TapasModel, TapasConfig +from transformers.models.tapas.modeling_tapas import TapasPreTrainedModel from haystack.schema import Document, Answer, Span from haystack.nodes.reader.base import BaseReader @@ -49,6 +51,8 @@ class TableReader(BaseReader): tokenizer: Optional[str] = None, use_gpu: bool = True, top_k: int = 10, + top_k_per_candidate: int = 3, + return_no_answer: bool = False, max_seq_len: int = 256, ): """ @@ -57,31 +61,54 @@ class TableReader(BaseReader): - ``'google/tapas-base-finetuned-wtq`'`` - ``'google/tapas-base-finetuned-wikisql-supervised``' + - ``'deepset/tapas-large-nq-hn-reader'`` + - ``'deepset/tapas-large-nq-reader'`` See https://huggingface.co/models?pipeline_tag=table-question-answering for full list of available TableQA models. + The nq-reader models are able to provide confidence scores, but cannot handle questions that need aggregation + over multiple cells. The returned answers are sorted first by a general table score and then by answer span + scores. + All the other models can handle aggregation questions, but don't provide reasonable confidence scores. + :param model_name_or_path: Directory of a saved model or the name of a public model e.g. See https://huggingface.co/models?pipeline_tag=table-question-answering for full list of available models. - :param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash. + :param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, + or commit hash. :param tokenizer: Name of the tokenizer (usually the same as model) :param use_gpu: Whether to use GPU or CPU. Falls back on CPU if no GPU is available. :param top_k: The maximum number of answers to return + :param top_k_per_candidate: How many answers to extract for each candidate table that is coming from + the retriever. + :param return_no_answer: Whether to include no_answer predictions in the results. + (Only applicable with nq-reader models.) :param max_seq_len: Max sequence length of one input table for the model. If the number of tokens of query + table exceed max_seq_len, the table will be truncated by removing rows until the input size fits the model. """ + # Save init parameters to enable export of component config as YAML + self.set_config(model_name_or_path=model_name_or_path, model_version=model_version, tokenizer=tokenizer, + use_gpu=use_gpu, top_k=top_k, top_k_per_candidate=top_k_per_candidate, + return_no_answer=return_no_answer, max_seq_len=max_seq_len) self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False) - self.model = TapasForQuestionAnswering.from_pretrained(model_name_or_path, revision=model_version) + config = TapasConfig.from_pretrained(model_name_or_path) + if config.architectures[0] == "TapasForScoredQA": + self.model = self.TapasForScoredQA.from_pretrained(model_name_or_path, revision=model_version) + else: + self.model = TapasForQuestionAnswering.from_pretrained(model_name_or_path, revision=model_version) self.model.to(str(self.devices[0])) + if tokenizer is None: self.tokenizer = TapasTokenizer.from_pretrained(model_name_or_path) else: self.tokenizer = TapasTokenizer.from_pretrained(tokenizer) + self.top_k = top_k + self.top_k_per_candidate = top_k_per_candidate self.max_seq_len = max_seq_len - self.return_no_answers = False + self.return_no_answer = return_no_answer def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> Dict: """ @@ -102,6 +129,7 @@ class TableReader(BaseReader): top_k = self.top_k answers = [] + no_answer_score = 1.0 for document in documents: if document.content_type != "table": logger.warning(f"Skipping document with id {document.id} in TableReader, as it is not of type table.") @@ -115,61 +143,27 @@ class TableReader(BaseReader): return_tensors="pt", truncation=True) inputs.to(self.devices[0]) - # Forward query and table through model and convert logits to predictions - outputs = self.model(**inputs) - inputs.to("cpu") - if self.model.config.num_aggregation_labels > 0: - aggregation_logits = outputs.logits_aggregation.cpu().detach() - else: - aggregation_logits = None - predicted_output = self.tokenizer.convert_logits_to_predictions( - inputs, - outputs.logits.cpu().detach(), - aggregation_logits - ) - if len(predicted_output) == 1: - predicted_answer_coordinates = predicted_output[0] - else: - predicted_answer_coordinates, predicted_aggregation_indices = predicted_output + if isinstance(self.model, TapasForQuestionAnswering): + current_answer = self._predict_tapas_for_qa(inputs, document) + answers.append(current_answer) + elif isinstance(self.model, self.TapasForScoredQA): + current_answers, current_no_answer_score = self._predict_tapas_for_scored_qa(inputs, document) + answers.extend(current_answers) + if current_no_answer_score < no_answer_score: + no_answer_score = current_no_answer_score - # Get cell values - current_answer_coordinates = predicted_answer_coordinates[0] - current_answer_cells = [] - for coordinate in current_answer_coordinates: - current_answer_cells.append(table.iat[coordinate]) - - # Get aggregation operator - if self.model.config.aggregation_labels is not None: - current_aggregation_operator = self.model.config.aggregation_labels[predicted_aggregation_indices[0]] - else: - current_aggregation_operator = "NONE" - - # Calculate answer score - current_score = self._calculate_answer_score(outputs.logits.cpu().detach(), inputs, current_answer_coordinates) - - if current_aggregation_operator == "NONE": - answer_str = ", ".join(current_answer_cells) - else: - answer_str = self._aggregate_answers(current_aggregation_operator, current_answer_cells) - - answer_offsets = self._calculate_answer_offsets(current_answer_coordinates, table) - - answers.append( - Answer( - answer=answer_str, - type="extractive", - score=current_score, - context=table, - offsets_in_document=answer_offsets, - offsets_in_context=answer_offsets, - document_id=document.id, - meta={"aggregation_operator": current_aggregation_operator, - "answer_cells": current_answer_cells} - ) - ) - - # Sort answers by score and select top-k answers + if self.return_no_answer and isinstance(self.model, self.TapasForScoredQA): + answers.append(Answer( + answer="", + type="extractive", + score=no_answer_score, + context=None, + offsets_in_context=[Span(start=0, end=0)], + offsets_in_document=[Span(start=0, end=0)], + document_id=None, + meta=None + )) answers = sorted(answers, reverse=True) answers = answers[:top_k] @@ -177,6 +171,147 @@ class TableReader(BaseReader): "answers": answers} return results + + def _predict_tapas_for_qa(self, inputs: BatchEncoding, document: Document) -> Answer: + table: pd.DataFrame = document.content + + # Forward query and table through model and convert logits to predictions + outputs = self.model(**inputs) + inputs.to("cpu") + if self.model.config.num_aggregation_labels > 0: + aggregation_logits = outputs.logits_aggregation.cpu().detach() + else: + aggregation_logits = None + + predicted_output = self.tokenizer.convert_logits_to_predictions( + inputs, + outputs.logits.cpu().detach(), + aggregation_logits + ) + if len(predicted_output) == 1: + predicted_answer_coordinates = predicted_output[0] + else: + predicted_answer_coordinates, predicted_aggregation_indices = predicted_output + + # Get cell values + current_answer_coordinates = predicted_answer_coordinates[0] + current_answer_cells = [] + for coordinate in current_answer_coordinates: + current_answer_cells.append(table.iat[coordinate]) + + # Get aggregation operator + if self.model.config.aggregation_labels is not None: + current_aggregation_operator = self.model.config.aggregation_labels[predicted_aggregation_indices[0]] + else: + current_aggregation_operator = "NONE" + + # Calculate answer score + current_score = self._calculate_answer_score(outputs.logits.cpu().detach(), inputs, current_answer_coordinates) + + if current_aggregation_operator == "NONE": + answer_str = ", ".join(current_answer_cells) + else: + answer_str = self._aggregate_answers(current_aggregation_operator, current_answer_cells) + + answer_offsets = self._calculate_answer_offsets(current_answer_coordinates, document.content) + + answer = Answer( + answer=answer_str, + type="extractive", + score=current_score, + context=document.content, + offsets_in_document=answer_offsets, + offsets_in_context=answer_offsets, + document_id=document.id, + meta={"aggregation_operator": current_aggregation_operator, + "answer_cells": current_answer_cells} + ) + + return answer + + def _predict_tapas_for_scored_qa(self, inputs: BatchEncoding, document: Document) -> Tuple[List[Answer], float]: + table: pd.DataFrame = document.content + + # Forward pass through model + outputs = self.model.tapas(**inputs) + + # Get general table score + table_score = self.model.classifier(outputs.pooler_output) + table_score_softmax = torch.nn.functional.softmax(table_score, dim=1) + table_relevancy_prob = table_score_softmax[0][1].item() + + # Get possible answer spans + token_types = [ + "segment_ids", + "column_ids", + "row_ids", + "prev_labels", + "column_ranks", + "inv_column_ranks", + "numeric_relations", + ] + row_ids: List[int] = inputs.token_type_ids[:, :, token_types.index("row_ids")].tolist()[0] + column_ids: List[int] = inputs.token_type_ids[:, :, token_types.index("column_ids")].tolist()[0] + + possible_answer_spans: List[Tuple[int, int, int, int]] = [] # List of tuples: (row_idx, col_idx, start_token, end_token) + current_start_idx = -1 + current_column_id = -1 + for idx, (row_id, column_id) in enumerate(zip(row_ids, column_ids)): + if row_id == 0 or column_id == 0: + continue + # Beginning of new cell + if column_id != current_column_id: + if current_start_idx != -1: + possible_answer_spans.append( + (row_ids[current_start_idx]-1, column_ids[current_start_idx]-1, current_start_idx, idx-1) + ) + current_start_idx = idx + current_column_id = column_id + possible_answer_spans.append( + (row_ids[current_start_idx]-1, column_ids[current_start_idx]-1, current_start_idx, len(row_ids)-1) + ) + + # Concat logits of start token and end token of possible answer spans + sequence_output = outputs.last_hidden_state + concatenated_logits = [] + for possible_span in possible_answer_spans: + start_token_logits = sequence_output[0, possible_span[2], :] + end_token_logits = sequence_output[0, possible_span[3], :] + concatenated_logits.append(torch.cat((start_token_logits, end_token_logits))) + concatenated_logit_tensors = torch.unsqueeze(torch.stack(concatenated_logits), dim=0) + + # Calculate score for each possible span + span_logits = torch.einsum("bsj,j->bs", concatenated_logit_tensors, self.model.span_output_weights) \ + + self.model.span_output_bias + span_logits_softmax = torch.nn.functional.softmax(span_logits, dim=1) + + top_k_answer_spans = torch.topk(span_logits[0], min(self.top_k_per_candidate, len(possible_answer_spans))) + + answers = [] + for answer_span_idx in top_k_answer_spans.indices: + current_answer_span = possible_answer_spans[answer_span_idx] + answer_str = table.iat[current_answer_span[:2]] + answer_offsets = self._calculate_answer_offsets([current_answer_span[:2]], document.content) + # As the general table score is more important for the final score, it is double weighted. + current_score = ((2 * table_relevancy_prob) + span_logits_softmax[0, answer_span_idx].item()) / 3 + + answers.append( + Answer( + answer=answer_str, + type="extractive", + score=current_score, + context=document.content, + offsets_in_document=answer_offsets, + offsets_in_context=answer_offsets, + document_id=document.id, + meta={"aggregation_operator": "NONE", + "answer_cells": table.iat[current_answer_span[:2]]} + ) + ) + + no_answer_score = 1 - table_relevancy_prob + + return answers, no_answer_score def _calculate_answer_score(self, logits: torch.Tensor, inputs: BatchEncoding, answer_coordinates: List[Tuple[int, int]]) -> float: @@ -253,6 +388,27 @@ class TableReader(BaseReader): def predict_batch(self, query_doc_list: List[dict], top_k: Optional[int] = None, batch_size: Optional[int] = None): raise NotImplementedError("Batch prediction not yet available in TableReader.") + class TapasForScoredQA(TapasPreTrainedModel): + + def __init__(self, config): + super().__init__(config) + + # base model + self.tapas = TapasModel(config) + + # dropout (only used when training) + self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) + + # answer selection head + self.span_output_weights = torch.nn.Parameter(torch.zeros(2 * config.hidden_size)) + self.span_output_bias = torch.nn.Parameter(torch.zeros([])) + + # table scoring head + self.classifier = torch.nn.Linear(config.hidden_size, 2) + + # Initialize weights + self.init_weights() + class RCIReader(BaseReader): """