diff --git a/docs/_src/api/api/reader.md b/docs/_src/api/api/reader.md
index 1aa0c4883..a6d816255 100644
--- a/docs/_src/api/api/reader.md
+++ b/docs/_src/api/api/reader.md
@@ -609,3 +609,75 @@ WARNING: The answer scores are not reliable, as they are always extremely high,
Dict containing query and answers
+
+## RCIReader
+
+```python
+class RCIReader(BaseReader)
+```
+
+Table Reader model based on Glass et al. (2021)'s Row-Column-Intersection model.
+See the original paper for more details:
+Glass, Michael, et al. (2021): "Capturing Row and Column Semantics in Transformer Based Question Answering over Tables"
+(https://aclanthology.org/2021.naacl-main.96/)
+
+Each row and each column is given a score with regard to the query by two separate models. The score of each cell
+is then calculated as the sum of the corresponding row score and column score. Accordingly, the predicted answer is
+the cell with the highest score.
+
+
+#### \_\_init\_\_
+
+```python
+ | __init__(row_model_name_or_path: str = "michaelrglass/albert-base-rci-wikisql-row", column_model_name_or_path: str = "michaelrglass/albert-base-rci-wikisql-col", row_model_version: Optional[str] = None, column_model_version: Optional[str] = None, row_tokenizer: Optional[str] = None, column_tokenizer: Optional[str] = None, use_gpu: bool = True, top_k: int = 10, max_seq_len: int = 256)
+```
+
+Load an RCI model from Transformers.
+Available models include:
+
+- ``'michaelrglass/albert-base-rci-wikisql-row'`` + ``'michaelrglass/albert-base-rci-wikisql-col'``
+- ``'michaelrglass/albert-base-rci-wtq-row'`` + ``'michaelrglass/albert-base-rci-wtq-col'``
+
+
+
+**Arguments**:
+
+- `row_model_name_or_path`: Directory of a saved row scoring model or the name of a public model
+- `column_model_name_or_path`: Directory of a saved column scoring model or the name of a public model
+- `row_model_version`: The version of row model to use from the HuggingFace model hub.
+ Can be tag name, branch name, or commit hash.
+- `column_model_version`: The version of column model to use from the HuggingFace model hub.
+ Can be tag name, branch name, or commit hash.
+- `row_tokenizer`: Name of the tokenizer for the row model (usually the same as model)
+- `column_tokenizer`: Name of the tokenizer for the column model (usually the same as model)
+- `use_gpu`: Whether to use GPU or CPU. Falls back on CPU if no GPU is available.
+- `top_k`: The maximum number of answers to return
+- `max_seq_len`: Max sequence length of one input table for the model. If the number of tokens of
+ query + table exceed max_seq_len, the table will be truncated by removing rows until the
+ input size fits the model.
+
+
+#### predict
+
+```python
+ | predict(query: str, documents: List[Document], top_k: Optional[int] = None) -> Dict
+```
+
+Use loaded RCI models to find answers for a query in the supplied list of Documents
+of content_type ``'table'``.
+
+Returns dictionary containing query and list of Answer objects sorted by (desc.) score.
+The existing RCI models on the HF model hub don"t allow aggregation, therefore, the answer will always be
+composed of a single cell.
+
+**Arguments**:
+
+- `query`: Query string
+- `documents`: List of Document in which to search for the answer. Documents should be
+ of content_type ``'table'``.
+- `top_k`: The maximum number of answers to return
+
+**Returns**:
+
+Dict containing query and answers
+
diff --git a/haystack/nodes/__init__.py b/haystack/nodes/__init__.py
index 807728fe9..f9cd1008b 100644
--- a/haystack/nodes/__init__.py
+++ b/haystack/nodes/__init__.py
@@ -23,7 +23,7 @@ from haystack.nodes.preprocessor import BasePreProcessor, PreProcessor
from haystack.nodes.query_classifier import SklearnQueryClassifier, TransformersQueryClassifier
from haystack.nodes.question_generator import QuestionGenerator
from haystack.nodes.ranker import BaseRanker, SentenceTransformersRanker
-from haystack.nodes.reader import BaseReader, FARMReader, TransformersReader, TableReader
+from haystack.nodes.reader import BaseReader, FARMReader, TransformersReader, TableReader, RCIReader
from haystack.nodes.retriever import (
BaseRetriever,
DensePassageRetriever,
diff --git a/haystack/nodes/reader/__init__.py b/haystack/nodes/reader/__init__.py
index e06975298..0d80527e0 100644
--- a/haystack/nodes/reader/__init__.py
+++ b/haystack/nodes/reader/__init__.py
@@ -1,4 +1,4 @@
from haystack.nodes.reader.base import BaseReader
from haystack.nodes.reader.farm import FARMReader
from haystack.nodes.reader.transformers import TransformersReader
-from haystack.nodes.reader.table import TableReader
+from haystack.nodes.reader.table import TableReader, RCIReader
diff --git a/haystack/nodes/reader/table.py b/haystack/nodes/reader/table.py
index b4453383a..711e256e1 100644
--- a/haystack/nodes/reader/table.py
+++ b/haystack/nodes/reader/table.py
@@ -6,7 +6,8 @@ import torch
import numpy as np
import pandas as pd
from quantulum3 import parser
-from transformers import TapasTokenizer, TapasForQuestionAnswering, BatchEncoding
+from transformers import TapasTokenizer, TapasForQuestionAnswering, AutoTokenizer, AutoModelForSequenceClassification, \
+ BatchEncoding, AutoConfig
from haystack.schema import Document, Answer, Span
from haystack.nodes.reader.base import BaseReader
@@ -252,3 +253,209 @@ class TableReader(BaseReader):
def predict_batch(self, query_doc_list: List[dict], top_k: Optional[int] = None, batch_size: Optional[int] = None):
raise NotImplementedError("Batch prediction not yet available in TableReader.")
+
+
+class RCIReader(BaseReader):
+ """
+ Table Reader model based on Glass et al. (2021)'s Row-Column-Intersection model.
+ See the original paper for more details:
+ Glass, Michael, et al. (2021): "Capturing Row and Column Semantics in Transformer Based Question Answering over Tables"
+ (https://aclanthology.org/2021.naacl-main.96/)
+
+ Each row and each column is given a score with regard to the query by two separate models. The score of each cell
+ is then calculated as the sum of the corresponding row score and column score. Accordingly, the predicted answer is
+ the cell with the highest score.
+
+ Pros and Cons of RCIReader compared to TableReader:
+ + Provides meaningful confidence scores
+ + Allows larger tables as input
+ - Does not support aggregation over table cells
+ - Slower
+ """
+
+ def __init__(self,
+ row_model_name_or_path: str = "michaelrglass/albert-base-rci-wikisql-row",
+ column_model_name_or_path: str = "michaelrglass/albert-base-rci-wikisql-col",
+ row_model_version: Optional[str] = None,
+ column_model_version: Optional[str] = None,
+ row_tokenizer: Optional[str] = None,
+ column_tokenizer: Optional[str] = None,
+ use_gpu: bool = True,
+ top_k: int = 10,
+ max_seq_len: int = 256,
+ ):
+ """
+ Load an RCI model from Transformers.
+ Available models include:
+
+ - ``'michaelrglass/albert-base-rci-wikisql-row'`` + ``'michaelrglass/albert-base-rci-wikisql-col'``
+ - ``'michaelrglass/albert-base-rci-wtq-row'`` + ``'michaelrglass/albert-base-rci-wtq-col'``
+
+
+
+ :param row_model_name_or_path: Directory of a saved row scoring model or the name of a public model
+ :param column_model_name_or_path: Directory of a saved column scoring model or the name of a public model
+ :param row_model_version: The version of row model to use from the HuggingFace model hub.
+ Can be tag name, branch name, or commit hash.
+ :param column_model_version: The version of column model to use from the HuggingFace model hub.
+ Can be tag name, branch name, or commit hash.
+ :param row_tokenizer: Name of the tokenizer for the row model (usually the same as model)
+ :param column_tokenizer: Name of the tokenizer for the column model (usually the same as model)
+ :param use_gpu: Whether to use GPU or CPU. Falls back on CPU if no GPU is available.
+ :param top_k: The maximum number of answers to return
+ :param max_seq_len: Max sequence length of one input table for the model. If the number of tokens of
+ query + table exceed max_seq_len, the table will be truncated by removing rows until the
+ input size fits the model.
+ """
+ # Save init parameters to enable export of component config as YAML
+ self.set_config(row_model_name_or_path=row_model_name_or_path,
+ column_model_name_or_path=column_model_name_or_path, row_model_version=row_model_version,
+ column_model_version=column_model_version, row_tokenizer=row_tokenizer,
+ column_tokenizer=column_tokenizer, use_gpu=use_gpu, top_k=top_k, max_seq_len=max_seq_len)
+
+ self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False)
+ self.row_model = AutoModelForSequenceClassification.from_pretrained(row_model_name_or_path,
+ revision=row_model_version)
+ self.column_model = AutoModelForSequenceClassification.from_pretrained(row_model_name_or_path,
+ revision=column_model_version)
+ self.row_model.to(str(self.devices[0]))
+ self.column_model.to(str(self.devices[0]))
+
+ if row_tokenizer is None:
+ try:
+ self.row_tokenizer = AutoTokenizer.from_pretrained(row_model_name_or_path)
+ # The existing RCI models on the model hub don't come with tokenizer vocab files.
+ except TypeError:
+ self.row_tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
+ else:
+ self.row_tokenizer = AutoTokenizer.from_pretrained(row_tokenizer)
+
+ if column_tokenizer is None:
+ try:
+ self.column_tokenizer = AutoTokenizer.from_pretrained(column_model_name_or_path)
+ # The existing RCI models on the model hub don't come with tokenizer vocab files.
+ except TypeError:
+ self.column_tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
+ else:
+ self.column_tokenizer = AutoTokenizer.from_pretrained(column_tokenizer)
+
+ self.top_k = top_k
+ self.max_seq_len = max_seq_len
+ self.return_no_answers = False
+
+ def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> Dict:
+ """
+ Use loaded RCI models to find answers for a query in the supplied list of Documents
+ of content_type ``'table'``.
+
+ Returns dictionary containing query and list of Answer objects sorted by (desc.) score.
+ The existing RCI models on the HF model hub don"t allow aggregation, therefore, the answer will always be
+ composed of a single cell.
+
+ :param query: Query string
+ :param documents: List of Document in which to search for the answer. Documents should be
+ of content_type ``'table'``.
+ :param top_k: The maximum number of answers to return
+ :return: Dict containing query and answers
+ """
+ if top_k is None:
+ top_k = self.top_k
+
+ answers = []
+ for document in documents:
+ if document.content_type != "table":
+ logger.warning(f"Skipping document with id {document.id} in RCIReader, as it is not of type table.")
+ continue
+
+ table: pd.DataFrame = document.content
+ table = table.astype(str)
+ # Create row and column representations
+ row_reps, column_reps = self._create_row_column_representations(table)
+
+ # Get row logits
+ row_inputs = self.row_tokenizer.batch_encode_plus(
+ batch_text_or_text_pairs=[(query, row_rep) for row_rep in row_reps],
+ max_length=self.max_seq_len,
+ return_tensors="pt",
+ add_special_tokens=True,
+ truncation=True,
+ padding=True
+ )
+ row_inputs.to(self.devices[0])
+ row_logits = self.row_model(**row_inputs)[0].detach().cpu().numpy()[:, 1]
+
+ # Get column logits
+ column_inputs = self.column_tokenizer.batch_encode_plus(
+ batch_text_or_text_pairs=[(query, column_rep) for column_rep in column_reps],
+ max_length=self.max_seq_len,
+ return_tensors="pt",
+ add_special_tokens=True,
+ truncation=True,
+ padding=True
+ )
+ column_inputs.to(self.devices[0])
+ column_logits = self.column_model(**column_inputs)[0].detach().cpu().numpy()[:, 1]
+
+ # Calculate cell scores
+ current_answers: List[Answer] = []
+ cell_scores_table: List[List[float]] = []
+ for row_idx, row_score in enumerate(row_logits):
+ cell_scores_table.append([])
+ for col_idx, col_score in enumerate(column_logits):
+ current_cell_score = float(row_score + col_score)
+ cell_scores_table[-1].append(current_cell_score)
+
+ answer_str = table.iloc[row_idx, col_idx]
+ answer_offsets = self._calculate_answer_offsets(row_idx, col_idx, table)
+ current_answers.append(
+ Answer(
+ answer=answer_str,
+ type="extractive",
+ score=current_cell_score,
+ context=table,
+ offsets_in_document=[answer_offsets],
+ offsets_in_context=[answer_offsets],
+ document_id=document.id,
+ )
+ )
+
+ # Add cell scores to Answers' meta to be able to use as heatmap
+ for answer in current_answers:
+ answer.meta = {"table_scores": cell_scores_table}
+ answers.extend(current_answers)
+
+ # Sort answers by score and select top-k answers
+ answers = sorted(answers, reverse=True)
+ answers = answers[:top_k]
+
+ results = {"query": query,
+ "answers": answers}
+
+ return results
+
+ @staticmethod
+ def _create_row_column_representations(table: pd.DataFrame) -> Tuple[List[str], List[str]]:
+ row_reps = []
+ column_reps = []
+ columns = table.columns
+
+ for idx, row in table.iterrows():
+ current_row_rep = " * ".join([header + " : " + cell for header, cell in zip(columns, row)])
+ row_reps.append(current_row_rep)
+
+ for col_name in columns:
+ current_column_rep = f"{col_name} * "
+ current_column_rep += " * ".join(table[col_name])
+ column_reps.append(current_column_rep)
+
+ return row_reps, column_reps
+
+ @staticmethod
+ def _calculate_answer_offsets(row_idx, column_index, table) -> Span:
+ n_rows, n_columns = table.shape
+ answer_cell_offset = (row_idx * n_columns) + column_index
+
+ return Span(start=answer_cell_offset, end=answer_cell_offset + 1)
+
+ def predict_batch(self, query_doc_list: List[dict], top_k: Optional[int] = None, batch_size: Optional[int] = None):
+ raise NotImplementedError("Batch prediction not yet available in RCIReader.")
diff --git a/test/conftest.py b/test/conftest.py
index 117f5063f..eccce83b7 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -34,7 +34,7 @@ from haystack.document_stores.memory import InMemoryDocumentStore
from haystack.document_stores.sql import SQLDocumentStore
from haystack.nodes.reader.farm import FARMReader
from haystack.nodes.reader.transformers import TransformersReader
-from haystack.nodes.reader.table import TableReader
+from haystack.nodes.reader.table import TableReader, RCIReader
from haystack.nodes.summarizer.transformers import TransformersSummarizer
from haystack.nodes.translator import TransformersTranslator
from haystack.nodes.question_generator import QuestionGenerator
@@ -338,9 +338,13 @@ def reader(request):
)
-@pytest.fixture(scope="function")
-def table_reader():
- return TableReader(model_name_or_path="google/tapas-base-finetuned-wtq")
+@pytest.fixture(params=["tapas", "rci"], scope="function")
+def table_reader(request):
+ if request.param == "tapas":
+ return TableReader(model_name_or_path="google/tapas-base-finetuned-wtq")
+ elif request.param == "rci":
+ return RCIReader(row_model_name_or_path="michaelrglass/albert-base-rci-wikisql-row",
+ column_model_name_or_path="michaelrglass/albert-base-rci-wikisql-col")
@pytest.fixture(scope="function")
diff --git a/test/test_table_reader.py b/test/test_table_reader.py
index 6d5f52c79..a99865d13 100644
--- a/test/test_table_reader.py
+++ b/test/test_table_reader.py
@@ -1,4 +1,5 @@
import pandas as pd
+import pytest
from haystack.schema import Document
from haystack.pipelines.base import Pipeline
@@ -7,53 +8,39 @@ from haystack.pipelines.base import Pipeline
def test_table_reader(table_reader):
data = {
"actors": ["brad pitt", "leonardo di caprio", "george clooney"],
- "age": ["57", "46", "60"],
+ "age": ["58", "47", "60"],
"number of movies": ["87", "53", "69"],
- "date of birth": ["7 february 1967", "10 june 1996", "28 november 1967"],
+ "date of birth": ["18 december 1963", "11 november 1974", "6 may 1961"],
}
table = pd.DataFrame(data)
- query = "When was DiCaprio born?"
+ query = "When was Di Caprio born?"
prediction = table_reader.predict(query=query, documents=[Document(content=table, content_type="table")])
- assert prediction["answers"][0].answer == "10 june 1996"
+ assert prediction["answers"][0].answer == "11 november 1974"
assert prediction["answers"][0].offsets_in_context[0].start == 7
assert prediction["answers"][0].offsets_in_context[0].end == 8
- # test aggregation
- query = "How old are DiCaprio and Pitt on average?"
- prediction = table_reader.predict(query=query, documents=[Document(content=table, content_type="table")])
- assert prediction["answers"][0].answer == "51.5"
- assert prediction["answers"][0].meta["answer_cells"] == ["57", "46"]
- assert prediction["answers"][0].meta["aggregation_operator"] == "AVERAGE"
- assert prediction["answers"][0].offsets_in_context[0].start == 1
- assert prediction["answers"][0].offsets_in_context[0].end == 2
- assert prediction["answers"][0].offsets_in_context[1].start == 5
- assert prediction["answers"][0].offsets_in_context[1].end == 6
-
def test_table_reader_in_pipeline(table_reader):
pipeline = Pipeline()
pipeline.add_node(table_reader, "TableReader", ["Query"])
data = {
"actors": ["brad pitt", "leonardo di caprio", "george clooney"],
- "age": ["57", "46", "60"],
+ "age": ["58", "47", "60"],
"number of movies": ["87", "53", "69"],
- "date of birth": ["7 february 1967", "10 june 1996", "28 november 1967"],
+ "date of birth": ["18 december 1963", "11 november 1974", "6 may 1961"],
}
table = pd.DataFrame(data)
- query = "Which actors played in more than 60 movies?"
+ query = "When was Di Caprio born?"
prediction = pipeline.run(query=query, documents=[Document(content=table, content_type="table")])
- assert prediction["answers"][0].answer == "brad pitt, george clooney"
- assert prediction["answers"][0].meta["aggregation_operator"] == "NONE"
- assert prediction["answers"][0].offsets_in_context[0].start == 0
- assert prediction["answers"][0].offsets_in_context[0].end == 1
- assert prediction["answers"][0].offsets_in_context[1].start == 8
- assert prediction["answers"][0].offsets_in_context[1].end == 9
-
+ assert prediction["answers"][0].answer == "11 november 1974"
+ assert prediction["answers"][0].offsets_in_context[0].start == 7
+ assert prediction["answers"][0].offsets_in_context[0].end == 8
+@pytest.mark.parametrize("table_reader", ["tapas"], indirect=True)
def test_table_reader_aggregation(table_reader):
data = {
"Mountain": ["Mount Everest", "K2", "Kangchenjunga", "Lhotse", "Makalu"],
@@ -72,4 +59,3 @@ def test_table_reader_aggregation(table_reader):
assert prediction["answers"][0].answer == "43046.0 m"
assert prediction["answers"][0].meta["aggregation_operator"] == "SUM"
assert prediction["answers"][0].meta["answer_cells"] == ['8848m', '8,611 m', '8 586m', '8 516 m', '8,485m']
-