mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-03 03:09:28 +00:00
Add Tapas reader with scores (#1997)
* Add Tapas reader with scores * Adapt possible answer spans * Add latest docstring and tutorial changes * Remove unused imports * Adapt scoring * Add latest docstring and tutorial changes * Fix mypy * Infer model architecture from config * Adapt answer score calculation * Add latest docstring and tutorial changes * Fix mypy Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
ee6b8d0688
commit
bbb65a19bd
@ -630,7 +630,7 @@ answer = prediction["answers"][0].answer # "10 june 1996"
|
|||||||
#### \_\_init\_\_
|
#### \_\_init\_\_
|
||||||
|
|
||||||
```python
|
```python
|
||||||
| __init__(model_name_or_path: str = "google/tapas-base-finetuned-wtq", model_version: Optional[str] = None, tokenizer: Optional[str] = None, use_gpu: bool = True, top_k: int = 10, max_seq_len: int = 256)
|
| __init__(model_name_or_path: str = "google/tapas-base-finetuned-wtq", model_version: Optional[str] = None, tokenizer: Optional[str] = None, use_gpu: bool = True, top_k: int = 10, top_k_per_candidate: int = 3, return_no_answer: bool = False, max_seq_len: int = 256)
|
||||||
```
|
```
|
||||||
|
|
||||||
Load a TableQA model from Transformers.
|
Load a TableQA model from Transformers.
|
||||||
@ -638,18 +638,30 @@ Available models include:
|
|||||||
|
|
||||||
- ``'google/tapas-base-finetuned-wtq`'``
|
- ``'google/tapas-base-finetuned-wtq`'``
|
||||||
- ``'google/tapas-base-finetuned-wikisql-supervised``'
|
- ``'google/tapas-base-finetuned-wikisql-supervised``'
|
||||||
|
- ``'deepset/tapas-large-nq-hn-reader'``
|
||||||
|
- ``'deepset/tapas-large-nq-reader'``
|
||||||
|
|
||||||
See https://huggingface.co/models?pipeline_tag=table-question-answering
|
See https://huggingface.co/models?pipeline_tag=table-question-answering
|
||||||
for full list of available TableQA models.
|
for full list of available TableQA models.
|
||||||
|
|
||||||
|
The nq-reader models are able to provide confidence scores, but cannot handle questions that need aggregation
|
||||||
|
over multiple cells. The returned answers are sorted first by a general table score and then by answer span
|
||||||
|
scores.
|
||||||
|
All the other models can handle aggregation questions, but don't provide reasonable confidence scores.
|
||||||
|
|
||||||
**Arguments**:
|
**Arguments**:
|
||||||
|
|
||||||
- `model_name_or_path`: Directory of a saved model or the name of a public model e.g.
|
- `model_name_or_path`: Directory of a saved model or the name of a public model e.g.
|
||||||
See https://huggingface.co/models?pipeline_tag=table-question-answering for full list of available models.
|
See https://huggingface.co/models?pipeline_tag=table-question-answering for full list of available models.
|
||||||
- `model_version`: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
|
- `model_version`: The version of model to use from the HuggingFace model hub. Can be tag name, branch name,
|
||||||
|
or commit hash.
|
||||||
- `tokenizer`: Name of the tokenizer (usually the same as model)
|
- `tokenizer`: Name of the tokenizer (usually the same as model)
|
||||||
- `use_gpu`: Whether to use GPU or CPU. Falls back on CPU if no GPU is available.
|
- `use_gpu`: Whether to use GPU or CPU. Falls back on CPU if no GPU is available.
|
||||||
- `top_k`: The maximum number of answers to return
|
- `top_k`: The maximum number of answers to return
|
||||||
|
- `top_k_per_candidate`: How many answers to extract for each candidate table that is coming from
|
||||||
|
the retriever.
|
||||||
|
- `return_no_answer`: Whether to include no_answer predictions in the results.
|
||||||
|
(Only applicable with nq-reader models.)
|
||||||
- `max_seq_len`: Max sequence length of one input table for the model. If the number of tokens of
|
- `max_seq_len`: Max sequence length of one input table for the model. If the number of tokens of
|
||||||
query + table exceed max_seq_len, the table will be truncated by removing rows until the
|
query + table exceed max_seq_len, the table will be truncated by removing rows until the
|
||||||
input size fits the model.
|
input size fits the model.
|
||||||
|
|||||||
@ -6,7 +6,9 @@ import torch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from quantulum3 import parser
|
from quantulum3 import parser
|
||||||
from transformers import TapasTokenizer, TapasForQuestionAnswering, AutoTokenizer, AutoModelForSequenceClassification, BatchEncoding
|
from transformers import TapasTokenizer, TapasForQuestionAnswering, AutoTokenizer, AutoModelForSequenceClassification, \
|
||||||
|
BatchEncoding, TapasModel, TapasConfig
|
||||||
|
from transformers.models.tapas.modeling_tapas import TapasPreTrainedModel
|
||||||
|
|
||||||
from haystack.schema import Document, Answer, Span
|
from haystack.schema import Document, Answer, Span
|
||||||
from haystack.nodes.reader.base import BaseReader
|
from haystack.nodes.reader.base import BaseReader
|
||||||
@ -49,6 +51,8 @@ class TableReader(BaseReader):
|
|||||||
tokenizer: Optional[str] = None,
|
tokenizer: Optional[str] = None,
|
||||||
use_gpu: bool = True,
|
use_gpu: bool = True,
|
||||||
top_k: int = 10,
|
top_k: int = 10,
|
||||||
|
top_k_per_candidate: int = 3,
|
||||||
|
return_no_answer: bool = False,
|
||||||
max_seq_len: int = 256,
|
max_seq_len: int = 256,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -57,31 +61,54 @@ class TableReader(BaseReader):
|
|||||||
|
|
||||||
- ``'google/tapas-base-finetuned-wtq`'``
|
- ``'google/tapas-base-finetuned-wtq`'``
|
||||||
- ``'google/tapas-base-finetuned-wikisql-supervised``'
|
- ``'google/tapas-base-finetuned-wikisql-supervised``'
|
||||||
|
- ``'deepset/tapas-large-nq-hn-reader'``
|
||||||
|
- ``'deepset/tapas-large-nq-reader'``
|
||||||
|
|
||||||
See https://huggingface.co/models?pipeline_tag=table-question-answering
|
See https://huggingface.co/models?pipeline_tag=table-question-answering
|
||||||
for full list of available TableQA models.
|
for full list of available TableQA models.
|
||||||
|
|
||||||
|
The nq-reader models are able to provide confidence scores, but cannot handle questions that need aggregation
|
||||||
|
over multiple cells. The returned answers are sorted first by a general table score and then by answer span
|
||||||
|
scores.
|
||||||
|
All the other models can handle aggregation questions, but don't provide reasonable confidence scores.
|
||||||
|
|
||||||
:param model_name_or_path: Directory of a saved model or the name of a public model e.g.
|
:param model_name_or_path: Directory of a saved model or the name of a public model e.g.
|
||||||
See https://huggingface.co/models?pipeline_tag=table-question-answering for full list of available models.
|
See https://huggingface.co/models?pipeline_tag=table-question-answering for full list of available models.
|
||||||
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
|
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name,
|
||||||
|
or commit hash.
|
||||||
:param tokenizer: Name of the tokenizer (usually the same as model)
|
:param tokenizer: Name of the tokenizer (usually the same as model)
|
||||||
:param use_gpu: Whether to use GPU or CPU. Falls back on CPU if no GPU is available.
|
:param use_gpu: Whether to use GPU or CPU. Falls back on CPU if no GPU is available.
|
||||||
:param top_k: The maximum number of answers to return
|
:param top_k: The maximum number of answers to return
|
||||||
|
:param top_k_per_candidate: How many answers to extract for each candidate table that is coming from
|
||||||
|
the retriever.
|
||||||
|
:param return_no_answer: Whether to include no_answer predictions in the results.
|
||||||
|
(Only applicable with nq-reader models.)
|
||||||
:param max_seq_len: Max sequence length of one input table for the model. If the number of tokens of
|
:param max_seq_len: Max sequence length of one input table for the model. If the number of tokens of
|
||||||
query + table exceed max_seq_len, the table will be truncated by removing rows until the
|
query + table exceed max_seq_len, the table will be truncated by removing rows until the
|
||||||
input size fits the model.
|
input size fits the model.
|
||||||
"""
|
"""
|
||||||
|
# Save init parameters to enable export of component config as YAML
|
||||||
|
self.set_config(model_name_or_path=model_name_or_path, model_version=model_version, tokenizer=tokenizer,
|
||||||
|
use_gpu=use_gpu, top_k=top_k, top_k_per_candidate=top_k_per_candidate,
|
||||||
|
return_no_answer=return_no_answer, max_seq_len=max_seq_len)
|
||||||
|
|
||||||
self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False)
|
self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False)
|
||||||
self.model = TapasForQuestionAnswering.from_pretrained(model_name_or_path, revision=model_version)
|
config = TapasConfig.from_pretrained(model_name_or_path)
|
||||||
|
if config.architectures[0] == "TapasForScoredQA":
|
||||||
|
self.model = self.TapasForScoredQA.from_pretrained(model_name_or_path, revision=model_version)
|
||||||
|
else:
|
||||||
|
self.model = TapasForQuestionAnswering.from_pretrained(model_name_or_path, revision=model_version)
|
||||||
self.model.to(str(self.devices[0]))
|
self.model.to(str(self.devices[0]))
|
||||||
|
|
||||||
if tokenizer is None:
|
if tokenizer is None:
|
||||||
self.tokenizer = TapasTokenizer.from_pretrained(model_name_or_path)
|
self.tokenizer = TapasTokenizer.from_pretrained(model_name_or_path)
|
||||||
else:
|
else:
|
||||||
self.tokenizer = TapasTokenizer.from_pretrained(tokenizer)
|
self.tokenizer = TapasTokenizer.from_pretrained(tokenizer)
|
||||||
|
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
|
self.top_k_per_candidate = top_k_per_candidate
|
||||||
self.max_seq_len = max_seq_len
|
self.max_seq_len = max_seq_len
|
||||||
self.return_no_answers = False
|
self.return_no_answer = return_no_answer
|
||||||
|
|
||||||
def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> Dict:
|
def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> Dict:
|
||||||
"""
|
"""
|
||||||
@ -102,6 +129,7 @@ class TableReader(BaseReader):
|
|||||||
top_k = self.top_k
|
top_k = self.top_k
|
||||||
|
|
||||||
answers = []
|
answers = []
|
||||||
|
no_answer_score = 1.0
|
||||||
for document in documents:
|
for document in documents:
|
||||||
if document.content_type != "table":
|
if document.content_type != "table":
|
||||||
logger.warning(f"Skipping document with id {document.id} in TableReader, as it is not of type table.")
|
logger.warning(f"Skipping document with id {document.id} in TableReader, as it is not of type table.")
|
||||||
@ -115,61 +143,27 @@ class TableReader(BaseReader):
|
|||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
truncation=True)
|
truncation=True)
|
||||||
inputs.to(self.devices[0])
|
inputs.to(self.devices[0])
|
||||||
# Forward query and table through model and convert logits to predictions
|
|
||||||
outputs = self.model(**inputs)
|
|
||||||
inputs.to("cpu")
|
|
||||||
if self.model.config.num_aggregation_labels > 0:
|
|
||||||
aggregation_logits = outputs.logits_aggregation.cpu().detach()
|
|
||||||
else:
|
|
||||||
aggregation_logits = None
|
|
||||||
|
|
||||||
predicted_output = self.tokenizer.convert_logits_to_predictions(
|
if isinstance(self.model, TapasForQuestionAnswering):
|
||||||
inputs,
|
current_answer = self._predict_tapas_for_qa(inputs, document)
|
||||||
outputs.logits.cpu().detach(),
|
answers.append(current_answer)
|
||||||
aggregation_logits
|
elif isinstance(self.model, self.TapasForScoredQA):
|
||||||
)
|
current_answers, current_no_answer_score = self._predict_tapas_for_scored_qa(inputs, document)
|
||||||
if len(predicted_output) == 1:
|
answers.extend(current_answers)
|
||||||
predicted_answer_coordinates = predicted_output[0]
|
if current_no_answer_score < no_answer_score:
|
||||||
else:
|
no_answer_score = current_no_answer_score
|
||||||
predicted_answer_coordinates, predicted_aggregation_indices = predicted_output
|
|
||||||
|
|
||||||
# Get cell values
|
if self.return_no_answer and isinstance(self.model, self.TapasForScoredQA):
|
||||||
current_answer_coordinates = predicted_answer_coordinates[0]
|
answers.append(Answer(
|
||||||
current_answer_cells = []
|
answer="",
|
||||||
for coordinate in current_answer_coordinates:
|
type="extractive",
|
||||||
current_answer_cells.append(table.iat[coordinate])
|
score=no_answer_score,
|
||||||
|
context=None,
|
||||||
# Get aggregation operator
|
offsets_in_context=[Span(start=0, end=0)],
|
||||||
if self.model.config.aggregation_labels is not None:
|
offsets_in_document=[Span(start=0, end=0)],
|
||||||
current_aggregation_operator = self.model.config.aggregation_labels[predicted_aggregation_indices[0]]
|
document_id=None,
|
||||||
else:
|
meta=None
|
||||||
current_aggregation_operator = "NONE"
|
))
|
||||||
|
|
||||||
# Calculate answer score
|
|
||||||
current_score = self._calculate_answer_score(outputs.logits.cpu().detach(), inputs, current_answer_coordinates)
|
|
||||||
|
|
||||||
if current_aggregation_operator == "NONE":
|
|
||||||
answer_str = ", ".join(current_answer_cells)
|
|
||||||
else:
|
|
||||||
answer_str = self._aggregate_answers(current_aggregation_operator, current_answer_cells)
|
|
||||||
|
|
||||||
answer_offsets = self._calculate_answer_offsets(current_answer_coordinates, table)
|
|
||||||
|
|
||||||
answers.append(
|
|
||||||
Answer(
|
|
||||||
answer=answer_str,
|
|
||||||
type="extractive",
|
|
||||||
score=current_score,
|
|
||||||
context=table,
|
|
||||||
offsets_in_document=answer_offsets,
|
|
||||||
offsets_in_context=answer_offsets,
|
|
||||||
document_id=document.id,
|
|
||||||
meta={"aggregation_operator": current_aggregation_operator,
|
|
||||||
"answer_cells": current_answer_cells}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Sort answers by score and select top-k answers
|
|
||||||
answers = sorted(answers, reverse=True)
|
answers = sorted(answers, reverse=True)
|
||||||
answers = answers[:top_k]
|
answers = answers[:top_k]
|
||||||
|
|
||||||
@ -177,6 +171,147 @@ class TableReader(BaseReader):
|
|||||||
"answers": answers}
|
"answers": answers}
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
def _predict_tapas_for_qa(self, inputs: BatchEncoding, document: Document) -> Answer:
|
||||||
|
table: pd.DataFrame = document.content
|
||||||
|
|
||||||
|
# Forward query and table through model and convert logits to predictions
|
||||||
|
outputs = self.model(**inputs)
|
||||||
|
inputs.to("cpu")
|
||||||
|
if self.model.config.num_aggregation_labels > 0:
|
||||||
|
aggregation_logits = outputs.logits_aggregation.cpu().detach()
|
||||||
|
else:
|
||||||
|
aggregation_logits = None
|
||||||
|
|
||||||
|
predicted_output = self.tokenizer.convert_logits_to_predictions(
|
||||||
|
inputs,
|
||||||
|
outputs.logits.cpu().detach(),
|
||||||
|
aggregation_logits
|
||||||
|
)
|
||||||
|
if len(predicted_output) == 1:
|
||||||
|
predicted_answer_coordinates = predicted_output[0]
|
||||||
|
else:
|
||||||
|
predicted_answer_coordinates, predicted_aggregation_indices = predicted_output
|
||||||
|
|
||||||
|
# Get cell values
|
||||||
|
current_answer_coordinates = predicted_answer_coordinates[0]
|
||||||
|
current_answer_cells = []
|
||||||
|
for coordinate in current_answer_coordinates:
|
||||||
|
current_answer_cells.append(table.iat[coordinate])
|
||||||
|
|
||||||
|
# Get aggregation operator
|
||||||
|
if self.model.config.aggregation_labels is not None:
|
||||||
|
current_aggregation_operator = self.model.config.aggregation_labels[predicted_aggregation_indices[0]]
|
||||||
|
else:
|
||||||
|
current_aggregation_operator = "NONE"
|
||||||
|
|
||||||
|
# Calculate answer score
|
||||||
|
current_score = self._calculate_answer_score(outputs.logits.cpu().detach(), inputs, current_answer_coordinates)
|
||||||
|
|
||||||
|
if current_aggregation_operator == "NONE":
|
||||||
|
answer_str = ", ".join(current_answer_cells)
|
||||||
|
else:
|
||||||
|
answer_str = self._aggregate_answers(current_aggregation_operator, current_answer_cells)
|
||||||
|
|
||||||
|
answer_offsets = self._calculate_answer_offsets(current_answer_coordinates, document.content)
|
||||||
|
|
||||||
|
answer = Answer(
|
||||||
|
answer=answer_str,
|
||||||
|
type="extractive",
|
||||||
|
score=current_score,
|
||||||
|
context=document.content,
|
||||||
|
offsets_in_document=answer_offsets,
|
||||||
|
offsets_in_context=answer_offsets,
|
||||||
|
document_id=document.id,
|
||||||
|
meta={"aggregation_operator": current_aggregation_operator,
|
||||||
|
"answer_cells": current_answer_cells}
|
||||||
|
)
|
||||||
|
|
||||||
|
return answer
|
||||||
|
|
||||||
|
def _predict_tapas_for_scored_qa(self, inputs: BatchEncoding, document: Document) -> Tuple[List[Answer], float]:
|
||||||
|
table: pd.DataFrame = document.content
|
||||||
|
|
||||||
|
# Forward pass through model
|
||||||
|
outputs = self.model.tapas(**inputs)
|
||||||
|
|
||||||
|
# Get general table score
|
||||||
|
table_score = self.model.classifier(outputs.pooler_output)
|
||||||
|
table_score_softmax = torch.nn.functional.softmax(table_score, dim=1)
|
||||||
|
table_relevancy_prob = table_score_softmax[0][1].item()
|
||||||
|
|
||||||
|
# Get possible answer spans
|
||||||
|
token_types = [
|
||||||
|
"segment_ids",
|
||||||
|
"column_ids",
|
||||||
|
"row_ids",
|
||||||
|
"prev_labels",
|
||||||
|
"column_ranks",
|
||||||
|
"inv_column_ranks",
|
||||||
|
"numeric_relations",
|
||||||
|
]
|
||||||
|
row_ids: List[int] = inputs.token_type_ids[:, :, token_types.index("row_ids")].tolist()[0]
|
||||||
|
column_ids: List[int] = inputs.token_type_ids[:, :, token_types.index("column_ids")].tolist()[0]
|
||||||
|
|
||||||
|
possible_answer_spans: List[Tuple[int, int, int, int]] = [] # List of tuples: (row_idx, col_idx, start_token, end_token)
|
||||||
|
current_start_idx = -1
|
||||||
|
current_column_id = -1
|
||||||
|
for idx, (row_id, column_id) in enumerate(zip(row_ids, column_ids)):
|
||||||
|
if row_id == 0 or column_id == 0:
|
||||||
|
continue
|
||||||
|
# Beginning of new cell
|
||||||
|
if column_id != current_column_id:
|
||||||
|
if current_start_idx != -1:
|
||||||
|
possible_answer_spans.append(
|
||||||
|
(row_ids[current_start_idx]-1, column_ids[current_start_idx]-1, current_start_idx, idx-1)
|
||||||
|
)
|
||||||
|
current_start_idx = idx
|
||||||
|
current_column_id = column_id
|
||||||
|
possible_answer_spans.append(
|
||||||
|
(row_ids[current_start_idx]-1, column_ids[current_start_idx]-1, current_start_idx, len(row_ids)-1)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Concat logits of start token and end token of possible answer spans
|
||||||
|
sequence_output = outputs.last_hidden_state
|
||||||
|
concatenated_logits = []
|
||||||
|
for possible_span in possible_answer_spans:
|
||||||
|
start_token_logits = sequence_output[0, possible_span[2], :]
|
||||||
|
end_token_logits = sequence_output[0, possible_span[3], :]
|
||||||
|
concatenated_logits.append(torch.cat((start_token_logits, end_token_logits)))
|
||||||
|
concatenated_logit_tensors = torch.unsqueeze(torch.stack(concatenated_logits), dim=0)
|
||||||
|
|
||||||
|
# Calculate score for each possible span
|
||||||
|
span_logits = torch.einsum("bsj,j->bs", concatenated_logit_tensors, self.model.span_output_weights) \
|
||||||
|
+ self.model.span_output_bias
|
||||||
|
span_logits_softmax = torch.nn.functional.softmax(span_logits, dim=1)
|
||||||
|
|
||||||
|
top_k_answer_spans = torch.topk(span_logits[0], min(self.top_k_per_candidate, len(possible_answer_spans)))
|
||||||
|
|
||||||
|
answers = []
|
||||||
|
for answer_span_idx in top_k_answer_spans.indices:
|
||||||
|
current_answer_span = possible_answer_spans[answer_span_idx]
|
||||||
|
answer_str = table.iat[current_answer_span[:2]]
|
||||||
|
answer_offsets = self._calculate_answer_offsets([current_answer_span[:2]], document.content)
|
||||||
|
# As the general table score is more important for the final score, it is double weighted.
|
||||||
|
current_score = ((2 * table_relevancy_prob) + span_logits_softmax[0, answer_span_idx].item()) / 3
|
||||||
|
|
||||||
|
answers.append(
|
||||||
|
Answer(
|
||||||
|
answer=answer_str,
|
||||||
|
type="extractive",
|
||||||
|
score=current_score,
|
||||||
|
context=document.content,
|
||||||
|
offsets_in_document=answer_offsets,
|
||||||
|
offsets_in_context=answer_offsets,
|
||||||
|
document_id=document.id,
|
||||||
|
meta={"aggregation_operator": "NONE",
|
||||||
|
"answer_cells": table.iat[current_answer_span[:2]]}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
no_answer_score = 1 - table_relevancy_prob
|
||||||
|
|
||||||
|
return answers, no_answer_score
|
||||||
|
|
||||||
def _calculate_answer_score(self, logits: torch.Tensor, inputs: BatchEncoding,
|
def _calculate_answer_score(self, logits: torch.Tensor, inputs: BatchEncoding,
|
||||||
answer_coordinates: List[Tuple[int, int]]) -> float:
|
answer_coordinates: List[Tuple[int, int]]) -> float:
|
||||||
@ -253,6 +388,27 @@ class TableReader(BaseReader):
|
|||||||
def predict_batch(self, query_doc_list: List[dict], top_k: Optional[int] = None, batch_size: Optional[int] = None):
|
def predict_batch(self, query_doc_list: List[dict], top_k: Optional[int] = None, batch_size: Optional[int] = None):
|
||||||
raise NotImplementedError("Batch prediction not yet available in TableReader.")
|
raise NotImplementedError("Batch prediction not yet available in TableReader.")
|
||||||
|
|
||||||
|
class TapasForScoredQA(TapasPreTrainedModel):
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
# base model
|
||||||
|
self.tapas = TapasModel(config)
|
||||||
|
|
||||||
|
# dropout (only used when training)
|
||||||
|
self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
|
# answer selection head
|
||||||
|
self.span_output_weights = torch.nn.Parameter(torch.zeros(2 * config.hidden_size))
|
||||||
|
self.span_output_bias = torch.nn.Parameter(torch.zeros([]))
|
||||||
|
|
||||||
|
# table scoring head
|
||||||
|
self.classifier = torch.nn.Linear(config.hidden_size, 2)
|
||||||
|
|
||||||
|
# Initialize weights
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
|
||||||
class RCIReader(BaseReader):
|
class RCIReader(BaseReader):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user