mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-05 19:47:45 +00:00
parent
59857cb492
commit
8db7dfb884
@ -1,5 +1,11 @@
|
||||
from abc import abstractmethod
|
||||
from typing import List, Optional, Tuple, Dict, Union
|
||||
|
||||
try:
|
||||
from typing import Literal
|
||||
except ImportError:
|
||||
from typing_extensions import Literal # type: ignore
|
||||
|
||||
import logging
|
||||
from statistics import mean
|
||||
import torch
|
||||
@ -98,7 +104,7 @@ class TableReader(BaseReader):
|
||||
or commit hash.
|
||||
:param tokenizer: Name of the tokenizer (usually the same as model)
|
||||
:param use_gpu: Whether to use GPU or CPU. Falls back on CPU if no GPU is available.
|
||||
:param top_k: The maximum number of answers to return
|
||||
:param top_k: The maximum number of answers to return.
|
||||
:param top_k_per_candidate: How many answers to extract for each candidate table that is coming from
|
||||
the retriever.
|
||||
:param return_no_answer: Whether to include no_answer predictions in the results.
|
||||
@ -128,27 +134,40 @@ class TableReader(BaseReader):
|
||||
super().__init__()
|
||||
|
||||
self.devices, _ = initialize_device_settings(devices=devices, use_cuda=use_gpu, multi_gpu=False)
|
||||
config = TapasConfig.from_pretrained(model_name_or_path, use_auth_token=use_auth_token)
|
||||
if len(self.devices) > 1:
|
||||
logger.warning(
|
||||
f"Multiple devices are not supported in {self.__class__.__name__} inference, "
|
||||
f"using the first device {self.devices[0]}."
|
||||
)
|
||||
|
||||
if config.architectures[0] == "TapasForScoredQA":
|
||||
self.model = self.TapasForScoredQA.from_pretrained(
|
||||
model_name_or_path, revision=model_version, use_auth_token=use_auth_token
|
||||
config = TapasConfig.from_pretrained(model_name_or_path, use_auth_token=use_auth_token)
|
||||
self.table_encoder: Union[_TapasEncoder, _TapasScoredEncoder]
|
||||
if config.architectures[0] == "TapasForQuestionAnswering":
|
||||
self.table_encoder = _TapasEncoder(
|
||||
device=self.devices[0],
|
||||
model_name_or_path=model_name_or_path,
|
||||
model_version=model_version,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_len=max_seq_len,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
elif config.architectures[0] == "TapasForScoredQA":
|
||||
self.table_encoder = _TapasScoredEncoder(
|
||||
device=self.devices[0],
|
||||
model_name_or_path=model_name_or_path,
|
||||
model_version=model_version,
|
||||
tokenizer=tokenizer,
|
||||
top_k_per_candidate=top_k_per_candidate,
|
||||
return_no_answer=return_no_answer,
|
||||
max_seq_len=max_seq_len,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
else:
|
||||
self.model = TapasForQuestionAnswering.from_pretrained(
|
||||
model_name_or_path, revision=model_version, use_auth_token=use_auth_token
|
||||
logger.error(
|
||||
"Unrecognized model architecture %s. Only the architectures TapasForQuestionAnswering and TapasForScoredQA are supported",
|
||||
config.architectures[0],
|
||||
)
|
||||
self.model.to(str(self.devices[0]))
|
||||
|
||||
if tokenizer is None:
|
||||
self.tokenizer = TapasTokenizer.from_pretrained(model_name_or_path, use_auth_token=use_auth_token)
|
||||
else:
|
||||
self.tokenizer = TapasTokenizer.from_pretrained(tokenizer, use_auth_token=use_auth_token)
|
||||
self.table_encoder.model.to(str(self.devices[0]))
|
||||
|
||||
self.top_k = top_k
|
||||
self.top_k_per_candidate = top_k_per_candidate
|
||||
@ -172,9 +191,93 @@ class TableReader(BaseReader):
|
||||
"""
|
||||
if top_k is None:
|
||||
top_k = self.top_k
|
||||
return self.table_encoder.predict(query=query, documents=documents, top_k=top_k)
|
||||
|
||||
answers = []
|
||||
no_answer_score = 1.0
|
||||
def predict_batch(
|
||||
self,
|
||||
queries: List[str],
|
||||
documents: Union[List[Document], List[List[Document]]],
|
||||
top_k: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Use loaded TableQA model to find answers for the supplied queries in the supplied Documents
|
||||
of content_type ``'table'``.
|
||||
|
||||
Returns dictionary containing query and list of Answer objects sorted by (desc.) score.
|
||||
WARNING: The answer scores are not reliable, as they are always extremely high, even if
|
||||
a question cannot be answered by a given table.
|
||||
|
||||
- If you provide a list containing a single query...
|
||||
- ... and a single list of Documents, the query will be applied to each Document individually.
|
||||
- ... and a list of lists of Documents, the query will be applied to each list of Documents and the Answers
|
||||
will be aggregated per Document list.
|
||||
|
||||
- If you provide a list of multiple queries...
|
||||
- ... and a single list of Documents, each query will be applied to each Document individually.
|
||||
- ... and a list of lists of Documents, each query will be applied to its corresponding list of Documents
|
||||
and the Answers will be aggregated per query-Document pair.
|
||||
|
||||
:param queries: Single query string or list of queries.
|
||||
:param documents: Single list of Documents or list of lists of Documents in which to search for the answers.
|
||||
Documents should be of content_type ``'table'``.
|
||||
:param top_k: The maximum number of answers to return per query.
|
||||
:param batch_size: Not applicable.
|
||||
"""
|
||||
results: Dict = {"queries": queries, "answers": []}
|
||||
|
||||
single_doc_list = False
|
||||
# Docs case 1: single list of Documents -> apply each query to all Documents
|
||||
if len(documents) > 0 and isinstance(documents[0], Document):
|
||||
single_doc_list = True
|
||||
for query in queries:
|
||||
for doc in documents:
|
||||
if not isinstance(doc, Document):
|
||||
raise HaystackError(f"doc was of type {type(doc)}, but expected a Document.")
|
||||
preds = self.predict(query=query, documents=[doc], top_k=top_k)
|
||||
results["answers"].append(preds["answers"])
|
||||
|
||||
# Docs case 2: list of lists of Documents -> apply each query to corresponding list of Documents, if queries
|
||||
# contains only one query, apply it to each list of Documents
|
||||
elif len(documents) > 0 and isinstance(documents[0], list):
|
||||
if len(queries) == 1:
|
||||
queries = queries * len(documents)
|
||||
if len(queries) != len(documents):
|
||||
raise HaystackError("Number of queries must be equal to number of provided Document lists.")
|
||||
for query, cur_docs in zip(queries, documents):
|
||||
if not isinstance(cur_docs, list):
|
||||
raise HaystackError(f"cur_docs was of type {type(cur_docs)}, but expected a list of Documents.")
|
||||
preds = self.predict(query=query, documents=cur_docs, top_k=top_k)
|
||||
results["answers"].append(preds["answers"])
|
||||
|
||||
# Group answers by question in case of multiple queries and single doc list
|
||||
if single_doc_list and len(queries) > 1:
|
||||
answers_per_query = int(len(results["answers"]) / len(queries))
|
||||
answers = []
|
||||
for i in range(0, len(results["answers"]), answers_per_query):
|
||||
answer_group = results["answers"][i : i + answers_per_query]
|
||||
answers.append(answer_group)
|
||||
results["answers"] = answers
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class _BaseTapasEncoder:
|
||||
@staticmethod
|
||||
def _calculate_answer_offsets(answer_coordinates: List[Tuple[int, int]], table: pd.DataFrame) -> List[Span]:
|
||||
"""
|
||||
Calculates the answer cell offsets of the linearized table based on the answer cell coordinates.
|
||||
"""
|
||||
answer_offsets = []
|
||||
n_rows, n_columns = table.shape
|
||||
for coord in answer_coordinates:
|
||||
answer_cell_offset = (coord[0] * n_columns) + coord[1]
|
||||
answer_offsets.append(Span(start=answer_cell_offset, end=answer_cell_offset + 1))
|
||||
return answer_offsets
|
||||
|
||||
@staticmethod
|
||||
def _check_documents(documents: List[Document]) -> List[Document]:
|
||||
table_documents = []
|
||||
for document in documents:
|
||||
if document.content_type != "table":
|
||||
logger.warning("Skipping document with id '%s' in TableReader as it is not of type table.", document.id)
|
||||
@ -186,59 +289,62 @@ class TableReader(BaseReader):
|
||||
"Skipping document with id '%s' in TableReader as it does not contain any rows.", document.id
|
||||
)
|
||||
continue
|
||||
# Tokenize query and current table
|
||||
inputs = self.tokenizer(
|
||||
table=table, queries=query, max_length=self.max_seq_len, return_tensors="pt", truncation=True
|
||||
)
|
||||
inputs.to(self.devices[0])
|
||||
|
||||
if isinstance(self.model, TapasForQuestionAnswering):
|
||||
current_answer = self._predict_tapas_for_qa(inputs, document)
|
||||
answers.append(current_answer)
|
||||
elif isinstance(self.model, self.TapasForScoredQA):
|
||||
current_answers, current_no_answer_score = self._predict_tapas_for_scored_qa(inputs, document)
|
||||
answers.extend(current_answers)
|
||||
if current_no_answer_score < no_answer_score:
|
||||
no_answer_score = current_no_answer_score
|
||||
table_documents.append(document)
|
||||
return table_documents
|
||||
|
||||
if self.return_no_answer and isinstance(self.model, self.TapasForScoredQA):
|
||||
answers.append(
|
||||
Answer(
|
||||
answer="",
|
||||
type="extractive",
|
||||
score=no_answer_score,
|
||||
context=None,
|
||||
offsets_in_context=[Span(start=0, end=0)],
|
||||
offsets_in_document=[Span(start=0, end=0)],
|
||||
document_id=None,
|
||||
meta=None,
|
||||
)
|
||||
)
|
||||
answers = sorted(answers, reverse=True)
|
||||
answers = answers[:top_k]
|
||||
@staticmethod
|
||||
def _preprocess(query: str, table: pd.DataFrame, tokenizer, max_seq_len) -> BatchEncoding:
|
||||
"""Tokenize the query and table."""
|
||||
model_inputs = tokenizer(
|
||||
table=table, queries=query, max_length=max_seq_len, return_tensors="pt", truncation=True
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
results = {"query": query, "answers": answers}
|
||||
@abstractmethod
|
||||
def predict(self, query: str, documents: List[Document], top_k: int) -> Dict:
|
||||
pass
|
||||
|
||||
return results
|
||||
|
||||
def _predict_tapas_for_qa(self, inputs: BatchEncoding, document: Document) -> Answer:
|
||||
class _TapasEncoder(_BaseTapasEncoder):
|
||||
def __init__(
|
||||
self,
|
||||
device: torch.device,
|
||||
model_name_or_path: str = "google/tapas-base-finetuned-wtq",
|
||||
model_version: Optional[str] = None,
|
||||
tokenizer: Optional[str] = None,
|
||||
max_seq_len: int = 256,
|
||||
use_auth_token: Optional[Union[str, bool]] = None,
|
||||
):
|
||||
self.model = TapasForQuestionAnswering.from_pretrained(
|
||||
model_name_or_path, revision=model_version, use_auth_token=use_auth_token
|
||||
)
|
||||
if tokenizer is None:
|
||||
self.tokenizer = TapasTokenizer.from_pretrained(model_name_or_path, use_auth_token=use_auth_token)
|
||||
else:
|
||||
self.tokenizer = TapasTokenizer.from_pretrained(tokenizer, use_auth_token=use_auth_token)
|
||||
self.max_seq_len = max_seq_len
|
||||
self.device = device
|
||||
|
||||
def _predict_tapas(self, inputs: BatchEncoding, document: Document) -> Answer:
|
||||
table: pd.DataFrame = document.content
|
||||
|
||||
# Forward query and table through model and convert logits to predictions
|
||||
outputs = self.model(**inputs)
|
||||
inputs.to("cpu")
|
||||
if self.model.config.num_aggregation_labels > 0:
|
||||
aggregation_logits = outputs.logits_aggregation.cpu().detach()
|
||||
else:
|
||||
aggregation_logits = None
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
|
||||
predicted_output = self.tokenizer.convert_logits_to_predictions(
|
||||
inputs, outputs.logits.cpu().detach(), aggregation_logits
|
||||
)
|
||||
if len(predicted_output) == 1:
|
||||
predicted_answer_coordinates = predicted_output[0]
|
||||
inputs.to("cpu")
|
||||
outputs_logits = outputs.logits.cpu()
|
||||
|
||||
if self.model.config.num_aggregation_labels > 0:
|
||||
aggregation_logits = outputs.logits_aggregation.cpu()
|
||||
predicted_answer_coordinates, predicted_aggregation_indices = self.tokenizer.convert_logits_to_predictions(
|
||||
inputs, outputs_logits, logits_agg=aggregation_logits, cell_classification_threshold=0.5
|
||||
)
|
||||
else:
|
||||
predicted_answer_coordinates, predicted_aggregation_indices = predicted_output
|
||||
predicted_answer_coordinates = self.tokenizer.convert_logits_to_predictions(
|
||||
inputs, outputs_logits, logits_agg=None, cell_classification_threshold=0.5
|
||||
)
|
||||
|
||||
# Get cell values
|
||||
current_answer_coordinates = predicted_answer_coordinates[0]
|
||||
@ -253,7 +359,7 @@ class TableReader(BaseReader):
|
||||
current_aggregation_operator = "NONE"
|
||||
|
||||
# Calculate answer score
|
||||
current_score = self._calculate_answer_score(outputs.logits.cpu().detach(), inputs, current_answer_coordinates)
|
||||
current_score = self._calculate_answer_score(outputs_logits, inputs, current_answer_coordinates)
|
||||
|
||||
if current_aggregation_operator == "NONE":
|
||||
answer_str = ", ".join(current_answer_cells)
|
||||
@ -272,19 +378,127 @@ class TableReader(BaseReader):
|
||||
document_id=document.id,
|
||||
meta={"aggregation_operator": current_aggregation_operator, "answer_cells": current_answer_cells},
|
||||
)
|
||||
|
||||
return answer
|
||||
|
||||
def _predict_tapas_for_scored_qa(self, inputs: BatchEncoding, document: Document) -> Tuple[List[Answer], float]:
|
||||
def _calculate_answer_score(
|
||||
self, logits: torch.Tensor, inputs: BatchEncoding, answer_coordinates: List[Tuple[int, int]]
|
||||
) -> float:
|
||||
# Calculate answer score
|
||||
# Values over 88.72284 will overflow when passed through exponential, so logits are truncated.
|
||||
logits[logits < -88.7] = -88.7
|
||||
token_probabilities = 1 / (1 + np.exp(-logits)) * inputs.attention_mask
|
||||
token_types = [
|
||||
"segment_ids",
|
||||
"column_ids",
|
||||
"row_ids",
|
||||
"prev_labels",
|
||||
"column_ranks",
|
||||
"inv_column_ranks",
|
||||
"numeric_relations",
|
||||
]
|
||||
|
||||
segment_ids = inputs.token_type_ids[0, :, token_types.index("segment_ids")].tolist()
|
||||
column_ids = inputs.token_type_ids[0, :, token_types.index("column_ids")].tolist()
|
||||
row_ids = inputs.token_type_ids[0, :, token_types.index("row_ids")].tolist()
|
||||
all_cell_probabilities = self.tokenizer._get_mean_cell_probs(
|
||||
token_probabilities[0].tolist(), segment_ids, row_ids, column_ids
|
||||
)
|
||||
# _get_mean_cell_probs seems to index cells by (col, row). DataFrames are, however, indexed by (row, col).
|
||||
all_cell_probabilities = {(row, col): prob for (col, row), prob in all_cell_probabilities.items()}
|
||||
answer_cell_probabilities = [all_cell_probabilities[coord] for coord in answer_coordinates]
|
||||
|
||||
return np.mean(answer_cell_probabilities)
|
||||
|
||||
@staticmethod
|
||||
def _aggregate_answers(agg_operator: Literal["COUNT", "SUM", "AVERAGE"], answer_cells: List[str]) -> str:
|
||||
if agg_operator == "COUNT":
|
||||
return str(len(answer_cells))
|
||||
|
||||
# No aggregation needed as only one cell selected as answer_cells
|
||||
if len(answer_cells) == 1:
|
||||
return answer_cells[0]
|
||||
# Return empty string if model did not select any cell as answer
|
||||
if len(answer_cells) == 0:
|
||||
return ""
|
||||
|
||||
# Parse answer cells in order to aggregate numerical values
|
||||
parsed_answer_cells = [parser.parse(cell) for cell in answer_cells]
|
||||
# Check if all cells contain at least one numerical value and that all values share the same unit
|
||||
try:
|
||||
if all(parsed_answer_cells) and all(
|
||||
cell[0].unit.name == parsed_answer_cells[0][0].unit.name for cell in parsed_answer_cells
|
||||
):
|
||||
numerical_values = [cell[0].value for cell in parsed_answer_cells]
|
||||
unit = parsed_answer_cells[0][0].unit.symbols[0] if parsed_answer_cells[0][0].unit.symbols else ""
|
||||
|
||||
if agg_operator == "SUM":
|
||||
answer_value = sum(numerical_values)
|
||||
elif agg_operator == "AVERAGE":
|
||||
answer_value = mean(numerical_values)
|
||||
else:
|
||||
raise ValueError("unknown aggregator")
|
||||
|
||||
return f"{answer_value}{' ' + unit if unit else ''}"
|
||||
|
||||
except ValueError as e:
|
||||
if "unknown aggregator" in str(e):
|
||||
pass
|
||||
|
||||
# Not all selected answer cells contain a numerical value or answer cells don't share the same unit
|
||||
return f"{agg_operator} > {', '.join(answer_cells)}"
|
||||
|
||||
def predict(self, query: str, documents: List[Document], top_k: int) -> Dict:
|
||||
answers = []
|
||||
table_documents = self._check_documents(documents)
|
||||
for document in table_documents:
|
||||
table: pd.DataFrame = document.content
|
||||
model_inputs = self._preprocess(query, table, self.tokenizer, self.max_seq_len)
|
||||
model_inputs.to(self.device)
|
||||
|
||||
current_answer = self._predict_tapas(model_inputs, document)
|
||||
answers.append(current_answer)
|
||||
|
||||
answers = sorted(answers, reverse=True)
|
||||
results = {"query": query, "answers": answers[:top_k]}
|
||||
return results
|
||||
|
||||
|
||||
class _TapasScoredEncoder(_BaseTapasEncoder):
|
||||
def __init__(
|
||||
self,
|
||||
device: torch.device,
|
||||
model_name_or_path: str = "deepset/tapas-large-nq-hn-reader",
|
||||
model_version: Optional[str] = None,
|
||||
tokenizer: Optional[str] = None,
|
||||
top_k_per_candidate: int = 3,
|
||||
return_no_answer: bool = False,
|
||||
max_seq_len: int = 256,
|
||||
use_auth_token: Optional[Union[str, bool]] = None,
|
||||
):
|
||||
self.model = self._TapasForScoredQA.from_pretrained(
|
||||
model_name_or_path, revision=model_version, use_auth_token=use_auth_token
|
||||
)
|
||||
if tokenizer is None:
|
||||
self.tokenizer = TapasTokenizer.from_pretrained(model_name_or_path, use_auth_token=use_auth_token)
|
||||
else:
|
||||
self.tokenizer = TapasTokenizer.from_pretrained(tokenizer, use_auth_token=use_auth_token)
|
||||
self.max_seq_len = max_seq_len
|
||||
self.device = device
|
||||
self.top_k_per_candidate = top_k_per_candidate
|
||||
self.return_no_answer = return_no_answer
|
||||
|
||||
def _predict_tapas_scored(self, inputs: BatchEncoding, document: Document) -> Tuple[List[Answer], float]:
|
||||
table: pd.DataFrame = document.content
|
||||
|
||||
# Forward pass through model
|
||||
outputs = self.model.tapas(**inputs)
|
||||
with torch.no_grad():
|
||||
outputs = self.model.tapas(**inputs)
|
||||
|
||||
# Get general table score
|
||||
table_score = self.model.classifier(outputs.pooler_output)
|
||||
table_score_softmax = torch.nn.functional.softmax(table_score, dim=1)
|
||||
table_relevancy_prob = table_score_softmax[0][1].item()
|
||||
no_answer_score = table_score_softmax[0][0].item()
|
||||
|
||||
# Get possible answer spans
|
||||
token_types = [
|
||||
@ -302,21 +516,31 @@ class TableReader(BaseReader):
|
||||
possible_answer_spans: List[
|
||||
Tuple[int, int, int, int]
|
||||
] = [] # List of tuples: (row_idx, col_idx, start_token, end_token)
|
||||
current_start_idx = -1
|
||||
current_start_token_idx = -1
|
||||
current_column_id = -1
|
||||
for idx, (row_id, column_id) in enumerate(zip(row_ids, column_ids)):
|
||||
for token_idx, (row_id, column_id) in enumerate(zip(row_ids, column_ids)):
|
||||
if row_id == 0 or column_id == 0:
|
||||
continue
|
||||
# Beginning of new cell
|
||||
if column_id != current_column_id:
|
||||
if current_start_idx != -1:
|
||||
if current_start_token_idx != -1:
|
||||
possible_answer_spans.append(
|
||||
(row_ids[current_start_idx] - 1, column_ids[current_start_idx] - 1, current_start_idx, idx - 1)
|
||||
(
|
||||
row_ids[current_start_token_idx] - 1,
|
||||
column_ids[current_start_token_idx] - 1,
|
||||
current_start_token_idx,
|
||||
token_idx - 1,
|
||||
)
|
||||
)
|
||||
current_start_idx = idx
|
||||
current_start_token_idx = token_idx
|
||||
current_column_id = column_id
|
||||
possible_answer_spans.append(
|
||||
(row_ids[current_start_idx] - 1, column_ids[current_start_idx] - 1, current_start_idx, len(row_ids) - 1)
|
||||
(
|
||||
row_ids[current_start_token_idx] - 1,
|
||||
column_ids[current_start_token_idx] - 1,
|
||||
current_start_token_idx,
|
||||
len(row_ids) - 1,
|
||||
)
|
||||
)
|
||||
|
||||
# Concat logits of start token and end token of possible answer spans
|
||||
@ -358,160 +582,41 @@ class TableReader(BaseReader):
|
||||
)
|
||||
)
|
||||
|
||||
no_answer_score = 1 - table_relevancy_prob
|
||||
|
||||
return answers, no_answer_score
|
||||
|
||||
def _calculate_answer_score(
|
||||
self, logits: torch.Tensor, inputs: BatchEncoding, answer_coordinates: List[Tuple[int, int]]
|
||||
) -> float:
|
||||
"""
|
||||
Calculates the answer score by computing each cell's probability of being part of the answer
|
||||
and taking the mean probability of the answer cells.
|
||||
"""
|
||||
# Calculate answer score
|
||||
# Values over 88.72284 will overflow when passed through exponential, so logits are truncated.
|
||||
logits[logits < -88.7] = -88.7
|
||||
token_probabilities = 1 / (1 + np.exp(-logits)) * inputs.attention_mask
|
||||
def predict(self, query: str, documents: List[Document], top_k: int) -> Dict:
|
||||
answers = []
|
||||
no_answer_score = 1.0
|
||||
table_documents = self._check_documents(documents)
|
||||
for document in table_documents:
|
||||
table: pd.DataFrame = document.content
|
||||
model_inputs = self._preprocess(query, table, self.tokenizer, self.max_seq_len)
|
||||
model_inputs.to(self.device)
|
||||
|
||||
segment_ids = inputs.token_type_ids[0, :, 0].tolist()
|
||||
column_ids = inputs.token_type_ids[0, :, 1].tolist()
|
||||
row_ids = inputs.token_type_ids[0, :, 2].tolist()
|
||||
all_cell_probabilities = self.tokenizer._get_mean_cell_probs(
|
||||
token_probabilities[0].tolist(), segment_ids, row_ids, column_ids
|
||||
)
|
||||
# _get_mean_cell_probs seems to index cells by (col, row). DataFrames are, however, indexed by (row, col).
|
||||
all_cell_probabilities = {(row, col): prob for (col, row), prob in all_cell_probabilities.items()}
|
||||
answer_cell_probabilities = [all_cell_probabilities[coord] for coord in answer_coordinates]
|
||||
current_answers, current_no_answer_score = self._predict_tapas_scored(model_inputs, document)
|
||||
answers.extend(current_answers)
|
||||
if current_no_answer_score < no_answer_score:
|
||||
no_answer_score = current_no_answer_score
|
||||
|
||||
return np.mean(answer_cell_probabilities)
|
||||
|
||||
@staticmethod
|
||||
def _aggregate_answers(agg_operator: str, answer_cells: List[str]) -> str:
|
||||
if agg_operator == "COUNT":
|
||||
return str(len(answer_cells))
|
||||
|
||||
# No aggregation needed as only one cell selected as answer_cells
|
||||
if len(answer_cells) == 1:
|
||||
return answer_cells[0]
|
||||
# Return empty string if model did not select any cell as answer
|
||||
if len(answer_cells) == 0:
|
||||
return ""
|
||||
|
||||
# Parse answer cells in order to aggregate numerical values
|
||||
parsed_answer_cells = [parser.parse(cell) for cell in answer_cells]
|
||||
# Check if all cells contain at least one numerical value and that all values share the same unit
|
||||
try:
|
||||
if all(parsed_answer_cells) and all(
|
||||
cell[0].unit.name == parsed_answer_cells[0][0].unit.name for cell in parsed_answer_cells
|
||||
):
|
||||
numerical_values = [cell[0].value for cell in parsed_answer_cells]
|
||||
unit = parsed_answer_cells[0][0].unit.symbols[0] if parsed_answer_cells[0][0].unit.symbols else ""
|
||||
|
||||
if agg_operator == "SUM":
|
||||
answer_value = sum(numerical_values)
|
||||
elif agg_operator == "AVERAGE":
|
||||
answer_value = mean(numerical_values)
|
||||
else:
|
||||
raise KeyError("unknown aggregator")
|
||||
|
||||
return f"{answer_value}{' ' + unit if unit else ''}"
|
||||
|
||||
except KeyError as e:
|
||||
if "unknown aggregator" in str(e):
|
||||
pass
|
||||
|
||||
# Not all selected answer cells contain a numerical value or answer cells don't share the same unit
|
||||
return f"{agg_operator} > {', '.join(answer_cells)}"
|
||||
|
||||
@staticmethod
|
||||
def _calculate_answer_offsets(answer_coordinates: List[Tuple[int, int]], table: pd.DataFrame) -> List[Span]:
|
||||
"""
|
||||
Calculates the answer cell offsets of the linearized table based on the
|
||||
answer cell coordinates.
|
||||
"""
|
||||
answer_offsets = []
|
||||
n_rows, n_columns = table.shape
|
||||
for coord in answer_coordinates:
|
||||
answer_cell_offset = (coord[0] * n_columns) + coord[1]
|
||||
answer_offsets.append(Span(start=answer_cell_offset, end=answer_cell_offset + 1))
|
||||
|
||||
return answer_offsets
|
||||
|
||||
def predict_batch(
|
||||
self,
|
||||
queries: List[str],
|
||||
documents: Union[List[Document], List[List[Document]]],
|
||||
top_k: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Use loaded TableQA model to find answers for the supplied queries in the supplied Documents
|
||||
of content_type ``'table'``.
|
||||
|
||||
Returns dictionary containing query and list of Answer objects sorted by (desc.) score.
|
||||
|
||||
WARNING: The answer scores are not reliable, as they are always extremely high, even if
|
||||
a question cannot be answered by a given table.
|
||||
|
||||
- If you provide a list containing a single query...
|
||||
|
||||
- ... and a single list of Documents, the query will be applied to each Document individually.
|
||||
- ... and a list of lists of Documents, the query will be applied to each list of Documents and the Answers
|
||||
will be aggregated per Document list.
|
||||
|
||||
- If you provide a list of multiple queries...
|
||||
|
||||
- ... and a single list of Documents, each query will be applied to each Document individually.
|
||||
- ... and a list of lists of Documents, each query will be applied to its corresponding list of Documents
|
||||
and the Answers will be aggregated per query-Document pair.
|
||||
|
||||
:param queries: Single query string or list of queries.
|
||||
:param documents: Single list of Documents or list of lists of Documents in which to search for the answers.
|
||||
Documents should be of content_type ``'table'``.
|
||||
:param top_k: The maximum number of answers to return per query.
|
||||
:param batch_size: Not applicable.
|
||||
"""
|
||||
# TODO: This method currently just calls the predict method multiple times, so there is room for improvement.
|
||||
|
||||
results: Dict = {"queries": queries, "answers": []}
|
||||
|
||||
single_doc_list = False
|
||||
# Docs case 1: single list of Documents -> apply each query to all Documents
|
||||
if len(documents) > 0 and isinstance(documents[0], Document):
|
||||
single_doc_list = True
|
||||
for query in queries:
|
||||
for doc in documents:
|
||||
if not isinstance(doc, Document):
|
||||
raise HaystackError(f"doc was of type {type(doc)}, but expected a Document.")
|
||||
preds = self.predict(query=query, documents=[doc], top_k=top_k)
|
||||
results["answers"].append(preds["answers"])
|
||||
|
||||
# Docs case 2: list of lists of Documents -> apply each query to corresponding list of Documents, if queries
|
||||
# contains only one query, apply it to each list of Documents
|
||||
elif len(documents) > 0 and isinstance(documents[0], list):
|
||||
if len(queries) == 1:
|
||||
queries = queries * len(documents)
|
||||
if len(queries) != len(documents):
|
||||
raise HaystackError("Number of queries must be equal to number of provided Document lists.")
|
||||
for query, cur_docs in zip(queries, documents):
|
||||
if not isinstance(cur_docs, list):
|
||||
raise HaystackError(f"cur_docs was of type {type(cur_docs)}, but expected a list of Documents.")
|
||||
preds = self.predict(query=query, documents=cur_docs, top_k=top_k)
|
||||
results["answers"].append(preds["answers"])
|
||||
|
||||
# Group answers by question in case of multiple queries and single doc list
|
||||
if single_doc_list and len(queries) > 1:
|
||||
answers_per_query = int(len(results["answers"]) / len(queries))
|
||||
answers = []
|
||||
for i in range(0, len(results["answers"]), answers_per_query):
|
||||
answer_group = results["answers"][i : i + answers_per_query]
|
||||
answers.append(answer_group)
|
||||
results["answers"] = answers
|
||||
if self.return_no_answer:
|
||||
answers.append(
|
||||
Answer(
|
||||
answer="",
|
||||
type="extractive",
|
||||
score=no_answer_score,
|
||||
context=None,
|
||||
offsets_in_context=[Span(start=0, end=0)],
|
||||
offsets_in_document=[Span(start=0, end=0)],
|
||||
document_id=None,
|
||||
meta=None,
|
||||
)
|
||||
)
|
||||
|
||||
answers = sorted(answers, reverse=True)
|
||||
results = {"query": query, "answers": answers[:top_k]}
|
||||
return results
|
||||
|
||||
class TapasForScoredQA(TapasPreTrainedModel):
|
||||
class _TapasForScoredQA(TapasPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
|
||||
@ -684,10 +684,14 @@ def reader(request):
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(params=["tapas", "rci"])
|
||||
@pytest.fixture(params=["tapas_small", "tapas_base", "tapas_scored", "rci"])
|
||||
def table_reader(request):
|
||||
if request.param == "tapas":
|
||||
if request.param == "tapas_small":
|
||||
return TableReader(model_name_or_path="google/tapas-small-finetuned-wtq")
|
||||
elif request.param == "tapas_base":
|
||||
return TableReader(model_name_or_path="google/tapas-base-finetuned-wtq")
|
||||
elif request.param == "tapas_scored":
|
||||
return TableReader(model_name_or_path="deepset/tapas-large-nq-hn-reader")
|
||||
elif request.param == "rci":
|
||||
return RCIReader(
|
||||
row_model_name_or_path="michaelrglass/albert-base-rci-wikisql-row",
|
||||
|
||||
@ -7,6 +7,7 @@ from haystack.schema import Document, Answer
|
||||
from haystack.pipelines.base import Pipeline
|
||||
|
||||
|
||||
@pytest.mark.parametrize("table_reader", ["tapas_small", "rci", "tapas_scored"], indirect=True)
|
||||
def test_table_reader(table_reader):
|
||||
data = {
|
||||
"actors": ["brad pitt", "leonardo di caprio", "george clooney"],
|
||||
@ -23,6 +24,7 @@ def test_table_reader(table_reader):
|
||||
assert prediction["answers"][0].offsets_in_context[0].end == 8
|
||||
|
||||
|
||||
@pytest.mark.parametrize("table_reader", ["tapas_small", "rci"], indirect=True)
|
||||
def test_table_reader_batch_single_query_single_doc_list(table_reader):
|
||||
data = {
|
||||
"actors": ["brad pitt", "leonardo di caprio", "george clooney"],
|
||||
@ -41,6 +43,7 @@ def test_table_reader_batch_single_query_single_doc_list(table_reader):
|
||||
assert len(prediction["answers"]) == 1 # Predictions for 5 docs
|
||||
|
||||
|
||||
@pytest.mark.parametrize("table_reader", ["tapas_small", "rci"], indirect=True)
|
||||
def test_table_reader_batch_single_query_multiple_doc_lists(table_reader):
|
||||
data = {
|
||||
"actors": ["brad pitt", "leonardo di caprio", "george clooney"],
|
||||
@ -61,6 +64,7 @@ def test_table_reader_batch_single_query_multiple_doc_lists(table_reader):
|
||||
assert len(prediction["answers"]) == 1 # Predictions for 1 collection of docs
|
||||
|
||||
|
||||
@pytest.mark.parametrize("table_reader", ["tapas_small", "rci"], indirect=True)
|
||||
def test_table_reader_batch_multiple_queries_single_doc_list(table_reader):
|
||||
data = {
|
||||
"actors": ["brad pitt", "leonardo di caprio", "george clooney"],
|
||||
@ -82,6 +86,7 @@ def test_table_reader_batch_multiple_queries_single_doc_list(table_reader):
|
||||
assert len(prediction["answers"]) == 2 # Predictions for 2 queries
|
||||
|
||||
|
||||
@pytest.mark.parametrize("table_reader", ["tapas_small", "rci"], indirect=True)
|
||||
def test_table_reader_batch_multiple_queries_multiple_doc_lists(table_reader):
|
||||
data = {
|
||||
"actors": ["brad pitt", "leonardo di caprio", "george clooney"],
|
||||
@ -103,6 +108,7 @@ def test_table_reader_batch_multiple_queries_multiple_doc_lists(table_reader):
|
||||
assert len(prediction["answers"]) == 2 # Predictions for 2 collections of documents
|
||||
|
||||
|
||||
@pytest.mark.parametrize("table_reader", ["tapas_small", "rci"], indirect=True)
|
||||
def test_table_reader_in_pipeline(table_reader):
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(table_reader, "TableReader", ["Query"])
|
||||
@ -123,7 +129,7 @@ def test_table_reader_in_pipeline(table_reader):
|
||||
assert prediction["answers"][0].offsets_in_context[0].end == 8
|
||||
|
||||
|
||||
@pytest.mark.parametrize("table_reader", ["tapas"], indirect=True)
|
||||
@pytest.mark.parametrize("table_reader", ["tapas_base"], indirect=True)
|
||||
def test_table_reader_aggregation(table_reader):
|
||||
data = {
|
||||
"Mountain": ["Mount Everest", "K2", "Kangchenjunga", "Lhotse", "Makalu"],
|
||||
@ -144,6 +150,7 @@ def test_table_reader_aggregation(table_reader):
|
||||
assert prediction["answers"][0].meta["answer_cells"] == ["8848m", "8,611 m", "8 586m", "8 516 m", "8,485m"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("table_reader", ["tapas_small", "rci"], indirect=True)
|
||||
def test_table_without_rows(caplog, table_reader):
|
||||
# empty DataFrame
|
||||
table = pd.DataFrame()
|
||||
@ -154,6 +161,7 @@ def test_table_without_rows(caplog, table_reader):
|
||||
assert len(predictions["answers"]) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("table_reader", ["tapas_small", "rci"], indirect=True)
|
||||
def test_text_document(caplog, table_reader):
|
||||
document = Document(content="text", id="text_doc")
|
||||
with caplog.at_level(logging.WARNING):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user