mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-21 06:58:27 +00:00
Add RCIReader for TableQA (#1909)
* Add RCIReader * Add latest docstring and tutorial changes * Add Doc Strings * Add latest docstring and tutorial changes * Add Tests * Add Doc Strings Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
6e8e3c68d9
commit
45df18c416
@ -609,3 +609,75 @@ WARNING: The answer scores are not reliable, as they are always extremely high,
|
|||||||
|
|
||||||
Dict containing query and answers
|
Dict containing query and answers
|
||||||
|
|
||||||
|
<a name="table.RCIReader"></a>
|
||||||
|
## RCIReader
|
||||||
|
|
||||||
|
```python
|
||||||
|
class RCIReader(BaseReader)
|
||||||
|
```
|
||||||
|
|
||||||
|
Table Reader model based on Glass et al. (2021)'s Row-Column-Intersection model.
|
||||||
|
See the original paper for more details:
|
||||||
|
Glass, Michael, et al. (2021): "Capturing Row and Column Semantics in Transformer Based Question Answering over Tables"
|
||||||
|
(https://aclanthology.org/2021.naacl-main.96/)
|
||||||
|
|
||||||
|
Each row and each column is given a score with regard to the query by two separate models. The score of each cell
|
||||||
|
is then calculated as the sum of the corresponding row score and column score. Accordingly, the predicted answer is
|
||||||
|
the cell with the highest score.
|
||||||
|
|
||||||
|
<a name="table.RCIReader.__init__"></a>
|
||||||
|
#### \_\_init\_\_
|
||||||
|
|
||||||
|
```python
|
||||||
|
| __init__(row_model_name_or_path: str = "michaelrglass/albert-base-rci-wikisql-row", column_model_name_or_path: str = "michaelrglass/albert-base-rci-wikisql-col", row_model_version: Optional[str] = None, column_model_version: Optional[str] = None, row_tokenizer: Optional[str] = None, column_tokenizer: Optional[str] = None, use_gpu: bool = True, top_k: int = 10, max_seq_len: int = 256)
|
||||||
|
```
|
||||||
|
|
||||||
|
Load an RCI model from Transformers.
|
||||||
|
Available models include:
|
||||||
|
|
||||||
|
- ``'michaelrglass/albert-base-rci-wikisql-row'`` + ``'michaelrglass/albert-base-rci-wikisql-col'``
|
||||||
|
- ``'michaelrglass/albert-base-rci-wtq-row'`` + ``'michaelrglass/albert-base-rci-wtq-col'``
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
**Arguments**:
|
||||||
|
|
||||||
|
- `row_model_name_or_path`: Directory of a saved row scoring model or the name of a public model
|
||||||
|
- `column_model_name_or_path`: Directory of a saved column scoring model or the name of a public model
|
||||||
|
- `row_model_version`: The version of row model to use from the HuggingFace model hub.
|
||||||
|
Can be tag name, branch name, or commit hash.
|
||||||
|
- `column_model_version`: The version of column model to use from the HuggingFace model hub.
|
||||||
|
Can be tag name, branch name, or commit hash.
|
||||||
|
- `row_tokenizer`: Name of the tokenizer for the row model (usually the same as model)
|
||||||
|
- `column_tokenizer`: Name of the tokenizer for the column model (usually the same as model)
|
||||||
|
- `use_gpu`: Whether to use GPU or CPU. Falls back on CPU if no GPU is available.
|
||||||
|
- `top_k`: The maximum number of answers to return
|
||||||
|
- `max_seq_len`: Max sequence length of one input table for the model. If the number of tokens of
|
||||||
|
query + table exceed max_seq_len, the table will be truncated by removing rows until the
|
||||||
|
input size fits the model.
|
||||||
|
|
||||||
|
<a name="table.RCIReader.predict"></a>
|
||||||
|
#### predict
|
||||||
|
|
||||||
|
```python
|
||||||
|
| predict(query: str, documents: List[Document], top_k: Optional[int] = None) -> Dict
|
||||||
|
```
|
||||||
|
|
||||||
|
Use loaded RCI models to find answers for a query in the supplied list of Documents
|
||||||
|
of content_type ``'table'``.
|
||||||
|
|
||||||
|
Returns dictionary containing query and list of Answer objects sorted by (desc.) score.
|
||||||
|
The existing RCI models on the HF model hub don"t allow aggregation, therefore, the answer will always be
|
||||||
|
composed of a single cell.
|
||||||
|
|
||||||
|
**Arguments**:
|
||||||
|
|
||||||
|
- `query`: Query string
|
||||||
|
- `documents`: List of Document in which to search for the answer. Documents should be
|
||||||
|
of content_type ``'table'``.
|
||||||
|
- `top_k`: The maximum number of answers to return
|
||||||
|
|
||||||
|
**Returns**:
|
||||||
|
|
||||||
|
Dict containing query and answers
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ from haystack.nodes.preprocessor import BasePreProcessor, PreProcessor
|
|||||||
from haystack.nodes.query_classifier import SklearnQueryClassifier, TransformersQueryClassifier
|
from haystack.nodes.query_classifier import SklearnQueryClassifier, TransformersQueryClassifier
|
||||||
from haystack.nodes.question_generator import QuestionGenerator
|
from haystack.nodes.question_generator import QuestionGenerator
|
||||||
from haystack.nodes.ranker import BaseRanker, SentenceTransformersRanker
|
from haystack.nodes.ranker import BaseRanker, SentenceTransformersRanker
|
||||||
from haystack.nodes.reader import BaseReader, FARMReader, TransformersReader, TableReader
|
from haystack.nodes.reader import BaseReader, FARMReader, TransformersReader, TableReader, RCIReader
|
||||||
from haystack.nodes.retriever import (
|
from haystack.nodes.retriever import (
|
||||||
BaseRetriever,
|
BaseRetriever,
|
||||||
DensePassageRetriever,
|
DensePassageRetriever,
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from haystack.nodes.reader.base import BaseReader
|
from haystack.nodes.reader.base import BaseReader
|
||||||
from haystack.nodes.reader.farm import FARMReader
|
from haystack.nodes.reader.farm import FARMReader
|
||||||
from haystack.nodes.reader.transformers import TransformersReader
|
from haystack.nodes.reader.transformers import TransformersReader
|
||||||
from haystack.nodes.reader.table import TableReader
|
from haystack.nodes.reader.table import TableReader, RCIReader
|
||||||
|
@ -6,7 +6,8 @@ import torch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from quantulum3 import parser
|
from quantulum3 import parser
|
||||||
from transformers import TapasTokenizer, TapasForQuestionAnswering, BatchEncoding
|
from transformers import TapasTokenizer, TapasForQuestionAnswering, AutoTokenizer, AutoModelForSequenceClassification, \
|
||||||
|
BatchEncoding, AutoConfig
|
||||||
|
|
||||||
from haystack.schema import Document, Answer, Span
|
from haystack.schema import Document, Answer, Span
|
||||||
from haystack.nodes.reader.base import BaseReader
|
from haystack.nodes.reader.base import BaseReader
|
||||||
@ -252,3 +253,209 @@ class TableReader(BaseReader):
|
|||||||
|
|
||||||
def predict_batch(self, query_doc_list: List[dict], top_k: Optional[int] = None, batch_size: Optional[int] = None):
|
def predict_batch(self, query_doc_list: List[dict], top_k: Optional[int] = None, batch_size: Optional[int] = None):
|
||||||
raise NotImplementedError("Batch prediction not yet available in TableReader.")
|
raise NotImplementedError("Batch prediction not yet available in TableReader.")
|
||||||
|
|
||||||
|
|
||||||
|
class RCIReader(BaseReader):
|
||||||
|
"""
|
||||||
|
Table Reader model based on Glass et al. (2021)'s Row-Column-Intersection model.
|
||||||
|
See the original paper for more details:
|
||||||
|
Glass, Michael, et al. (2021): "Capturing Row and Column Semantics in Transformer Based Question Answering over Tables"
|
||||||
|
(https://aclanthology.org/2021.naacl-main.96/)
|
||||||
|
|
||||||
|
Each row and each column is given a score with regard to the query by two separate models. The score of each cell
|
||||||
|
is then calculated as the sum of the corresponding row score and column score. Accordingly, the predicted answer is
|
||||||
|
the cell with the highest score.
|
||||||
|
|
||||||
|
Pros and Cons of RCIReader compared to TableReader:
|
||||||
|
+ Provides meaningful confidence scores
|
||||||
|
+ Allows larger tables as input
|
||||||
|
- Does not support aggregation over table cells
|
||||||
|
- Slower
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
row_model_name_or_path: str = "michaelrglass/albert-base-rci-wikisql-row",
|
||||||
|
column_model_name_or_path: str = "michaelrglass/albert-base-rci-wikisql-col",
|
||||||
|
row_model_version: Optional[str] = None,
|
||||||
|
column_model_version: Optional[str] = None,
|
||||||
|
row_tokenizer: Optional[str] = None,
|
||||||
|
column_tokenizer: Optional[str] = None,
|
||||||
|
use_gpu: bool = True,
|
||||||
|
top_k: int = 10,
|
||||||
|
max_seq_len: int = 256,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Load an RCI model from Transformers.
|
||||||
|
Available models include:
|
||||||
|
|
||||||
|
- ``'michaelrglass/albert-base-rci-wikisql-row'`` + ``'michaelrglass/albert-base-rci-wikisql-col'``
|
||||||
|
- ``'michaelrglass/albert-base-rci-wtq-row'`` + ``'michaelrglass/albert-base-rci-wtq-col'``
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
:param row_model_name_or_path: Directory of a saved row scoring model or the name of a public model
|
||||||
|
:param column_model_name_or_path: Directory of a saved column scoring model or the name of a public model
|
||||||
|
:param row_model_version: The version of row model to use from the HuggingFace model hub.
|
||||||
|
Can be tag name, branch name, or commit hash.
|
||||||
|
:param column_model_version: The version of column model to use from the HuggingFace model hub.
|
||||||
|
Can be tag name, branch name, or commit hash.
|
||||||
|
:param row_tokenizer: Name of the tokenizer for the row model (usually the same as model)
|
||||||
|
:param column_tokenizer: Name of the tokenizer for the column model (usually the same as model)
|
||||||
|
:param use_gpu: Whether to use GPU or CPU. Falls back on CPU if no GPU is available.
|
||||||
|
:param top_k: The maximum number of answers to return
|
||||||
|
:param max_seq_len: Max sequence length of one input table for the model. If the number of tokens of
|
||||||
|
query + table exceed max_seq_len, the table will be truncated by removing rows until the
|
||||||
|
input size fits the model.
|
||||||
|
"""
|
||||||
|
# Save init parameters to enable export of component config as YAML
|
||||||
|
self.set_config(row_model_name_or_path=row_model_name_or_path,
|
||||||
|
column_model_name_or_path=column_model_name_or_path, row_model_version=row_model_version,
|
||||||
|
column_model_version=column_model_version, row_tokenizer=row_tokenizer,
|
||||||
|
column_tokenizer=column_tokenizer, use_gpu=use_gpu, top_k=top_k, max_seq_len=max_seq_len)
|
||||||
|
|
||||||
|
self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False)
|
||||||
|
self.row_model = AutoModelForSequenceClassification.from_pretrained(row_model_name_or_path,
|
||||||
|
revision=row_model_version)
|
||||||
|
self.column_model = AutoModelForSequenceClassification.from_pretrained(row_model_name_or_path,
|
||||||
|
revision=column_model_version)
|
||||||
|
self.row_model.to(str(self.devices[0]))
|
||||||
|
self.column_model.to(str(self.devices[0]))
|
||||||
|
|
||||||
|
if row_tokenizer is None:
|
||||||
|
try:
|
||||||
|
self.row_tokenizer = AutoTokenizer.from_pretrained(row_model_name_or_path)
|
||||||
|
# The existing RCI models on the model hub don't come with tokenizer vocab files.
|
||||||
|
except TypeError:
|
||||||
|
self.row_tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
|
||||||
|
else:
|
||||||
|
self.row_tokenizer = AutoTokenizer.from_pretrained(row_tokenizer)
|
||||||
|
|
||||||
|
if column_tokenizer is None:
|
||||||
|
try:
|
||||||
|
self.column_tokenizer = AutoTokenizer.from_pretrained(column_model_name_or_path)
|
||||||
|
# The existing RCI models on the model hub don't come with tokenizer vocab files.
|
||||||
|
except TypeError:
|
||||||
|
self.column_tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
|
||||||
|
else:
|
||||||
|
self.column_tokenizer = AutoTokenizer.from_pretrained(column_tokenizer)
|
||||||
|
|
||||||
|
self.top_k = top_k
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
self.return_no_answers = False
|
||||||
|
|
||||||
|
def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> Dict:
|
||||||
|
"""
|
||||||
|
Use loaded RCI models to find answers for a query in the supplied list of Documents
|
||||||
|
of content_type ``'table'``.
|
||||||
|
|
||||||
|
Returns dictionary containing query and list of Answer objects sorted by (desc.) score.
|
||||||
|
The existing RCI models on the HF model hub don"t allow aggregation, therefore, the answer will always be
|
||||||
|
composed of a single cell.
|
||||||
|
|
||||||
|
:param query: Query string
|
||||||
|
:param documents: List of Document in which to search for the answer. Documents should be
|
||||||
|
of content_type ``'table'``.
|
||||||
|
:param top_k: The maximum number of answers to return
|
||||||
|
:return: Dict containing query and answers
|
||||||
|
"""
|
||||||
|
if top_k is None:
|
||||||
|
top_k = self.top_k
|
||||||
|
|
||||||
|
answers = []
|
||||||
|
for document in documents:
|
||||||
|
if document.content_type != "table":
|
||||||
|
logger.warning(f"Skipping document with id {document.id} in RCIReader, as it is not of type table.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
table: pd.DataFrame = document.content
|
||||||
|
table = table.astype(str)
|
||||||
|
# Create row and column representations
|
||||||
|
row_reps, column_reps = self._create_row_column_representations(table)
|
||||||
|
|
||||||
|
# Get row logits
|
||||||
|
row_inputs = self.row_tokenizer.batch_encode_plus(
|
||||||
|
batch_text_or_text_pairs=[(query, row_rep) for row_rep in row_reps],
|
||||||
|
max_length=self.max_seq_len,
|
||||||
|
return_tensors="pt",
|
||||||
|
add_special_tokens=True,
|
||||||
|
truncation=True,
|
||||||
|
padding=True
|
||||||
|
)
|
||||||
|
row_inputs.to(self.devices[0])
|
||||||
|
row_logits = self.row_model(**row_inputs)[0].detach().cpu().numpy()[:, 1]
|
||||||
|
|
||||||
|
# Get column logits
|
||||||
|
column_inputs = self.column_tokenizer.batch_encode_plus(
|
||||||
|
batch_text_or_text_pairs=[(query, column_rep) for column_rep in column_reps],
|
||||||
|
max_length=self.max_seq_len,
|
||||||
|
return_tensors="pt",
|
||||||
|
add_special_tokens=True,
|
||||||
|
truncation=True,
|
||||||
|
padding=True
|
||||||
|
)
|
||||||
|
column_inputs.to(self.devices[0])
|
||||||
|
column_logits = self.column_model(**column_inputs)[0].detach().cpu().numpy()[:, 1]
|
||||||
|
|
||||||
|
# Calculate cell scores
|
||||||
|
current_answers: List[Answer] = []
|
||||||
|
cell_scores_table: List[List[float]] = []
|
||||||
|
for row_idx, row_score in enumerate(row_logits):
|
||||||
|
cell_scores_table.append([])
|
||||||
|
for col_idx, col_score in enumerate(column_logits):
|
||||||
|
current_cell_score = float(row_score + col_score)
|
||||||
|
cell_scores_table[-1].append(current_cell_score)
|
||||||
|
|
||||||
|
answer_str = table.iloc[row_idx, col_idx]
|
||||||
|
answer_offsets = self._calculate_answer_offsets(row_idx, col_idx, table)
|
||||||
|
current_answers.append(
|
||||||
|
Answer(
|
||||||
|
answer=answer_str,
|
||||||
|
type="extractive",
|
||||||
|
score=current_cell_score,
|
||||||
|
context=table,
|
||||||
|
offsets_in_document=[answer_offsets],
|
||||||
|
offsets_in_context=[answer_offsets],
|
||||||
|
document_id=document.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add cell scores to Answers' meta to be able to use as heatmap
|
||||||
|
for answer in current_answers:
|
||||||
|
answer.meta = {"table_scores": cell_scores_table}
|
||||||
|
answers.extend(current_answers)
|
||||||
|
|
||||||
|
# Sort answers by score and select top-k answers
|
||||||
|
answers = sorted(answers, reverse=True)
|
||||||
|
answers = answers[:top_k]
|
||||||
|
|
||||||
|
results = {"query": query,
|
||||||
|
"answers": answers}
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_row_column_representations(table: pd.DataFrame) -> Tuple[List[str], List[str]]:
|
||||||
|
row_reps = []
|
||||||
|
column_reps = []
|
||||||
|
columns = table.columns
|
||||||
|
|
||||||
|
for idx, row in table.iterrows():
|
||||||
|
current_row_rep = " * ".join([header + " : " + cell for header, cell in zip(columns, row)])
|
||||||
|
row_reps.append(current_row_rep)
|
||||||
|
|
||||||
|
for col_name in columns:
|
||||||
|
current_column_rep = f"{col_name} * "
|
||||||
|
current_column_rep += " * ".join(table[col_name])
|
||||||
|
column_reps.append(current_column_rep)
|
||||||
|
|
||||||
|
return row_reps, column_reps
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _calculate_answer_offsets(row_idx, column_index, table) -> Span:
|
||||||
|
n_rows, n_columns = table.shape
|
||||||
|
answer_cell_offset = (row_idx * n_columns) + column_index
|
||||||
|
|
||||||
|
return Span(start=answer_cell_offset, end=answer_cell_offset + 1)
|
||||||
|
|
||||||
|
def predict_batch(self, query_doc_list: List[dict], top_k: Optional[int] = None, batch_size: Optional[int] = None):
|
||||||
|
raise NotImplementedError("Batch prediction not yet available in RCIReader.")
|
||||||
|
@ -34,7 +34,7 @@ from haystack.document_stores.memory import InMemoryDocumentStore
|
|||||||
from haystack.document_stores.sql import SQLDocumentStore
|
from haystack.document_stores.sql import SQLDocumentStore
|
||||||
from haystack.nodes.reader.farm import FARMReader
|
from haystack.nodes.reader.farm import FARMReader
|
||||||
from haystack.nodes.reader.transformers import TransformersReader
|
from haystack.nodes.reader.transformers import TransformersReader
|
||||||
from haystack.nodes.reader.table import TableReader
|
from haystack.nodes.reader.table import TableReader, RCIReader
|
||||||
from haystack.nodes.summarizer.transformers import TransformersSummarizer
|
from haystack.nodes.summarizer.transformers import TransformersSummarizer
|
||||||
from haystack.nodes.translator import TransformersTranslator
|
from haystack.nodes.translator import TransformersTranslator
|
||||||
from haystack.nodes.question_generator import QuestionGenerator
|
from haystack.nodes.question_generator import QuestionGenerator
|
||||||
@ -338,9 +338,13 @@ def reader(request):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(params=["tapas", "rci"], scope="function")
|
||||||
def table_reader():
|
def table_reader(request):
|
||||||
|
if request.param == "tapas":
|
||||||
return TableReader(model_name_or_path="google/tapas-base-finetuned-wtq")
|
return TableReader(model_name_or_path="google/tapas-base-finetuned-wtq")
|
||||||
|
elif request.param == "rci":
|
||||||
|
return RCIReader(row_model_name_or_path="michaelrglass/albert-base-rci-wikisql-row",
|
||||||
|
column_model_name_or_path="michaelrglass/albert-base-rci-wikisql-col")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import pytest
|
||||||
|
|
||||||
from haystack.schema import Document
|
from haystack.schema import Document
|
||||||
from haystack.pipelines.base import Pipeline
|
from haystack.pipelines.base import Pipeline
|
||||||
@ -7,53 +8,39 @@ from haystack.pipelines.base import Pipeline
|
|||||||
def test_table_reader(table_reader):
|
def test_table_reader(table_reader):
|
||||||
data = {
|
data = {
|
||||||
"actors": ["brad pitt", "leonardo di caprio", "george clooney"],
|
"actors": ["brad pitt", "leonardo di caprio", "george clooney"],
|
||||||
"age": ["57", "46", "60"],
|
"age": ["58", "47", "60"],
|
||||||
"number of movies": ["87", "53", "69"],
|
"number of movies": ["87", "53", "69"],
|
||||||
"date of birth": ["7 february 1967", "10 june 1996", "28 november 1967"],
|
"date of birth": ["18 december 1963", "11 november 1974", "6 may 1961"],
|
||||||
}
|
}
|
||||||
table = pd.DataFrame(data)
|
table = pd.DataFrame(data)
|
||||||
|
|
||||||
query = "When was Di Caprio born?"
|
query = "When was Di Caprio born?"
|
||||||
prediction = table_reader.predict(query=query, documents=[Document(content=table, content_type="table")])
|
prediction = table_reader.predict(query=query, documents=[Document(content=table, content_type="table")])
|
||||||
assert prediction["answers"][0].answer == "10 june 1996"
|
assert prediction["answers"][0].answer == "11 november 1974"
|
||||||
assert prediction["answers"][0].offsets_in_context[0].start == 7
|
assert prediction["answers"][0].offsets_in_context[0].start == 7
|
||||||
assert prediction["answers"][0].offsets_in_context[0].end == 8
|
assert prediction["answers"][0].offsets_in_context[0].end == 8
|
||||||
|
|
||||||
# test aggregation
|
|
||||||
query = "How old are DiCaprio and Pitt on average?"
|
|
||||||
prediction = table_reader.predict(query=query, documents=[Document(content=table, content_type="table")])
|
|
||||||
assert prediction["answers"][0].answer == "51.5"
|
|
||||||
assert prediction["answers"][0].meta["answer_cells"] == ["57", "46"]
|
|
||||||
assert prediction["answers"][0].meta["aggregation_operator"] == "AVERAGE"
|
|
||||||
assert prediction["answers"][0].offsets_in_context[0].start == 1
|
|
||||||
assert prediction["answers"][0].offsets_in_context[0].end == 2
|
|
||||||
assert prediction["answers"][0].offsets_in_context[1].start == 5
|
|
||||||
assert prediction["answers"][0].offsets_in_context[1].end == 6
|
|
||||||
|
|
||||||
|
|
||||||
def test_table_reader_in_pipeline(table_reader):
|
def test_table_reader_in_pipeline(table_reader):
|
||||||
pipeline = Pipeline()
|
pipeline = Pipeline()
|
||||||
pipeline.add_node(table_reader, "TableReader", ["Query"])
|
pipeline.add_node(table_reader, "TableReader", ["Query"])
|
||||||
data = {
|
data = {
|
||||||
"actors": ["brad pitt", "leonardo di caprio", "george clooney"],
|
"actors": ["brad pitt", "leonardo di caprio", "george clooney"],
|
||||||
"age": ["57", "46", "60"],
|
"age": ["58", "47", "60"],
|
||||||
"number of movies": ["87", "53", "69"],
|
"number of movies": ["87", "53", "69"],
|
||||||
"date of birth": ["7 february 1967", "10 june 1996", "28 november 1967"],
|
"date of birth": ["18 december 1963", "11 november 1974", "6 may 1961"],
|
||||||
}
|
}
|
||||||
|
|
||||||
table = pd.DataFrame(data)
|
table = pd.DataFrame(data)
|
||||||
query = "Which actors played in more than 60 movies?"
|
query = "When was Di Caprio born?"
|
||||||
|
|
||||||
prediction = pipeline.run(query=query, documents=[Document(content=table, content_type="table")])
|
prediction = pipeline.run(query=query, documents=[Document(content=table, content_type="table")])
|
||||||
|
|
||||||
assert prediction["answers"][0].answer == "brad pitt, george clooney"
|
assert prediction["answers"][0].answer == "11 november 1974"
|
||||||
assert prediction["answers"][0].meta["aggregation_operator"] == "NONE"
|
assert prediction["answers"][0].offsets_in_context[0].start == 7
|
||||||
assert prediction["answers"][0].offsets_in_context[0].start == 0
|
assert prediction["answers"][0].offsets_in_context[0].end == 8
|
||||||
assert prediction["answers"][0].offsets_in_context[0].end == 1
|
|
||||||
assert prediction["answers"][0].offsets_in_context[1].start == 8
|
|
||||||
assert prediction["answers"][0].offsets_in_context[1].end == 9
|
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("table_reader", ["tapas"], indirect=True)
|
||||||
def test_table_reader_aggregation(table_reader):
|
def test_table_reader_aggregation(table_reader):
|
||||||
data = {
|
data = {
|
||||||
"Mountain": ["Mount Everest", "K2", "Kangchenjunga", "Lhotse", "Makalu"],
|
"Mountain": ["Mount Everest", "K2", "Kangchenjunga", "Lhotse", "Makalu"],
|
||||||
@ -72,4 +59,3 @@ def test_table_reader_aggregation(table_reader):
|
|||||||
assert prediction["answers"][0].answer == "43046.0 m"
|
assert prediction["answers"][0].answer == "43046.0 m"
|
||||||
assert prediction["answers"][0].meta["aggregation_operator"] == "SUM"
|
assert prediction["answers"][0].meta["aggregation_operator"] == "SUM"
|
||||||
assert prediction["answers"][0].meta["answer_cells"] == ['8848m', '8,611 m', '8 586m', '8 516 m', '8,485m']
|
assert prediction["answers"][0].meta["answer_cells"] == ['8848m', '8,611 m', '8 586m', '8 516 m', '8,485m']
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user